Writing an LLM from scratch, part 3

Posted on 26 December 2024 in Programming, Python, AI

I'm reading Sebastian Raschka's book "Build a Large Language Model (from Scratch)", and posting about what I found interesting every day that I read some of it.

Today I was working through the second half of Chapter 2, "Working with text data", which I'd started just before Christmas. Only two days off, so it was reasonably fresh in my mind :-)

Data sampling with a sliding window

The first section is about how we take a relatively unstructured document -- the book uses "The Verdict" by Edith Wharton -- and turn it into structured data that we can feed into an LLM so that it can predict the next word and be trained to do so with reasonable accuracy.

There's nothing conceptually difficult in this section, but I did find myself with some unanswered questions and a little bit of confusion. Let's work through it.

The goal of our LLM is to take some context and predict the next word (strictly, token). Let's say that the text that we're training it on is just

Once more unto the breach

We might want to train it on examples like this (written input_token_list -> target_output_token)

["Once"] -> "more"
["Once", "more"] -> "unto"
["Once", "more", "unto"] -> "the"
["Once", "more", "unto", "the"] -> "breach"

The book goes through some simple code to do exactly that kind of thing, and that's all pretty clear.

It's left implicit, but I also get the impression that we'd normally be using fixed input sizes. Let's say that this was three, then that sample input sentence would give us just two training samples:

["Once", "more", "unto"] -> "the"
["more", "unto", "the"] -> "breach"

(Later it's mentioned that we try to avoid overlaps between training samples, so that might not be strictly accurate -- but you can imagine how it would work if the training text was longer.)

Up to this point, everything was crystal-clear to me :-)

But then things changed, and it wasn't clear to me why this was. It has the feel of something that will become clear later, but at this point I am a bit confused.

The sentence that introduces the change is "we are interested in returning two tensors: an input tensor containing the text the LLM sees, and a target tensor that includes the targets for the LLM to predict" (italics mine). The target has become pluralised and is only included in the target tensor rather than being its sole element.

What this means is that we switch so that the target is a list of the same length as the input, but starting with the second element from the input and including the word that comes next -- the one we're trying to predict. To make that more concrete, let's say we were targetting an input size of three tokens like we were above -- we change to having training data that looks like this:

["Once", "more", "unto"] -> ["more", "unto", "the"]
["more", "unto", "the"] -> ["unto", "the", "breach"]

It's pretty clearly the same data as before with some extra stuff in the target, and the code to generate it is simple and easy to understand. I just don't understand why we're adding on the extra leading elements at this stage. It will doubtless become clearer later! I'm just confused right now. However, I've decided I'm going to follow a strict "no side quests" rule while reading this book, so it will have to remain a mystery for now.

Anyway, having established this data format, Raschka introduces some code to create a PyTorch Dataset subclass that can generate it for us. It's called GPTDatasetV1, so presumably will be enhanced later on -- in particular, I noted that it loads all of the text into memory, which would have to change with larger amounts of data. Anyway, the GPTDatasetV1 has as its parameters that text, a tokenizer, and two more interesting params:

I was finding it a bit hard to understand how these interacted at the edge cases. For example, considering

1 2 3 4 5 6

..with a max length of 3 and a stride also of 3, it clearly splits into two potential input lists, 1 2 3 and 4 5 6. But given that we need, for each of these input lists, a target list which is the second and third elements from the original input list and then the next one from the original sequence, what do we do at the end? There's no element after 6 to put into a target sequence.

Playing with things in IPython made it clear what happens (at least, with a bit of fiddling with PyTorch docs in order to find out how to convert tensor([[ 16, 362, 513]]) into a list I could pass to the tokenenizer to decode):

In [135]: for stride in range(1, 4):
     ...:     print(f"\nWith stride={stride}:")
     ...:     dataloader = create_dataloader_v1("1 2 3 4 5 6", batch_size=1, max_length=3, stride=stride, shuffle=False)
     ...:     for inputs, target in iter(dataloader):
     ...:         print(tokenizer.decode(inputs[0].tolist()), " --> ", tokenizer.decode(target[0].tolist()))
     ...:

With stride=1:
1 2 3  -->   2 3 4
 2 3 4  -->   3 4 5
 3 4 5  -->   4 5 6

With stride=2:
1 2 3  -->   2 3 4
 3 4 5  -->   4 5 6

With stride=3:
1 2 3  -->   2 3 4

So, if it can't generate a target list of the given size, it won't output a input/target pair at all. To double-check what happens with one extra element in the input list:

