Writing an LLM from scratch, part 5 -- more on self-attention
I'm reading Sebastian Raschka's book "Build a Large Language Model (from Scratch)", and posting about what I found interesting every day that I read some of it. In retrospect, it was kind of adorable that I thought I could get it all done over my Christmas break, given that I managed just the first two-and-a-half chapters! However, now that the start-of-year stuff is out of the way at work, hopefully I can continue. And at least the two-week break since my last post in this series has given things some time to stew.
In the last post I was reading about attention mechanisms and how they work, and was a little thrown by the move from attention to self-attention, and in this blog post I hope to get that all fully sorted so that I can move on to the rest of chapter 3, and then the rest of the book. Rashka himself said on X that this chapter "might be the most technical one (like building the engine of a car) but it gets easier from here!" That's reassuring, and hopefully it means that my blog posts will speed up too once I'm done with it.
But first: on to attention and what it means in the LLM sense.
Just to be clear at the outset -- everything important here is covered perfectly well at the start of chapter 3 in the book -- the introduction, 3.1, 3.2 and the intro to 3.3 (up to the start of 3.3.1). But it took me a while to get my head around it, and at least to clarify things in my own mind, I think it's worth writing it here. I'll also be duplicating some things from my last blog post, but perhaps with a different spin on them.
Let's start off with the original encoder-decoder translation systems that evolved into the Transformers architecture that is the foundation for modern GPT-style LLMs. Here's the wording I used in my original post:
- The input text in the source language was run through an encoder, a Recurrent Neural Network (RNN) that processed it a word (or maybe a letter or a token) at a time and built up a vector (specifically in its hidden state) that represented in some abstract way the meaning of the sentence -- basically, an embedding.
- That embedding was then passed in to a decoder, which was an RNN that churned through the embedding to produce an output in the target language. (I believe that the RNN would also modify the embedding as it went along in order to keep track of how far it had got in the sentence).
It's probably worth expanding on that a bit, in particular to show how the "modification of the embedding" I was talking about actually happens.
As a first step, let's do a mini-dive into RNNs. They differ from a normal neural network by the fact that they store state between uses. So let's say:
- You feed in input i1, and get output o1. During its processing of that input, it will stash away some hidden state h1, which will be kept for the next use.
- Now, if you feed in i2, it will use both that and h1 to work out what o2 is, and will also update its hidden state to h2.
A good way of looking at it (as Rashka says) is to think of this as being outputs from one step being fed in to the next calculation. So, to rephrase the example above in those terms:
- You feed in i1 and an "empty" initial hidden state h0. You get two outputs, o1 and h1
- You feed in i2 and h1, and get o2 and h2
...and so on. Now, remembering that the inputs, outputs and hidden states are vectors, you can easily see that this is just a normal neural network with some extra inputs and outputs (the same number of each) that are reserved to be the hidden state.
(Apparently that specific explanation only strictly holds for single-layer RNNs, because multi-layer ones would hold hidden state per layer, but I think you can keep the same mental model of output going to input even with those -- just imagine that, for an n-layer network, you have a hidden state vector n times larger, and have it fed through to the end via layers where it's just multiplied by 1 at each layer -- that is, layer 1 would output its hidden state to layer 2, which would pass it on unchanged to layer 3 while adding on its own, and so on, so that all layers' hidden states were available on the output. Likewise each layer would ignore the incoming hidden states meant for other layers. That's just a model to keep this "feed the output from one step into the next" concept going rather than how it would work in practice, but it helps me keep it clear in my head.)
As to how you train such a network -- you essentially "unroll" it in time. So let's say you've fed a two-layer network a sequence of ten inputs; assuming it's a case where you don't care about the contents of the hidden state at the end, which in most RNN cases you would not, you treat as normal back-propagation across a normal NN with a depth of 2 * 10 = 20 layers, in which the parameters in each set of two (that is, each step) are constrained to be the same in each of the ten occurrences. That is one of those things that is easy enough to imagine in an intuitive way but is undoubtedly a nightmare to get your head around and implement if you actually have to do it. (Side quest successfully avoided.)
Anyway, let's go back to our translation task. We want to feed it the sentence
The dog, who must have been able to see the cat running through the garden.
We'll assume one token per word, so we:
- Feed our encoder RNN the word "The"; it updates its hidden state and outputs something. For this use case, we just ignore that (or, perhaps, we can just have an RNN with no real outputs -- just those "virtual outputs" we're using for the hidden state).
- Now we feed it "dog". Again, the hidden state gets updated, and there's a null output.
We repeat that for every token. Once we've completed the sentence, the plan is that we have something in the hidden state that is essentially an embedding that captures the meaning of the original sentence.
Next we move on to the decoder. In this case, we "preload" it with the hidden state from the encoder.
In my first draft, I felt that this would add a constraint, because the decoder would have to have a hidden state that was "compatible" in size with the encoder -- specifically, if what I understand about multi-layer RNNs is correct, the decoder would have to have the same number of layers as the encoder, each of which would have the same size of hidden states as their equivalents on the encoder. But apparently real-world implementations would generally have some kind of mapping transformation in between the two -- also learned -- to match things up.
So, once we have that embedding, how do we generate an output?
Just as our original encoder RNN didn't really need an output -- it could just produce a new hidden state -- the decoder can just accept a hidden state as its "input". But that doesn't stop it from having an output in addition to a new hidden state for the next run.
So, it starts off with a hidden state that represents
The dog, who must have been able to see the cat running through the garden.
...and:
- We run it with the initial hidden state and it outputs the word "Der" plus a hidden state that represents "dog, who must have been able to see the cat running through the garden."
- We run it again with that hidden state, and it outputs "Hund" and an updated hidden state.
- And so on.
- Eventually it will reach a state where the hidden state is essentially a representation of "there's nothing more to output" and will produce some kind of end-of-sequence token, and we're done.
That sounds simple enough (for rather abstruse values of simple). But just looking at the first two words misses the issues with word re-ordering. Remember that in our English-to-German translation task,
The dog, who must have been able to see the cat running through the garden.
...translates to
Der Hund, der die Katze durch den Garten hatte jagen sehen können müssen.
...which literally translates as:
The dog, who the cat through the garden had chasing see can must
So let's think about what the hidden state might represent at each step (I'll ignore punctuation):
- "The dog who must have been able to see the cat running through the garden"
- "dog who must have been able to see the cat running through the garden"
- "who must have been able to see the cat running through the garden"
- "must have been able to see the cat running through the garden"
- "must have been able to see cat running through the garden"
- "must have been able to see running through the garden"
- "must have been able to see running the garden"
- "must have been able to see running garden"
- "must have been able to see running"
- "must been able to see running"
- "must been able to see"
- "must been able"
- "must"
We're getting to crazy levels of abstract there. This hidden state is having to do a huge amount of work for us. Perhaps you could have something more like:
- "The dog who must have been able to see the cat running through the garden"
- "dog who must have been able to see the cat running through the garden"
- "who must have been able to see the cat running through the garden"
- "who must have been able to see the cat running through the garden" (but I've already said the "who")
- "who must have been able to see the cat running through the garden" (but I've already said the "who" and the "the" of "the cat")
...and so on. But either way, it's easy to see how the whole thing might work for simpler sentences but collapse with more complex ones. This, plus the problem with using a fixed-size embedding to represent a (potentially long and complex) document, was called the fixed-length bottleneck. There was also a problem with vanishing gradients, which if I understand it correctly was related to the depth of the RNNs when doing back-propagation -- imagine a 5-layer network trained on a 100-length sequence -- your backward pass has to go through 500 layers, and it's easy to see how the gradients at the start of that pass might essentially have disappeared by the time they reach the first layer.
This encoder-decoder setup, at least as far as I can tell, was described by Cho et al in this paper submitted to Arxiv on 2 June 2014, "Learning Phrase Representations using RNN Encoder-Decoder for Statistical Machine Translation":
In this paper, we propose a novel neural network model called RNN Encoder– Decoder that consists of two recurrent neural networks (RNN). One RNN encodes a sequence of symbols into a fixed-length vector representation, and the other decodes the representation into another sequence of symbols.
Things seem to have been a little busy after that; on 1 September 2014, the Bahdanau paper (which Raschka mentions in Appendix B) was submitted to Arxiv:
In this paper, we conjecture that the use of a fixed-length vector is a bottleneck in improving the performance of this basic encoder–decoder architecture, and propose to extend this by allowing a model to automatically (soft-)search for parts of a source sentence that are relevant to predicting a target word, without having to form these parts as a hard segment explicitly.
This appears to introduce the concept of attention:
Intuitively, this implements a mechanism of attention in the decoder. The decoder decides parts of the source sentence to pay attention to. By letting the decoder have an attention mechanism, we relieve the encoder from the burden of having to encode all information in the source sentence into a fixed-length vector.
If I'm understanding things correctly, the decoder is actually looking at the encoder's hidden states rather than the inputs themselves, but that probably doesn't matter too much for the level of understanding I need right now -- after all, we're looking at the history rather than trying to build one of these systems.
Anyway, the point here is that by adding to the decoder the ability to pay attention to the encoder's input, the embedding passed from the encoder has to do much less work -- and the results are much better.
(It's also worth noting that on 10 September 2014, Sutskever et al's "Sequence to Sequence Learning with Neural Networks" appeared on Arxiv. If I understand it correctly, it got improved performance without an attention mechanism by using an updated form of RNNs called Long short-term memory networks (LSTMs). I think this might have been because -- while they still had some part of the fixed-length bottleneck -- they were less prone to vanishing gradients.)
(Also: for clarity: I have not read the papers in full! That could be an interesting follow-up to this series, perhaps. It's also worth noting that those dates are just when Arxiv received them -- they'd probably been floating around for some time prior to that.)
Anyway, when attention was introduced as a concept, it was the decoder attending to tokens in the encoder's input. This, per Wikipedia, is called cross-attention, which makes sense.
A concept that seems to have bubbled up after that, but I can't find a solid source for, is self-attention. If I can trust Wikipedia (which I think I can more than the LLMs), it sounds like it was an idea in the air that was finally made solid in the "Attention Is All You Need" paper. This introduced a bunch of concepts at once, but to focus on attention: as well as having the decoder pay attention to the different inputs that the encoder had seen (or rather, the hidden states they gave rise to), it introduced the idea of the decoder being able to pay attention to its own outputs -- that is, self-attention.
Things get a little messy here in trying to treat this as a step on top of Bahdanau, though, because "Attention Is All You Need" introduced a bunch of ideas at once, as I mentioned in the last post. It was a landmark paper and is regarded as pretty much a revolution in the way language-processing models worked, so I guess that's not all that surprising :-)
Most importantly for what I'm trying to understand right now, it got rid of the RNNs and the hidden states entirely -- including that hidden state that was passed from encoder to decoder with the embedding of the sequence's meaning -- which made training easier (see my note on RNN training above -- it could handle all of the inputs at once rather than doing them one at a time), and replaced them with two kinds of attention:
- Cross-attention, where the decoder was looking at (some kind of representation of) the encoder's inputs.
- Self-attention, where the decoder was looking at its own previously-generated tokens.
- And the encoder also had self-attention, which sounds like an interesting topic but also a side quest right now!
By doing those, it could translate sentences, and do so more effectively than the earlier architectures.
Now, what we're doing in the book -- as we're focusing on modern decoder-only LLMs -- is trying to implement a setup where we only have the second one of those. We want to take an input and predict the next word. There's no encoder to cross-attend to, so we're going to implement something that looks at a sequence of tokens (or rather, the learned embeddings for them) and works out what the next one should be using a self-attention mechanism.
Phew.
Now I think my head is in a state where I can move on!
I think that the main reason I've been blocked at this point in the book is that the way my mind works might be a little different to the way the book is targeted. Rashka says quite clearly at the start of the chapter that "[w]e will largely look at attention mechanisms in isolation and focus on them at a mechanistic level" -- that is, the book is deliberately aiming to explain the "how" rather than the "why". I find that very hard to do, and feel I really need to understand why something is so in order to be able to get a good understanding as to how the how works.
Perhaps an individual learning style thing. Still, with the help of Claude and ChatGPT (cross-checking and backed up with more reliable sources on the Internet) I can hopefully backfill the parts that I need in addition to what is in the text, and move forward :-)