date-created: 2024-07-09 05:32:53 date-modified: 2024-07-09 05:35:19

What exactly are keys, queries, and values in attention mechanisms?

anchored to 116.00_anchor_machine_learning requires and proceeds from 116.16_self_supervised_learning and 116.18_transformer_model

First Answer:

I was also puzzled by the keys, queries, and values in the attention mechanisms for a while. After searching on the Web and digesting relevant information, I have a clear picture about how the keys, queries, and values work and why they would work!

Let’s see how they work, followed by why they work.

Attention to replace context vector

In a seq2seq model, we encode the input sequence to a context vector, and then feed this context vector to the decoder to yield expected good output.

However, if the input sequence becomes long, relying on only one context vector become less effective. We need all the information from the hidden states in the input sequence (encoder) for better decoding (the attention mechanism).

One way to utilize the input hidden states is shown below: Image source: https://towardsdatascience.com/attn-illustrated-attention-5ec4ad276ee3 Image source

In other words, in this attention mechanism, the context vector is computed as a weighted sum of the values, where the weight assigned to each value is computed by a compatibility function of the query with the corresponding key (this is a slightly modified sentence from Attention Is All You Need).

Here, the query is from the decoder hidden state, the key and value are from the encoder hidden states (key and value are the same in this figure). The score is the compatibility between the query and key, which can be a dot product between the query and key (or other form of compatibility). The scores then go through the softmax function to yield a set of weights whose sum equals 1. Each weight multiplies its corresponding values to yield the context vector which utilizes all the input hidden states.

Note that if we manually set the weight of the last input to 1 and all its precedences to 0s, we reduce the attention mechanism to the original seq2seq context vector mechanism. That is, there is no attention to the earlier input encoder states.

Self-Attention uses Q, K, V all from the input

Now, let’s consider the self-attention mechanism as shown in the figure below:

enter image description here Image source

The difference from the above figure is that the queries, keys, and values are transformations of the corresponding input state vectors. The others remain the same.

Note that we could still use the original encoder state vectors as the queries, keys, and values. So, why we need the transformation? The transformation is simply a matrix multiplication like this:

Query = I x W(Q)

Key = I x W(K)

Value = I x W(V)

where I is the input (encoder) state vector, and W(Q), W(K), and W(V) are the corresponding matrices to transform the I vector into the Query, Key, Value vectors.

What are the benefits of this matrix multiplication (vector transformation)?

The obvious reason is that if we do not transform the input vectors, the dot product for computing the weight for each input’s value will always yield a maximum weight score for the individual input token itself. In other words, when we compute the n attention weights (j for j=1, 2, …, n) for input token at position i, the weight at i (j==i) is always the largest than the other weights at j=1, 2, …, n (j<>i). This may not be the desired case. For example, for the pronoun token, we need it to attend to its referent, not the pronoun token itself.

Another less obvious but important reason is that the transformation may yield better representations for Query, Key, and Value. Recall the effect of Singular Value Decomposition (SVD) like that in the following figure:

Application of SVD

Image source

By multiplying an input vector with a matrix V (from the SVD), we obtain a better representation for computing the compatibility between two vectors, if these two vectors are similar in the topic space as shown in the example in the figure.

And these matrices for transformation can be learned in a neural network!

In short, by multiplying the input vector with a matrix, we got:

  1. increase of the possibility for each input token to attend to other tokens in the input sequence, instead of individual token itself

  2. possibly better (latent) representations of the input vector

  3. conversion of the input vector into a space with a desired dimension, say, from dimension 5 to 2, or from n to m, etc (which is practically useful)

I hope this help you understand the queries, keys, and values in the (self-)attention mechanism of deep neural networks.

Second Answer

author: SeanSean source: stats.stackexchange.com

Big picture

Basically Transformer builds a graph network where a node is a position-encoded token in a sequence.

During training:

  1. Get un-connected tokens as a sequence (e.g. sentence).
  2. Wires connections among tokens by having looked at the co-occurrences of them in billions of sequences.

What roles Q and K will play to build this graph network? You could be Q in your society trying to build the social graph network with other people. Each person in the people is K and you will build the connections with them. Eventually by having billions of interactions with other people, the connections become dependent on the contexts even with the same person K.

You may be superior to a person K at work, but K may be a master of martial art for you. As you remember such connections/relations with others based on the contexts, Transformer model (trained on a specific dataset) figures out such context dependent connections from Q to K (or from you to other person(s)), which is a memory that it offers.

If the layers go up higher, your individual identity as K will be blended into larger parts via going through the BoW process which plays the role.

With regard to the Markov Chain (MC), there is only one static connection from Q to K as P(K|Q) in MC as MC does not have the context memory that Transformer model offers.

First, understand Q and K

First, focus on the objective of First MatMul in the Scaled dot product attention using Q and K.

enter image description here

Intuition on what is Attention

For the sentence “jane visits africa”.

