Writing an LLM from scratch, part 6 -- starting to code self-attention
This is the sixth in my series of notes on Sebastian Raschka's book "Build a Large Language Model (from Scratch)". Each time I read part of it, I'm posting about what I found interesting as a way to help get things straight in my own head -- and perhaps to help anyone else that is working through it too. This post covers just one subsection of the trickiest chapter in the book -- subsection 3.3.1, "A simple self-attention mechanism without trainable weights". I feel that there's enough in there to make up a post on its own. For me, it certainly gave me one key intuition that I think is a critical part of how everything fits together.
As always, there may be errors in my understanding below -- I've cross-checked and run the whole post through Claude, ChatGPT o1, and DeepSeek r1, so I'm reasonably confident, but caveat lector :-) With all that said, let's go!
In my last two posts, I've spent a lot of time trying to get a solid, intuitive understanding of what self-attention mechanisms are for, and what they mean in an abstract sense. They're the core concept in the implementation of LLMs (using Rashcka's specific meaning of transformers-based language models, specifically GPTs), so this felt like an important foundation to get exactly right. This chapter is where things start to get a bit more concrete.
Let's start with a summary at the abstract level, though. Using my normal simplification that words == tokens, an LLM is a model that given some input text, works out what the next word should be. In order to do this, it considers all of the words in the input text. Attention is a system that tells it, when looking at word x in the input, how much it should also take into account each of the other words in that input -- that is, how much attention it should pay to each other word.
So, for the sentence
The fat cat sat on the mat
...when the LLM is considering "cat", it needs to focus on the first "the" (because it's a specific cat, not "a cat" -- sorry, Slavic language speakers, that is important ;-), "sat" because that is the verb that the cat is doing, "fat" because it applies to the cat, and perhaps to a lesser extent "on", "the", and "mat". Likewise, if the LLM is looking at "sat", then it's likely that "cat" and "mat" are the important bits as the subject and indirect objects of the verb, with the other words playing lesser parts.
You can imagine building up for each word in the sentence a list of numbers -- one for each other token. Here's an example of what that might look like, based on my own intuition about the importance of each word to the others in that sentence:
Token | ω("The") | ω("fat") | ω("cat") | ω("sat") | ω("on") | ω("the") | ω("mat") |
---|---|---|---|---|---|---|---|
The | 1 | 0.3 | 0.75 | 0.1 | 0 | 0 | 0 |
fat | 0.2 | 1 | 0.8 | 0 | 0 | 0 | 0 |
cat | 0.6 | 0.8 | 1 | 0.7 | 0.3 | 0.2 | 0.4 |
sat | 0.1 | 0 | 0.85 | 1 | 0.3 | 0.2 | 0.75 |
on | 0 | 0.1 | 0.4 | 0.6 | 1 | 0.3 | 0.8 |
the | 0 | 0 | 0 | 0 | 0.1 | 1 | 0.75 |
mat | 0 | 0 | 0.2 | 0.8 | 0.7 | 0.6 | 1 |
To unpack that -- each row in the table is a word in the sentence, and the numbers in that row are how much attention one might want to pay to the other words when trying to understand it. I've used a number from 0 ("you can ignore it") to 1 ("this is super-important"). But you can see that I've tried to represent the relative importance that I gave above for "cat" and "sat" numerically, and then done likewise for the other words. Note that each word is most attentive to itself, which makes intuitive sense to me :-)
Now, the cool thing is that this is exactly what attention mechanisms do. The reason I (slightly ab-)used the symbol ω to represent these numbers is because it's a symbol used to represent attention scores. When considering a particular token in the input sequence, the attention mechanism assigns every other token an attention score. This is then used to work out what to do with it.
Attention is all about building up data like this so that neural networks can understand language enough to be able to predict the next token when given something to start with. So we really have two questions:
- How do we build up a system that can work out sets of attention scores across all tokens, for each token in our input sequence?
- What do we do with those lists of attention scores once we've got them?
The answer to the second one of those is actually the easiest, and Rashka explains that in this section -- this is the big "aha" moment for me in this post. We want to create, for each token, something called a context vector.
So far, our LLM has received a bunch of text, and has
- Split it into tokens.
- For each token, it has generated two embeddings -- a token embedding that represents (or rather, will be trained to represent) in some manner the meaning of that token in isolation, and a position embedding, which will represent simply whether it's the first, second, third, or whatever token in the sequence.
Now, for our token "cat" we have at this stage a token embedding that represents "cat" (and is probably similar to the embedding for "kitten", not too dissimilar to the embedding for "dog", and very different to the embedding for "rain"). The purpose of the context vector is to produce a something numerical that represents "cat" in the context of all of the other words in the sentence.
So while the token embedding for "cat" just means "cat", the context vector for "cat" in this sentence also has some kind of overtones about it being a cat that is sitting, perhaps less strongly that it's a specific cat ("the" rather than "a"), and hints of it being sitting on a mat. By contrast, the context for the first "The" would just be pretty much representing that it's the definite article and might also have some cattishness about it, and maybe a spot of fatness. And so on for the context vectors for all of the other words. As I understand it, it's the context vector that the next-token-prediction layer will care about -- not the word embedding. Which makes sense, it doesn't care about cats in general, it cares about this specific one in this sentence.
If you've read further on in the chapter you'll have noticed I'm playing a bit fast and loose in the above, because I'm not considering causal attention. But I think that's worth doing as an aid to intuition, at least my own, at this stage.
That means that the attention part of our LLM is basically a black box that takes in token and position embeddings, and for each token spits out a context vector that represents the token's meaning in the context of this particular input sequence. (Multi-head attention, introduced later, will build on this a bit more.)
To me the context vector feels a little bit like the hidden states that were used to transfer meaning from the encoder to the decoder in those original translation systems that I was writing about in my last post in this series; perhaps that's how people got from there to where we are now. It's essentially something like an embedding for this token in this context.
So, given a set of attention scores for a particular token -- one for each token in the input -- how do we go about getting this context vector for the token we're considering?
Well, if for "cat" we want something that includes sitting, some hint of the mat, and so on, and we have these attention scores that relate how strongly the word "cat" is related to each other token. And each of those other tokens has an embedding that is meant to represent its meaning. So we could do a pretty simple bit of maths to get something that represents them all taken together -- we just multiply all of the token embeddings by their respective attention weights and add them together! So to get the context vector for "cat" given this part of the attention score matrix above:
Token | ω("The") | ω("fat") | ω("cat") | ω("sat") | ω("on") | ω("the") | ω("mat") |
---|---|---|---|---|---|---|---|
cat | 0.6 | 0.8 | 1 | 0.7 | 0.3 | 0.2 | 0.4 |
We'd calculate 0.6 times the token embedding for "The", 0.8 times the token embedding for "fat", 1 times the token embedding for "cat", and so on, then add them all up. That might give us our context embedding for "cat" in this sentence.
And that is pretty much what happens! That, to me, is beautifully simple.
There's only one extra wrinkle, which seems entirely reasonable -- the numbers I showed above would be called attention scores, and before multiplying the token embeddings for all of the tokens in the input sequence by their associated score, we convert them into attention weights, which are essentially the same numbers but normalised so that all of the weights for a given token sum up to 1 -- this is done by running them through the Softmax function. This apparently works well with extreme values and turns negative numbers into small positives ones, both of which seem like good things (especially the second -- paying a negative amount of attention to one word when trying to understand another doesn't make any intuitive sense).
So, for each token -- or rather the (token embedding, position embedding) pair for each -- we work out a vector of attention scores, which is the length of the full input sequence and for each item in the sequence, says how much attention to pay to it. We then run that vector through softmax to normalise it into attention weights, multiply each token embedding in the input sequence by its attention weight, and add together the results element-wise to get the context vector for the token we're considering.
Maybe it's worth sketching out some pseudocode to make this as clear as possible.
Given a sequence of (token embedding, position embedding)
pairs called inputs
, we can
imagine it's doing this:
output_context_vectors = []
for (tok_em, pos_em) in inputs:
attention_scores = get_attention_scores(tok_em, inputs)
attention_weights = softmax(attention_scores)
context_vector = an empty vector
for ((other_tok_em, _), attention_weight) in zip(inputs, attention_weights):
context_vector += attention_weight * other_tok_em
output_context_vectors.append(context_vector)
Now, obviously this would not be the right way to implement it (all of that iterating
over lists inside iterations over lists can be simplified into matrix multiplications). And
of course, the get_attention_scores
is magic -- and the next big thing to
understand. But I think that pseudocode gives, at least for me, a clear idea of
what is going on when building up a context vector given a vector of attention scores
for each input token (or more strictly, its token and position embeddings).
That's pretty magical :-) But how do we get the attention scores in the first place?
Or more accurately, how do we train something to do that?
In this section, Rashka keeps things super-simple. As he says in the section title,
this is a "simple self-attention mechanism without trainable
weights". So instead of trying to write a real
get_attention_scores
, he just uses the dot product of the token embeddings
to generate a placeholder one. The dot product takes two vectors and returns a
scalar, so that means that the attention scores for "cat" is a vector made up
of:
- The token embedding for "The", dot-product the token embedding for "cat"
- The token embedding for "fat", dot-product the token embedding for "cat"
- The token embedding for "cat", dot-product the token embedding for "cat"
- The token embedding for "sat", dot-product the token embedding for "cat"
- The token embedding for "on", dot-product the token embedding for "cat"
- The token embedding for "the", dot-product the token embedding for "cat"
- The token embedding for "mat", dot-product the token embedding for "cat"
Now, the dot product for vectors is worked out by multiplying their elements individually and then adding up the results -- that is,
[1, 2, 3] . [4, 5, 6]
-> [1 * 4, 2 * 5, 3 * 6]
= [4, 10, 18]
-> 4 + 10 + 18
= 32
So this is a non-crazy "toy" calculation to work out attention scores, because the dot product of two vectors is related to how similar they are in terms of their direction in the vector space (though they are apparently normalised in some way for this).
But it's not what we want for a real attention mechanism -- for example, in
The fat cat sat on the mat
...if we ignore case, the two "the"s would be super-closely related to each other so would have a really high attention score, much higher than any of the other pairs of words. So there's nothing in there about the meaning of the sentence. But it's a great starting point so that we can start implementing things before actually coding a real system for generating attention scores. And I'm pretty sure that the real attention score system will be using dot product somewhere, given their ability to say whether two things are similar -- it will just be a part of a larger mechanism. And this mechanism will have to involve positional embeddings in some way -- right now we're just ignoring them completely.
I'm going to wrap up there. The next section is working at a more implementational level -- given this simple dot-product way of working out the attention, how do we in practice calculate the context vectors for an input sequence in PyTorch, using matrix and other tensor operations -- so that everything can be done as a batch without complex loops like in my pseudocode above. That can be the topic for my next post.