In [137]: for stride in range(1, 4):
     ...:     print(f"\nWith stride={stride}:")
     ...:     dataloader = create_dataloader_v1("1 2 3 4 5 6 7", batch_size=1, max_length=3, stride=stride, shuffle=False)
     ...:     for inputs, target in iter(dataloader):
     ...:         print(tokenizer.decode(inputs[0].tolist()), " --> ", tokenizer.decode(target[0].tolist()))
     ...:

With stride=1:
1 2 3  -->   2 3 4
 2 3 4  -->   3 4 5
 3 4 5  -->   4 5 6
 4 5 6  -->   5 6 7

With stride=2:
1 2 3  -->   2 3 4
 3 4 5  -->   4 5 6

With stride=3:
1 2 3  -->   2 3 4
 4 5 6  -->   5 6 7

So that's pretty clear.

(One interesting thing to note is the extra spaces before the numbers after 1 -- as I noted last time, tokenizers tend to have separate tokens for, say, "2", and " 2". Due to the structure of my input string, it chose the without-space one for 1, and the with-space ones for the other numbers.)

Anyway -- apart from my ongoing confusion about why the target isn't just the next word that we want the LLM to predict, but instead all of the words in the input apart from the first one, plus the next word at the end, all of this seemed pretty logical and clear.

The next steps were just explaining how you could get the dataloader to provide batches (something I was familiar with from my fine-tuning experiments).

For anyone that's reading this that doesn't already know, when you feed an input into a neural network, it's a vector, but you can also feed in a matrix -- that is, a bunch of inputs simultaneously. So, for a single input you're essentially feeding in a 1 x n matrix -- one row, n columns, where n is the number of inputs to the neural network. But if you feed in a b x n matrix, with b separate inputs, one on each row, the maths is exactly the same and the code doesn't need changing.

This takes up more memory and processor power to do the calculations and get an output, but running (say) a batch of 8 inputs and getting 8 outputs in one go is less expensive in processing time than going through the full calculations 8 times, once for each input. In particular, when you're training, you can save quite a lot of time doing batches -- so long as you have enough RAM (or more likely, VRAM) to hold them.

There was an interesting side note about this, though -- Raschka mentions that "small batch sizes require less memory during training but lead to more noisy model updates" -- I'd certainly seen the memory effects of batch sizes during my fine tuning experiments, and (during a conversation with Claude when looking into gradient checkpointing) got the impression that larger batches might lead to better-generalised models:

[A]pparently, larger batch sizes can lead to better, more generalised training. Intuitively this makes sense. For an LLM, you can imagine that training it on the sentence "I like bacon", and then "I like cheese", and then "I like eggs", adjusting the weights after each one, might lead it to "learn" more from the first example than from the second, and more from the second than the third. Training on a batch of all three of them in one go and then doing a single adjustment of the weights afterwards seems like it might consider them all more equally.

