Writing an LLM from scratch, part 7 -- wrapping up non-trainable self-attention
This is the seventh post in my series of notes on Sebastian Raschka's book "Build a Large Language Model (from Scratch)". Each time I read part of it, I'm posting about what I found interesting or needed to think hard about, as a way to help get things straight in my own head -- and perhaps to help anyone else that is working through it too.
This post is a quick one, covering just section 3.3.2, "Computing attention weights for all input tokens". I'm covering it in a post on its own because it gets things in place for what feels like the hardest part to grasp at an intuitive level -- how we actually design a system that can learn how to generate attention weights, which is the subject of the next section, 3.4. My linear algebra is super-rusty, and while going through this one, I needed to relearn some stuff that I think I must have forgotten sometime late last century...
In the last post, I went through subsection 3.3.1, which introduced the concept of the context vector and attention scores and weights. This post will be a mystery to you if you've not read that post (and the previous ones), so I really do recommend that you go back if you're coming in halfway through this series.
That section gave a framework for how the attention mechanism takes a stream of input embeddings (each of which is the sum of a token embedding, which represents in some manner the meaning of its associated token taken on its own, and a position embedding, which represents where it is in the input) and converts it into a sequence of context vectors, one per token, each of which represents the meaning of the associated token in the context of the input as a whole.
Throughout that, the generation of attention scores was kept simple. The attention score for a given token when considering a token would be the dot product of their respective input embeddings. The dot product is calculated by multiplying the vectors element-wise, then summing the elements of the result. It can be taken as a measure of the similarity of two vectors -- how close they are to pointing in the same direction -- so this is unrealistic as an attention score. For example, in the sentence "the fat cat sat on the mat" the two "the"s, being the same token and thus the same token embedding, would have very high attention scores for each other -- differing only because their position embeddings would be different -- despite not being particularly relevant to each other in a semantic sense. But we have to start somewhere :-)
This section keeps with this implementation of attention scores, and optimises it. On the face of it, that seems like a strange thing to do -- if we're using a fake attention score mechanism, then why try to make it more efficient? My guess is that the real attention mechanism with trainable weights is similar enough to this simple one that the optimisation is helpful for both, and might be a good starting point for intuition.
So, let's look at how the simple version is implemented. We have a sequence of input embeddings of length . For each one, we want a list of attention scores, that will also be of length . Here's the table of imaginary attention scores that I used in the last post (with made-up but intuitively "right" numbers -- not dot products):
Token | ω("The") | ω("fat") | ω("cat") | ω("sat") | ω("on") | ω("the") | ω("mat") |
---|---|---|---|---|---|---|---|
The | 1 | 0.3 | 0.75 | 0.1 | 0 | 0 | 0 |
fat | 0.2 | 1 | 0.8 | 0 | 0 | 0 | 0 |
cat | 0.6 | 0.8 | 1 | 0.7 | 0.3 | 0.2 | 0.4 |
sat | 0.1 | 0 | 0.85 | 1 | 0.3 | 0.2 | 0.75 |
on | 0 | 0.1 | 0.4 | 0.6 | 1 | 0.3 | 0.8 |
the | 0 | 0 | 0 | 0 | 0.1 | 1 | 0.75 |
mat | 0 | 0 | 0.2 | 0.8 | 0.7 | 0.6 | 1 |
That is pretty clearly something we could represent with an matrix. Each row represents a token embedding in the input stream, and the columns in that row are the attention scores for the other tokens when considering that token.
To put this another way, in that matrix, the element is the attention score that we use when considering the th token to determine how much attention we should pay to the th token. (One thing that always trips me up with matrices is that they're indexed (row, column) -- entirely the opposite of plotting things where the x axis comes first, then the y.).
Now, as we're using the dot product of the input embeddings for our attention scores, we can generate that with some simple code. Raschka's example does exactly that, creating a 6 x 6 PyTorch matrix (his example is six tokens long) and then iterating over the input embeddings to fill in one row at a time, within that iteration iterating over the input embeddings again to fill in each column in that row with the dot product of the two embeddings.
He then moves on one step with something that I expect is perfectly obvious to anyone who's learned basic linear algebra -- that is, matrices and vectors -- recently, or who has kept that knowledge fresh, but was a bit of a jump for me:
When computing the preceding attention score tensor, we used
for
loops in Python. However,for
loops are generally slow, and we can achieve the same results using matrix multiplication:attn_scores = inputs @ inputs.T print(attn_scores)
He then invites us to check the results to confirm they're the same, and of course they are. Now, I understand why matrix multiplication is a good thing for efficiency's sake -- GPUs, for example, are essentially optimised for exactly that kind of calculation, so if we can use one to get all of the attention scores with what is basically a single operation, that's a big win. But why can we use it here to replace the specific loop-within-a-loop that he provided previously?
The last time I studied matrices in detail, this song was in the charts, so I wanted to expand this a bit. Why can we use that matrix multiplication as a shortcut? If it's obvious for you, then I do recommend that you skip the rest of this post :-) But for anyone else whose memories of this kind of thing pre-date the fall of the Berlin Wall, you might want to read on.
I've revisited matrices more recently (in particular when going through various courses on simple neural nets, where you use them to represent the weights connecting layers), but the jump from dot product to normal matrix multiplication was a bit too big for me to get in one step.
Let's start by thinking about what inputs
is. It's the input embeddings for
the input sequence, with one row per input token containing its embedding. That
means that it has n rows for n input tokens, and d columns, where d is the
size (the number of dimensions) of the embeddings. So it's an matrix.
We'll take a really simple example of that -- imagine and :
Token | ||
---|---|---|
So, token has an input embedding of , token has an input embedding of , and token has an input embedding of .
We're calculating the dot product to get our attention scores, and to get the dot product we multiply the elements of the vectors and add them up:
- 's attention score for is
- 's attention score for is
...and so on.
Now let's think in terms of matrices. If we have a matrix like this:
...then it is a matrix. We can multiply it by any other matrix that has 2 rows (ie. it's something), that much I remember. The number of columns in the first matrix must match the number of rows in the second, and the result has the number of rows from the first and the nuber of columns from the second,
The transpose of a matrix basically swaps rows and columns, so the transpose of a matrix is a one, so it's "compatible" for matrix multiplication. It would look like this:
So the calculation in Raschka's PyTorch code above, for this data, is this matrix multiplication:
Using this introductory matrix explainer we see that this is:
And that makes the connection clear; the location -- that is, at the th row and th column -- is the dot product of row in the first input matrix (treated as a vector) and the column in the second input matrix (likewise as a vector). Indeed, if I'm reading it correctly, that's pretty much the definition of how matrix multiplication works.
(I guess you can take it the other way around and see the normal vector dot product being a minimal case of matrix multiplication. If you have two vectors and then you work out the dot product by multiplying them element-wise and adding up those products. But you could also regard them as two matrices, take the transpose of one of them, do a matrix multiplication, and get a a matrix as the result, which contains the value you're looking for.)
So, by multiplying our matrix of input embeddings with its own transpose, we've essentially got an output matrix containing all of those input embeddings dot-producted (if there is such a word) against all of the others.
I hope that's going to keep this stuff clear in my mind going forward, because I suspect that having an intuitive grasp of even basic linear algebra will become increasingly important as I continue through this book.
I'll leave things there for this post; the next one is when we start building a real attention mechanism, beyond the dot product one, and I suspect it's going to be tough...