When your eyes see jane, your brain looks for the most related word in the rest of the sentence to understand what jane is about (query). Your brain focuses or attends to the word visit (key).

This process happens for each word in the sentence as your eyes progress through the sentence.

First MatMul as Inquiry System using Vector Similarity

The first MatMul implements an inquiry system or question-answer system that imitates this brain function, using Vector Similarity Calculation. Watch CS480/680 Lecture 19: Attention and Transformer Networks by professor Pascal Poupart to understand further.

Think about the attention essentially being some form of approximation of SELECT that you would do in the database.
enter image description here

enter image description here

Think of the MatMul as an inquiry system that processes the inquiry: “For the word q that your eyes see in the given sentence, what is the most related word k in the sentence to understand what q is about?” The inquiry system provides the answer as the probability.

qkprobability
janevisit0.94
visitafrica0.86
africavisit0.76

Note that the softmax is used to normalize values into probabilities so that their sum becomes 1.0.

enter image description here

There are multiple ways to calculate the similarity between vectors such as cosine similarity. Transformer attention uses simple dot product.

Where are Q and K from

The transformer encoder training builds the weight parameter matrices WQ and Wk in the way Q and K builds the Inquiry System that answers the inquiry “What is k for the word q”.

The calculation goes like below where x is a sequence of position-encoded word embedding vectors that represents an input sentence.

  1. Picks up a word vector (position encoded) from the input sentence sequence, and transfer it to a vector space Q. This becomes the query.
    Q=X⋅WQT

  2. Pick all the words in the sentence and transfer them to the vector space K. They become keys and each of them is used as key.
    K=X⋅WKT

  3. For each (q, k) pair, their relation strength is calculated using dot product.
    q_to_k_similarity_scores=matmul(Q,KT)

  4. Weight matrices WQ and WK are trained via the back propagations during the Transformer training.

We first needs to understand this part that involves Q and K before moving to V.

enter image description here

Borrowing the code from Let’s build GPT: from scratch, in code, spelled out. by Andrej Karpathy.

# B: Batch size
# T: Sequence length or max token size e.g. 512 for BERT. 'T' because of 'Time steps = Sequence length'
# D: Dimensions of the model embedding vector, which is d_model in the paper.
# H or h: Number of multi attention heads in Multi-head attention


def calculate_dot_product_similarities(
        query: Tensor,
        key: Tensor,
) -> Tensor:
    """
    Calculate similarity scores between queries and keys using dot product.

    Args:
        query: embedding vector of query of shape (B, h, T, d_k)
        key: embedding vector of key of shape (B, h, T, d_k)

    Returns: Similarities (closeness) between q and k of shape (B, h, T, T) where
        last (T, T) represents relations between all query elements in T sequence
        against all key elements in T sequence. If T is people in an organization,
        (T,T) represents all (cartesian product) social connections among them.
        The relation considers d_k number of features.
    """
    # --------------------------------------------------------------------------------
    # Relationship between k and q as the first MatMul using dot product similarity:
    # (B, h, T, d_k) @ (B, hH, d_k, T) ---> (B, h, T, T)
    # --------------------------------------------------------------------------------
    similarities = query @ key.transpose(-2, -1)            # dot product
    return similarities                                     # shape:(B, h, T, T)

Then, understand how V is created using Q and K

Second Matmul

Self Attention then generates the embedding vector called attention value as a bag of words (BoW) where each word contributes proportionally according to its relationship strength to q. This occurs for each q from the sentence sequence. The embedding vector is encoding the relations from q to all the words in the sentence.

Citing the words from Andrej Karpathy:

What is the easiest way for tokens to communicate. The easiest way is just average.

He makes it simple for the sake of tutorial but the essence is BoW.

enter image description here

def calculate_attention_values(
        similarities,
        values
):
    """
    For every q element, create a Bag of Words that encodes the relationships with
    other elements (including itself) in T, using (q,k) relationship value as the
    strength of the relationships.

    Citation:
    > On each of these projected versions of queries, keys and values we then perform
    > the attention function in parallel, yielding d_v-dimensional output values.

    ```
    bows = []
    for row in similarities:                    # similarity matrix of shape (T,T)
        bow = sum([                             # bow:shape(d_v,)
            # each column in row is (q,k) similarity score s
            s*v for (s,v) in zip(row,values)    # k:shape(), v:shape(d_v,)
=        ])
        bows.append(bow)                        # bows:shape(T,d_v)
    ```

    Args:
        similarities: q to k relationship strength matrix of shape (B, h, T, T)
        values: elements of sequence with length T of shape (B, h, T, d_v)

    Returns: Bag of Words for every q element of shape (B, h, T, d_v)
    """
    return similarities @ values     # (B,h,T,T) @ (B,h,T,d_v) -> (B,h,T,d_v)

References

There are multiple concepts that will help understand how the self attention in transformer works, e.g. embedding to group similars in a vector space, data retrieval to answer query Q using the neural network and vector similarity.