Seq2seq architectures process input elements sequentially (the output upto the previous element is used in the processing of the current element). This is slow, can’t be parallelised, suffers from vanishing gradient and has trouble with context when processing long sequences.
Attention comes to solve all of this in the form of transformers.
Attention
Notice that in the figure above, the influence of the first element of the sequence on the final encoded feature vector is 3 arrows away whereas in attention it directly influences the encoded feature vector.
Attention is just a way to look at the entire sequence at once, irrespective of the position of the sequence that is being encoded or decoded.
It was born as a way to enable seq2seq architectures to not rely on hacks like memory vectors, instead use attention as a way to lookup the original sequence as needed. Transformers proved that attention by itself is capable of learning and outperforming seq2seq architectures on sequential data.
How to lookup the entire sequence? You can agree that some parts of the sentence are more important to look at when predicting certain words, say in the task of translation. Let’s say the word you to be decoded is a verb, then you might need to look at (pay attention to) verbs used in the input sentence to realise what tense needs to be used.
So we are looking at some sort of a weighted average of the input embeddings tuned to the various needs of the network.
The general framework of attention revolves around 3 things — query, key-value pairs.
Query — the current state, position or time step of the network
Values — things the network is going to pay attention to
Keys — used in determining how much attention to pay to its corresponding value
Given a query and a list of key-value pairs, we determine which values get what attention by computing a similarity between query and each of the keys. This similarity will be used to take a weighted average among all the values.
Here a could be any function that outputs a scalar value of similarity (compatibility) between the query and key vectors.
Scaled Dot Product Attention
It is when you use a specific attention function of the form given below.
It uses a simple dot product to find similarity.
Since attention is similar to a weighted average of the values we would need the attention scores to follow a probability distribution i.e they sum upto 1, which is why we apply softmax over it.
The division is done so that in cases of higher dimensions, the magnitude of dot product doesn’t blow up and causes backprop issues.
Self Attention
It is the special case when the key and value are the same. We compute a similarity score for a each embedding with the query using the embedding itself.
If attention can be represented as
the self attention is
For example in translation tasks the keys and the values are both the embeddings of the input words themselves. Since we just want to know which words to pay attention to, it makes sense to use the word embeddings themselves as keys to compute their own attention scores.
Multi Head Attention
Why only one attention module? Why don’t we have multiple, so that for the same query, each module learn to pay attention to particular parts of the sequence based on different needs.
A naive example could be one attention head paying attention to time related words (verbs and its tenses), the other paying attention to gender terms etc.
To achieve this, the query, keys and values are linearly projected by, say h, different sets of learned weights. Attention is then computed on these h sets of query, keys, values leading to h different attention vectors.
These h outputs of the different attention heads are concatenated to form the output of multi head attention block.
Masked Attention
The decoder in the transformer still operates sequentially. It uses the previous word prediction in predicting the current word, which means the decoder can’t look to the right of its current position when computing attention (because there’s nothing to look at). So we pass -∞ in the softmax layer of attention for all positions to the right of the decoder’s current position (which would lead to 0 attention scores to the “garbage” output embeddings on the right).
Other components
Input/Output Embedding
It is simply the vector representation of whatever your input/output sequence is. If it’s a sentence then each word is converted to a vector using something like GloVe. If it’s an image then patches of the image are made into vectors in a row/column major. In essence you need a vector representation for each element of your sequences since at the end of the day neural networks need numbers.
Positional Encoder
Transformers processes all elements in parallel unlike RNNs. How does it retain the order of the sequence? The positional encoder. It puts position information directly into the embedding of the elements of the sequences.
Why is it addition instead of concatenation? Let’s assume we concatenate natural numbers as an index at the front of the embedding, we might reach INT_MAX when working with really long sequences or end up with really large numbers in the forward pass of the network or the network might mistake the index for some variable with a magnitude and have trouble with backprop.
The other problem of concatenation is the increasing input dimension. Remember curse of dimensionality.
How is P determined? Let’s say you have an input of size
then for the element at position pos, at dimension i
That looks definitely weird, but this has 2 advantages.
- the use of both sin and cos means, relative positions can be reached through a simple linear transformation, allowing the network to learn to exploit it if the need arises. The network can jump from position t to t+dt by learning the below weight matrix.
- there is no trouble in handling long sequences and/or embeddings of really high dimensions because sin and cos are cyclic functions without overflow problems
A good analogue as to what it’s trying to achieve can be seen through binary numbers.
Writing decimal numbers as binary, we can see that the least significant bit fluctuates the most and the most significant but fluctuates the least i.e a decreasing trend in the fluctuation as we go bitwise.
If you look at the formula for P, you’ll realise that the frequency indeed does decrease for columns from left to right. We can see that sincos position encoder is trying to emulate binary numbers and each row approximately represents a unique number in decimal space.