It was good to read some confirmation of this underlying process (though I'm sure my analogy is absurdly simplified).

The other thing that I found interesting was a mention that we would normally try to avoid overlap between batches to avoid overfitting. That sounds to me (and perhaps I'm misunderstanding) like in practice the max_length and the stride would be set to the same number.

Anyway, that ended the section on setting up the dataset for training. Now it was on to the section on token embeddings.

Token embeddings

This one I was expecting to be really tricky -- how do you work out embeddings for each of the tokens in your vocabulary, or at least, how do you put things in place so that you can train them? But I was forgetting that because the embeddings are trained as part of the general LLM training, all we really needed to put in place was the infrastructure for that.

One interesting thing that came out of this and the various notes was why we use embeddings at all. It's relatively easy to see why you can't just feed in token IDs directly -- they're discrete numbers that don't really mean anything in and of themselves. Neural nets work best when they're dealing with some kind of continuous data, where (say) 1.5 really does mean something halfway between 1 and 2. But if "once" is token 123, and "more" is token 124, what does 123.5 mean? Nothing meaningful.

I'd read in the past about "one-hot" encodings to work around issues like this, where you have discrete inputs. That's where you have a vector of length n, where n is the number of options for a choice -- you fill it with zeros, apart from one specific element that is 1, representing the particular choice in this case. It crops up in Jeremy Howard's Fast.AI course, for example -- IIRC it's specifically the case where he's working through building a system to predict who survived the Titanic, and one of the features he's looking at is the class of the passenger: first, second or third. He explains that this is a bad feature, because there's no (for example) 2.5th class, so he creates new is_first_class, is_second_class, and is_third_class synthetic features. You can think of that as being adding on a new feature class which is a vector with three elements, which is set to [1, 0, 0] for first class, [0, 1, 0] for second, and [0, 0, 1] for third. That is more meaningful because fractions start meaning something. Let's say that your model wanted to express "second class passengers were more likely to survive, then first, then third", it could multiply the passenger's one-hot encoded class vector by [0.3, 0.5, 0.1].

So, I'd had a vague thought in the past that perhaps LLMs received each token as a one-hot vector -- that is, each token was a vector as long as the vocab size, with a single one in the position related to its token ID, and zeros everywhere else. This would then be fed into the LLM, which could make better sense of them. Let's say that token 1001 was "also" and 3109 was "additionally" -- they have very similar meanings, and some part of the network that was trying to deal with that kind of concept could have weights that were high for the 1001st and 3109th positions and low everywhere else.

In a side note, Raschka points out that this is essentially the same as the embedding system. If you take those one-hot vectors and pass them through a single fully-connected layer of neurons, you are performing the same calculations as you would to generate the embeddings. So -- embeddings here are just a neat way of creating an easily-trainable input layer. This makes a lot of intuitive sense to me, but I need to think it through to make sure I really do understand it -- something about it feels like there's an extra order involved -- order in the sense that vectors are order 1, matrices order 2, and so on -- that doesn't exactly match my mental model of how data flows through a neural network. But still, an intuitive "OK, that sort of kind of makes sense" is what I need right now.

All that aside: the goal with embeddings was to have a set of vectors of some particular dimensionality (you might remember from the last post that GPT-2's had 768 dimensions and GPT-3 had 12,288). Each token would have one specific vector associated with it, so the process for getting the text into the LLM would look something like this pseudocode:

tokens = tokenizer.encode(input_text)
embeddings = [embeddings_dict[token] for token in tokens]
result = do_LLM(embeddings)

...where embeddings_dict would just have a map from each token ID to its associated embedding vector. Indeed, when I started reading the explanatory text I was expecting to see exactly that code -- a bunch of PyTorch vectors, each with the requires_grad property set to make them trainable.

As it turns out, though, PyTorch has a torch.nn.Embedding class that handles that for you. From the docs it looks like it's not much more than a wrapper around the kind of dictionary I was thinking about -- though I can easily imagine that having it exist as a more abstract entity allows optimisations that make it faster than a lower-level implementation would.

So, the code that Rashka goes through gives a few simple examples with low numbers of dimensions just to make the concepts clear.

And next, it was on to encoding token positions.

Encoding word positions

The core here is that the self-attention mechanism (that magical thing I'm so looking forward to learning about!) does not apparently "understand" the concept of ordering -- if it gets the same token twice in a stream it will treat it the same way each time. So in:

The cat ate the Christmas turkey

...the two "the"s would be treated the same, which is obviously not what we want, as they relate to "cat" and "Christmas turkey" respectively.

So once we've worked out encodings for each token, we also need to add on information about where they are in the input. These are called position embeddings, there are two flavours -- relative and absolute -- and we simply add them element-wise to the token embeddings. It's not explained how relative position embeddings work in practice (another interesting side quest I'm very carefully avoiding), but for absolute ones, we just need another dict (or it might as well be a list) mapping from the index of a token in the input stream to a vector -- that is, position 1 in the input will always have the same position embedding, position 2 its own different one, and so on. This, of course, gives us a hard cap on the number of tokens we can accept (one of the reasons LLMs tend to have a fixed context window, I imagine). But it's simple, and apparently OpenAI's GPT models use it, so it's presumably a good way to do things :-)

Here's some sample pseudocode:

tokens = tokenizer.encode(input_text)
token_embeddings = [token_embeddings_dict[token] for token in tokens]
position_embeddings = [position_embeddings_list[ii] for ii in range(len(tokens))]
input_embeddings = [(te + pe) for (te, pe) in zip(token_embeddings, position_embeddings)]
result = do_LLM(input_embeddings)

That means that the position embeddings need to be the same dimensionality as the token embeddings, of course (because otherwise we can't do the element-wise additions in the calculations of input_embeddings).

But again, these are things that are trained as part of the LLM training, so we just start off with numbers -- say all-ones for the one for position 1, all 2s for position 2, and so on. Training will adjust them appropriately.

Summary

Apart from my ongoing confusion about the targets in our dataset being lists with the next token to predict at the end rather than just the next token, I think that was all pretty clear to me, and hopefully my notes are reasonably clear too :-)

I'm fighting hard to resist the temptation to read up on that, anyway, as I've given myself strict "no side quests" instructions for this run through the book. If I still don't understand it at the end, then I'll definitely be reading up on it.

Next time around, it's chapter 3 -- self-attention. Definitely looking forward to that one!