Messing around with fine-tuning LLMs, part 5 -- exploring memory usage
My goal is to fine-tune an 8B model -- specifically, the Llama 3 8B base model -- on the openassistant-guanaco dataset, without using tricks like quantization or LoRA. I'm doing this as a way to try to understand how to do full-on multi-GPU training of a model that cannot be trained on just one GPU.
I've been building up to this goal gradually; so far, I've:
- Fine-tuned a 0.5B model on my own machine.
- Done the same, but in the cloud using Lambda Labs.
- Run some multi-GPU training, but using the GPUs to run larger batches for the 0.5B model -- which in turn means training faster -- rather than to train a larger model.
- Successfully fine-tuned the 8B model across multiple GPUs using ZeRO and DeepSpeed, but with the optimizer offloaded to CPU.
This time around, I wanted to find out why I had to offload the optimizer, because it didn't seem like it should be necessary. Hugging Face helpfully document a DeepSpeed function that you can call to estimate the VRAM requirements for training a model with ZeRO, and when I ran it against the 8B model, I got this:
(fine-tune) ubuntu@130-61-28-84:~/fine-tune-2024-04$ python -c 'from transformers import AutoModel; \
from deepspeed.runtime.zero.stage3 import estimate_zero3_model_states_mem_needs_all_live; \
model = AutoModel.from_pretrained("meta-llama/Meta-Llama-3-8B"); \
estimate_zero3_model_states_mem_needs_all_live(model, num_gpus_per_node=8, num_nodes=1)'
[2024-05-17 23:19:31,667] [INFO] [real_accelerator.py:203:get_accelerator] Setting ds_accelerator to cuda (auto detect)
[WARNING] async_io requires the dev libaio .so object and headers but these were not found.
[WARNING] async_io: please install the libaio-dev package with apt
[WARNING] If libaio is already installed (perhaps from source), try setting the CFLAGS and LDFLAGS environment variables to where it can be found.
[WARNING] Please specify the CUTLASS repo directory as environment variable $CUTLASS_PATH
[WARNING] sparse_attn requires a torch version >= 1.5 and < 2.0 but detected 2.2
[WARNING] using untested triton version (2.2.0), only 1.0.0 is known to be compatible
Loading checkpoint shards: 100%|============================================================================================================| 4/4 [00:02<00:00, 1.61it/s]
Estimated memory needed for params, optim states and gradients for a:
HW: Setup with 1 node, 8 GPUs per node.
SW: Model with 7504M total params, 525M largest layer params.
per CPU | per GPU | Options
188.72GB | 1.96GB | offload_param=cpu , offload_optimizer=cpu , zero_init=1
335.50GB | 1.96GB | offload_param=cpu , offload_optimizer=cpu , zero_init=0
167.75GB | 3.70GB | offload_param=none, offload_optimizer=cpu , zero_init=1
335.50GB | 3.70GB | offload_param=none, offload_optimizer=cpu , zero_init=0
23.48GB | 17.68GB | offload_param=none, offload_optimizer=none, zero_init=1
335.50GB | 17.68GB | offload_param=none, offload_optimizer=none, zero_init=0
It was saying that I only needed 17.68 GiB VRAM per GPU with no optimizer offload -- but I had needed to offload it even though I had 40 GiB per GPU. Why was that? What was I doing wrong? The documents that mention that function also say:
these are just the memory requirements for the parameters, optimizer states and gradients, and you'll need a bit more for the CUDA kernels and activations
...but 22 GiB extra is more than "a bit more". I must have been misunderstanding something.
Digging into this took an embarrassing amount of time -- I started work on it shortly after publishing my last post in this series, so that's been more than a month! And it's embarrassing that I took so long because the reason why I should not trust the number reported by that script was staring me in the face from the start, and involved something I'd discovered in my first explorations into this stuff.
Still, I learned a lot over the course of these investigations, so I think it's worth showing at least some of the journey. The post below is a distilled version of my lab notes and is a little rambling, but you might find it interesting if you're also digging into memory usage during LLM training as a beginner. If not, and you're looking for more carefully planned experiments and results, hopefully the next post in this series will have more of those :-)
Let's get going.
Can I repro locally?
The problem I was trying to solve was that on the machine I was renting for US$10/hour, the DeepSpeed helper function said that I needed just under 18 GiB of VRAM per GPU to train the 8B model without offloading anything, but in reality I needed more than 40 GiB to do it without offloading. To fit into 40 GiB, I had to offload the optimizer.
I didn't want to spend lots of money while debugging this, so the first question was, could I repro the problem on my own machine? I knew that I could train a 0.5B model locally. So what did the helper function say I needed for that model if I wanted to train it with DeepSpeed on my 1-GPU local machine?
Updating the command above to reduce the number of GPUs and to change the model was simple enough:
(fine-tune) ubuntu@130-61-28-84:~/fine-tune-2024-04$ python -c 'from transformers import AutoModel; \
from deepspeed.runtime.zero.stage3 import estimate_zero3_model_states_mem_needs_all_live; \
model = AutoModel.from_pretrained("Qwen/Qwen1.5-0.5B"); \
estimate_zero3_model_states_mem_needs_all_live(model, num_gpus_per_node=1, num_nodes=1)'
[2024-05-25 21:54:57,905] [INFO] [real_accelerator.py:203:get_accelerator] Setting ds_accelerator to cuda (auto detect)
[WARNING] Please specify the CUTLASS repo directory as environment variable $CUTLASS_PATH
[WARNING] sparse_attn requires a torch version >= 1.5 and < 2.0 but detected 2.2
[WARNING] using untested triton version (2.2.0), only 1.0.0 is known to be compatible
Estimated memory needed for params, optim states and gradients for a:
HW: Setup with 1 node, 1 GPU per node.
SW: Model with 463M total params, 155M largest layer params.
per CPU | per GPU | Options
11.67GB | 0.58GB | offload_param=OffloadDeviceEnum.cpu, offload_optimizer=OffloadDeviceEnum.cpu, zero_init=1
11.67GB | 0.58GB | offload_param=OffloadDeviceEnum.cpu, offload_optimizer=OffloadDeviceEnum.cpu, zero_init=0
10.37GB | 1.44GB | offload_param=none, offload_optimizer=OffloadDeviceEnum.cpu, zero_init=1
10.37GB | 1.44GB | offload_param=none, offload_optimizer=OffloadDeviceEnum.cpu, zero_init=0
0.87GB | 8.36GB | offload_param=none, offload_optimizer=none, zero_init=1
2.59GB | 8.36GB | offload_param=none, offload_optimizer=none, zero_init=0
So, I should need 8.36 GiB to train locally. I knew that wasn't the case, at least when I had trained earlier without using DeepSpeed -- I could only just fit it into the GPU, which has 24 GiB of VRAM.
An obvious quick sanity check: to see if I got similar VRAM usage with DeepSpeed,
I put together a script
third-0.5b-fine-tune-as-script-with-deepspeed.py
that combined my original fine-tune code for the 0.5B model with the extra config
that was required for DeepSpeed, and then ran it locally:
$ deepspeed --num_gpus=1 third-0.5b-fine-tune-as-script-with-deepspeed.py
As soon as it started training, according to nvtop
it was using 17 GiB, and after a little while it
jumped to just over 20 GiB.
So the estimated VRAM usage from the function and the amount used in reality were very different, just as it had been for the 8B model when I ran it on the larger machine. This meant that I was able to repro the problem on my own machine and wasn't going to have to spend lots to investigate this. So that was a win :-)
Time to dig down a bit.
ZeRO 2 or ZeRO 3?
My initial thought was "is the estimate for 16-bit but we're training in
32-bit?" I figured that the best way to work out how the calculation was done
would be to go straight to the
source code.
The function I was calling,
estimate_zero3_model_states_mem_needs_all_live
, delegated most of the real
work to estimate_zero3_model_states_mem_needs
, and there I found this line:
largest_layer_memory = (4 * largest_layer_params)
So that definitely looked like it was allowing for 4 bytes per parameter -- 32-bit rather than 16.
But poking around a little further, I came across a much more in-depth
description of ZeRO
than the Hugging Face one
I had been relying on so far. One thing that leapt out at me was that I was using
ZeRO stage 2 in my JSON configuration -- and the function was
estimate_zero3_model_states_mem_needs_all_live
. Changing from 2 to 3 looked like it would
make things worse rather than better, though -- in the documentation for the
stage
parameter it said:
Stage 0, 1, 2, and 3 refer to disabled, optimizer state partitioning, and optimizer+gradient state partitioning, and optimizer+gradient+parameter partitioning, respectively
It also referred to 16-bit model parameters despite the 4 bytes per parameter in the code that the code I had found was using. Perplexing.
Even more confusingly, why was that code even trying to work out the memory required for the largest layer in the model, given that ZeRO (at least as I understand it) splits up the model horizontally rather than vertically? My best guess is that this is for activations. Each card needs all of the activations from outputs of the previous layer, which means that even if that layer itself is distributed over all of the cards, the activations will need to be copied to all of the GPUs once the layer has been computed. Perhaps that's also why it's 4 bytes rather than 2; they say that the parameters are 16-bit, but perhaps the activations are 32.
Either way, the function I was using to estimate memory usage was for ZeRO 3, I was using ZeRO 2, so obviously if I wanted to work out why the memory estimates were wrong, I would have to switch to ZeRO 3 to find out!
So, I set the stage
to 3 in the JSON file and ran the 0.5B fine-tune again, with
a batch size of one.
Memory usage rapidly ramped up to 21 GiB -- close to the amount of VRAM that
I found with my first fine-tune of that model. In a way,
that was kind of reassuring. If you're using a single-GPU machine, you'd expect
to have roughly the same memory usage when using a multi-GPU training system
as you would when not using that system -- a bit of extra overhead for the
multi-GPU framework, but that's it.
But the question still remained: the worst-case VRAM usage with the estimation function for ZeRO stage 3 was 8.36 GiB, and I was seeing about 2.5x that much.
What was I doing wrong?
Measuring memory usage more carefully
I decided to try just loading up the model
and see how much VRAM it used without any training. That came out as 2.48 GiB
according to nvtop
, well below the estimate.
But when measuring that, I realised that when I'd been noting down VRAM usage numbers
from nvtop
, I had not consistently been noting down the actual
VRAM usage for the Python process that was running the train; sometimes I'd
been looking at the total VRAM usage for the system. I'm on a minimalist Arch
setup, so usage from other processes is really low, but I decided I should
work out some way to be more precise.
The first step was to run the train again and make sure that I got the numbers for the process. I did that, and saw that it was 20,336 MiB. Also, there was a step-by-step ramp-up as the process started:
- From zero to 1,960 MiB
- From there to about 8,000 MiB
- Then up to 14,000 MiB
- Then finally up to the 20,336 MiB
What interested me about that in particular was that the first jump -- which I suspected was the loading of the model -- was smaller than the usage of the script that just loaded it and nothing else.
16-bit vs 32-bit
At this point I noticed something in my training arguments:
fp16=True,
OK, finally something that made immediate sense. The script that loaded the model and did nothing else was just loading it straight into CUDA with no trainer. When I did use a trainer, it used less VRAM because that trainer was loading it in 16 bits. There was obviously an overhead -- using the number from just loading the model, 2,482 MiB, and dividing it by 2 to allow for the change from 32-bit to 16-bit should be 1,241 MiB, not 1,960 MiB, but it's entirely plausible that other stuff was being loaded at the same time -- maybe the training data and/or the CUDA kernels?
I decided to see what removing that fp16
parameter did to memory usage. I
did so, and running the script again led to 22,932 MiB used in nvtop
, and the
estimated time to completion roughly doubled. The memory usage was kinda-sorta in-line with the
model taking up twice as much space -- a bit more than I would have expected, but
within reasonable bounds.
I tried fiddling around a bit by adding a flag to the JSON to try to get DeepSpeed to take full responsibility for the quantization side of things rather than using Transformers; the DeepSpeed docs had an example:
{
"zero_optimization": {
"stage": 3,
},
"fp16": {
"enabled": true
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": 0.001,
"betas": [
0.8,
0.999
],
"eps": 1e-8,
"weight_decay": 3e-7
}
},
}
Now, I had this minimal JSON file:
{
"train_micro_batch_size_per_gpu": 1,
"zero_optimization": {
"stage": 3
}
}
...but also some config in my TrainingArguments
object that looked like it would conflict with
a simple merge of the two:
args = TrainingArguments(
'outputs',
learning_rate=8e-5,
warmup_ratio=0.1,
lr_scheduler_type='cosine',
fp16=True,
evaluation_strategy="epoch",
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size * 2,
num_train_epochs=2,
weight_decay=0.01,
deepspeed="ds_config.json",
report_to='none',
)
So I came up with this to combine them:
{
"zero_optimization": {
"stage": 3
},
"fp16": {
"enabled": true
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
"betas": "auto",
"eps": 1e-8,
"weight_decay": "auto"
}
},
"train_micro_batch_size_per_gpu": 1
}
With that, my memory usage was 17,052 MiB. So that was pretty good! But still about double what the estimation function was coming up with.
This bittedness thing felt like a blind alley. I felt that the mistake I was making that was pushing my memory usage so far over the estimated amount was somewhere else.
With some misgivings, given the problems I had with hallucinated APIs in my previous experiment, I decided to see if ChatGPT might be able to help.
More detailed memory measurements: in which ChatGPT is actually pretty helpful
The first suggestion it had was to drop calls to
print(torch.cuda.memory_summary(device=None, abbreviated=False))
...at various points in the code -- at the start, after the dataset was tokenized,
after the trainer was created, and after the training. This script
was the result. The memory_summary
function produces tons of output, so I
won't paste it in here, but essentially the first three were zero, and I didn't
let it run to completion because what I was trying to debug was the memory
usage during training, so a summary at the end was not going to help all that
much. That all made sense. Nothing was being loaded onto the GPU until we started
training, so VRAM usage would be zero until that point.
But this felt like a thread I could pull on. How might I get some kind of memory usage logging during the training?
I couldn't see an immediate way to do that, and for a while I went down a bit of a rabbit hole of trying to offload everything to the CPU so that I could see if that made VRAM usage go down to zero. Unfortunately there were no docs that I could find to help, and ChatGPT went off on another acid trip and dreamt up a few new APIs for me, so I won't give the details here.
Back on track: ChatGPT suggested adding this code to hook into the trainer to
track memory usage over time, by overiding the training_step
method in the Trainer
and printing out memory usage before and after calling the original superclass version
of the function:
def print_memory_usage(step, stage):
allocated = torch.cuda.memory_allocated()
reserved = torch.cuda.memory_reserved()
print(f"Step {step} ({stage}): Allocated: {allocated / (1024 ** 3):.2f} GB, Reserved: {reserved / (1024 ** 3):.2f} GB")
class MemoryLoggingTrainer(Trainer):
def training_step(self, *args, **kwargs):
step = self.state.global_step
print_memory_usage(step, "before training step")
output = super().training_step(*args, **kwargs)
print_memory_usage(step, "after training step")
return output
...and then using that trainer. This sounded like an excellent idea; inserting logging into third-party libraries is something I do (and recommend) for drilling down into what's going on inside any framework I use. Indeed, often I'll just create a throwaway virtualenv and hack the source of the library I'm using directly in order to work out what is going on.
Luckily, there was no need to go quite so all-in here. The subclassing version printed out some useful stuff -- there's a lot of it, so apologies to anyone reading on mobile devices:
0%| | 0/19692 [00:00<?, ?it/s]
Step 0 (before training step): Allocated: 4.58 GB, Reserved: 5.76 GB
Step 0 (after training step): Allocated: 4.60 GB, Reserved: 12.47 GB
0%| | 1/19692 [00:00<1:50:01, 2.98it/s]
Step 1 (before training step): Allocated: 4.60 GB, Reserved: 12.47 GB
Step 1 (after training step): Allocated: 4.60 GB, Reserved: 12.47 GB
0%| | 2/19692 [00:00<2:37:06, 2.09it/s]
Step 2 (before training step): Allocated: 4.60 GB, Reserved: 12.47 GB
Step 2 (after training step): Allocated: 4.60 GB, Reserved: 12.47 GB
0%| | 3/19692 [00:01<1:59:33, 2.74it/s]
Step 3 (before training step): Allocated: 4.60 GB, Reserved: 12.47 GB
Step 3 (after training step): Allocated: 4.60 GB, Reserved: 12.47 GB
0%| | 4/19692 [00:01<1:41:37, 3.23it/s]
Step 4 (before training step): Allocated: 4.60 GB, Reserved: 12.47 GB
Step 4 (after training step): Allocated: 4.60 GB, Reserved: 12.47 GB
0%| | 5/19692 [00:01<1:31:41, 3.58it/s]
Step 5 (before training step): Allocated: 4.60 GB, Reserved: 12.47 GB
Step 5 (after training step): Allocated: 4.60 GB, Reserved: 12.47 GB
0%| | 6/19692 [00:01<1:25:44, 3.83it/s]
Step 6 (before training step): Allocated: 4.60 GB, Reserved: 12.47 GB
Step 6 (after training step): Allocated: 4.60 GB, Reserved: 12.47 GB
0%| | 7/19692 [00:02<1:21:51, 4.01it/s]
Step 7 (before training step): Allocated: 4.60 GB, Reserved: 12.47 GB
Step 7 (after training step): Allocated: 4.60 GB, Reserved: 12.47 GB
0%| | 8/19692 [00:02<1:19:34, 4.12it/s]
Step 8 (before training step): Allocated: 4.60 GB, Reserved: 12.47 GB
Step 8 (after training step): Allocated: 4.60 GB, Reserved: 12.47 GB
0%| | 9/19692 [00:02<1:17:48, 4.22it/s]
Step 9 (before training step): Allocated: 4.60 GB, Reserved: 12.47 GB 00:12:37 [79/1812]
Step 9 (after training step): Allocated: 4.60 GB, Reserved: 12.47 GB
0%| | 10/19692 [00:02<1:16:37, 4.28it/s]
Step 10 (before training step): Allocated: 4.60 GB, Reserved: 12.47 GB
Step 10 (after training step): Allocated: 4.60 GB, Reserved: 12.47 GB
0%| | 11/19692 [00:02<1:15:40, 4.33it/s]
Step 11 (before training step): Allocated: 4.60 GB, Reserved: 12.47 GB
Step 11 (after training step): Allocated: 4.60 GB, Reserved: 12.47 GB
0%| | 12/19692 [00:03<1:15:05, 4.37it/s]
Step 12 (before training step): Allocated: 4.60 GB, Reserved: 12.47 GB
Step 12 (after training step): Allocated: 4.60 GB, Reserved: 12.47 GB
0%| | 13/19692 [00:03<1:14:37, 4.40it/s]
Step 13 (before training step): Allocated: 4.60 GB, Reserved: 12.47 GB
Step 13 (after training step): Allocated: 4.60 GB, Reserved: 12.47 GB
0%| | 14/19692 [00:03<1:14:19, 4.41it/s]
Step 14 (before training step): Allocated: 4.60 GB, Reserved: 12.47 GB
Step 14 (after training step): Allocated: 4.60 GB, Reserved: 12.47 GB
0%| | 15/19692 [00:03<1:14:07, 4.42it/s]
Step 15 (before training step): Allocated: 4.60 GB, Reserved: 12.47 GB
Step 15 (after training step): Allocated: 4.60 GB, Reserved: 12.47 GB
0%| | 16/19692 [00:04<1:13:56, 4.43it/s]
Step 16 (before training step): Allocated: 4.60 GB, Reserved: 12.47 GB
Step 16 (after training step): Allocated: 4.60 GB, Reserved: 12.47 GB
0%| | 17/19692 [00:04<1:13:57, 4.43it/s]
Step 17 (before training step): Allocated: 4.60 GB, Reserved: 12.47 GB
Step 17 (after training step): Allocated: 4.60 GB, Reserved: 12.47 GB
0%| | 18/19692 [00:04<1:13:54, 4.44it/s]
Step 18 (before training step): Allocated: 4.60 GB, Reserved: 12.47 GB
Step 18 (after training step): Allocated: 4.60 GB, Reserved: 12.47 GB
0%| | 19/19692 [00:04<1:13:46, 4.44it/s]
Step 19 (before training step): Allocated: 4.60 GB, Reserved: 12.47 GB
Step 19 (after training step): Allocated: 4.60 GB, Reserved: 12.47 GB
0%|▏ | 20/19692 [00:04<1:13:52, 4.44it/s]
Step 20 (before training step): Allocated: 4.60 GB, Reserved: 12.47 GB
Step 20 (after training step): Allocated: 4.60 GB, Reserved: 12.47 GB
0%|▏ | 21/19692 [00:05<1:13:52, 4.44it/s]
Step 21 (before training step): Allocated: 4.60 GB, Reserved: 12.47 GB
Step 21 (after training step): Allocated: 4.60 GB, Reserved: 12.47 GB
0%|▏ | 22/19692 [00:05<1:13:47, 4.44it/s]
Step 22 (before training step): Allocated: 4.60 GB, Reserved: 12.47 GB 00:12:41 [40/1812]
Step 22 (after training step): Allocated: 4.60 GB, Reserved: 12.47 GB
0%|▏ | 23/19692 [00:05<1:13:41, 4.45it/s]
Step 23 (before training step): Allocated: 4.60 GB, Reserved: 12.47 GB
Step 23 (after training step): Allocated: 4.60 GB, Reserved: 12.47 GB
0%|▏ | 24/19692 [00:05<1:13:40, 4.45it/s]
Step 24 (before training step): Allocated: 4.60 GB, Reserved: 12.47 GB
Step 24 (after training step): Allocated: 8.05 GB, Reserved: 19.39 GB
0%|▏ | 25/19692 [00:06<1:21:47, 4.01it/s]
Step 25 (before training step): Allocated: 8.05 GB, Reserved: 19.39 GB
Step 25 (after training step): Allocated: 8.05 GB, Reserved: 19.39 GB
0%|▏ | 26/19692 [00:06<1:26:33, 3.79it/s]
Step 26 (before training step): Allocated: 8.05 GB, Reserved: 19.39 GB
Step 26 (after training step): Allocated: 8.05 GB, Reserved: 19.39 GB
0%|▏ | 27/19692 [00:06<1:29:52, 3.65it/s]
Step 27 (before training step): Allocated: 8.05 GB, Reserved: 19.39 GB
Step 27 (after training step): Allocated: 8.05 GB, Reserved: 19.39 GB
0%|▏ | 28/19692 [00:07<1:32:07, 3.56it/s]
Step 28 (before training step): Allocated: 8.05 GB, Reserved: 19.39 GB
Step 28 (after training step): Allocated: 8.05 GB, Reserved: 19.39 GB
0%|▏ | 29/19692 [00:07<1:33:50, 3.49it/s]
Step 29 (before training step): Allocated: 8.05 GB, Reserved: 19.39 GB
Step 29 (after training step): Allocated: 8.05 GB, Reserved: 19.39 GB
0%|▏ | 30/19692 [00:07<1:34:54, 3.45it/s]
Step 30 (before training step): Allocated: 8.05 GB, Reserved: 19.39 GB
Step 30 (after training step): Allocated: 8.05 GB, Reserved: 19.39 GB
0%|▏ | 31/19692 [00:07<1:35:41, 3.42it/s]
Step 31 (before training step): Allocated: 8.05 GB, Reserved: 19.39 GB
Step 31 (after training step): Allocated: 8.05 GB, Reserved: 19.39 GB
0%|▏ | 32/19692 [00:08<1:36:23, 3.40it/s]
Step 32 (before training step): Allocated: 8.05 GB, Reserved: 19.39 GB
Step 32 (after training step): Allocated: 8.05 GB, Reserved: 19.39 GB
0%|▏ | 33/19692 [00:08<1:36:51, 3.38it/s]
Step 33 (before training step): Allocated: 8.05 GB, Reserved: 19.39 GB
Step 33 (after training step): Allocated: 8.05 GB, Reserved: 19.39 GB
0%|▏ | 34/19692 [00:08<1:37:02, 3.38it/s]
Step 34 (before training step): Allocated: 8.05 GB, Reserved: 19.39 GB
Step 34 (after training step): Allocated: 8.05 GB, Reserved: 19.39 GB
0%|▏ | 35/19692 [00:09<1:37:10, 3.37it/s]
Step 35 (before training step): Allocated: 8.05 GB, Reserved: 19.39 GB
Step 35 (after training step): Allocated: 8.05 GB, Reserved: 19.39 GB
0%|▏ | 36/19692 [00:09<1:37:23, 3.36it/s]
Step 36 (before training step): Allocated: 8.05 GB, Reserved: 19.39 GB
Step 36 (after training step): Allocated: 8.05 GB, Reserved: 19.39 GB
0%|▏ | 37/19692 [00:09<1:37:23, 3.36it/s]
Step 37 (before training step): Allocated: 8.05 GB, Reserved: 19.39 GB
Step 37 (after training step): Allocated: 8.05 GB, Reserved: 19.39 GB
0%|▏ | 38/19692 [00:10<1:37:22, 3.36it/s]
Step 38 (before training step): Allocated: 8.05 GB, Reserved: 19.39 GB
Step 38 (after training step): Allocated: 8.05 GB, Reserved: 19.39 GB
0%|▎ | 39/19692 [00:10<1:37:48, 3.35it/s]
Step 39 (before training step): Allocated: 8.05 GB, Reserved: 19.39 GB
Step 39 (after training step): Allocated: 8.05 GB, Reserved: 19.39 GB
0%|▎ | 40/19692 [00:10<1:37:50, 3.35it/s]
Step 40 (before training step): Allocated: 8.05 GB, Reserved: 19.39 GB
Step 40 (after training step): Allocated: 8.05 GB, Reserved: 19.39 GB
0%|▎ | 41/19692 [00:10<1:37:48, 3.35it/s]
Step 41 (before training step): Allocated: 8.05 GB, Reserved: 19.39 GB
Step 41 (after training step): Allocated: 8.05 GB, Reserved: 19.39 GB
0%|▎ | 42/19692 [00:11<1:37:50, 3.35it/s]
Step 42 (before training step): Allocated: 8.05 GB, Reserved: 19.39 GB
...and that's where I hit control-C. Note that I was not using the 16-bit DeepSpeed config for this run (or further down). The "allocated" number capped out at just less than the estimated memory usage of the trainer! But the "reserved" was much higher than that, and much closer to the actual VRAM usage I was seeing. There was also that interesting jump at iteration 24:
Step 24 (before training step): Allocated: 4.60 GB, Reserved: 12.47 GB
Step 24 (after training step): Allocated: 8.05 GB, Reserved: 19.39 GB
Digging into why there was that step change there felt like something to investigate later. But the difference between allocated and reserved memory definitely sounded relevant.
Looking into the PyTorch docs:
- The allocated memory is "the current GPU memory occupied by tensors in bytes for a given device."
- The reserved memory is "the current GPU memory managed by the caching allocator in bytes for a given device."
This was beginning to sound to me like -- if there was some way to free up spaced used by caches -- we could potentially reduce the memory usage. But I decided to dig in to these numbers a little more first. Was there some way to find out what was going on during the training step?
ChatGPT suggested overriding the training_step
method in Trainer
again, but this time, instead of printing some memory usage
stuff and then calling the superclass's method like this:
def training_step(self, *args, **kwargs):
step = self.state.global_step
print_memory_usage(step, "before training step")
output = super().training_step(*args, **kwargs)
print_memory_usage(step, "after training step")
return output
...it wanted to duplicate the superclass's code but with logging in between each step. Its proposed code was:
class MemoryLoggingTrainer(Trainer):
def training_step(self, model, inputs):
step = self.state.global_step
print_memory_usage(step, "before forward")
output = model(**inputs)
print_memory_usage(step, "after forward")
loss = output.loss
print_memory_usage(step, "before backward")
loss.backward()
print_memory_usage(step, "after backward")
return loss
The actual code of training_step
in Trainer
in the version of PyTorch that I was using was this:
model.train()
inputs = self._prepare_inputs(inputs)
if is_sagemaker_mp_enabled():
loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps)
return loss_mb.reduce_mean().detach().to(self.args.device)
with self.compute_loss_context_manager():
loss = self.compute_loss(model, inputs)
if self.args.n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu parallel training
if self.use_apex:
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
scaled_loss.backward()
else:
self.accelerator.backward(loss)
return loss.detach() / self.args.gradient_accumulation_steps
So by mixing the two I got:
class MemoryLoggingTrainer(Trainer):
def training_step(self, model, inputs):
step = self.state.global_step
print_memory_usage(step, "before training_step")
model.train()
inputs = self._prepare_inputs(inputs)
if is_sagemaker_mp_enabled():
loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps)
loss = loss_mb.reduce_mean().detach().to(self.args.device)
else:
with self.compute_loss_context_manager():
print_memory_usage(step, "before forward pass")
loss = self.compute_loss(model, inputs)
print_memory_usage(step, "after forward pass")
if self.args.n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu parallel training
print_memory_usage(step, "before backward pass")
if self.use_apex:
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
scaled_loss.backward()
else:
self.accelerator.backward(loss)
print_memory_usage(step, "after backward pass")
print_memory_usage(step, "after training_step")
return loss.detach() / self.args.gradient_accumulation_steps
I didn't put any new code code inside the is_sagemaker_mp_enabled
or the self.use_apex
branches
-- that seemed likely to be safe because neither of those was a feature that I was
(knowingly) using.
I had to fix a couple of imports -- the code is here -- and running it, I got this:
Parameter Offload: Total persistent parameters: 123904 in 121 params
0%| | 0/19692 [00:00<?, ?it/s]
Step 0 (before training_step): Allocated: 4.58 GB, Reserved: 5.76 GB
Step 0 (before forward pass): Allocated: 4.58 GB, Reserved: 5.76 GB
Step 0 (after forward pass): Allocated: 8.00 GB, Reserved: 10.01 GB
Step 0 (before backward pass): Allocated: 8.00 GB, Reserved: 10.01 GB
Step 0 (after backward pass): Allocated: 4.60 GB, Reserved: 12.47 GB
Step 0 (after training_step): Allocated: 4.60 GB, Reserved: 12.47 GB
0%| | 1/19692 [00:00<1:49:46, 2.99it/s]
Step 1 (before training_step): Allocated: 4.60 GB, Reserved: 12.47 GB
Step 1 (before forward pass): Allocated: 4.60 GB, Reserved: 12.47 GB
Step 1 (after forward pass): Allocated: 8.00 GB, Reserved: 12.47 GB
Step 1 (before backward pass): Allocated: 8.00 GB, Reserved: 12.47 GB
Step 1 (after backward pass): Allocated: 4.60 GB, Reserved: 12.47 GB
Step 1 (after training_step): Allocated: 4.60 GB, Reserved: 12.47 GB
0%| | 2/19692 [00:00<2:39:25, 2.06it/s]
Step 2 (before training_step): Allocated: 4.60 GB, Reserved: 12.47 GB
Step 2 (before forward pass): Allocated: 4.60 GB, Reserved: 12.47 GB
Step 2 (after forward pass): Allocated: 8.00 GB, Reserved: 12.47 GB
Step 2 (before backward pass): Allocated: 8.00 GB, Reserved: 12.47 GB
Step 2 (after backward pass): Allocated: 4.60 GB, Reserved: 12.47 GB
Step 2 (after training_step): Allocated: 4.60 GB, Reserved: 12.47 GB
0%| | 3/19692 [00:01<2:00:44, 2.72it/s]
Step 3 (before training_step): Allocated: 4.60 GB, Reserved: 12.47 GB
Step 3 (before forward pass): Allocated: 4.60 GB, Reserved: 12.47 GB
Step 3 (after forward pass): Allocated: 8.00 GB, Reserved: 12.47 GB
Step 3 (before backward pass): Allocated: 8.00 GB, Reserved: 12.47 GB
Step 3 (after backward pass): Allocated: 4.60 GB, Reserved: 12.47 GB
Step 3 (after training_step): Allocated: 4.60 GB, Reserved: 12.47 GB
0%| | 4/19692 [00:01<1:42:28, 3.20it/s]
Step 4 (before training_step): Allocated: 4.60 GB, Reserved: 12.47 GB
Step 4 (before forward pass): Allocated: 4.60 GB, Reserved: 12.47 GB
Step 4 (after forward pass): Allocated: 8.00 GB, Reserved: 12.47 GB
Step 4 (before backward pass): Allocated: 8.00 GB, Reserved: 12.47 GB
Step 4 (after backward pass): Allocated: 4.60 GB, Reserved: 12.47 GB
Step 4 (after training_step): Allocated: 4.60 GB, Reserved: 12.47 GB
0%| | 5/19692 [00:01<16:10:21 [182/1813]
Step 5 (before training_step): Allocated: 4.60 GB, Reserved: 12.47 GB
Step 5 (before forward pass): Allocated: 4.60 GB, Reserved: 12.47 GB
Step 5 (after forward pass): Allocated: 8.00 GB, Reserved: 12.47 GB
Step 5 (before backward pass): Allocated: 8.00 GB, Reserved: 12.47 GB
Step 5 (after backward pass): Allocated: 4.60 GB, Reserved: 12.47 GB
Step 5 (after training_step): Allocated: 4.60 GB, Reserved: 12.47 GB
0%| | 6/19692 [00:01<1:26:16, 3.80it/s]
Step 6 (before training_step): Allocated: 4.60 GB, Reserved: 12.47 GB
Step 6 (before forward pass): Allocated: 4.60 GB, Reserved: 12.47 GB
Step 6 (after forward pass): Allocated: 8.00 GB, Reserved: 12.47 GB
Step 6 (before backward pass): Allocated: 8.00 GB, Reserved: 12.47 GB
Step 6 (after backward pass): Allocated: 4.60 GB, Reserved: 12.47 GB
Step 6 (after training_step): Allocated: 4.60 GB, Reserved: 12.47 GB
0%| | 7/19692 [00:02<1:22:24, 3.98it/s]
Step 7 (before training_step): Allocated: 4.60 GB, Reserved: 12.47 GB
Step 7 (before forward pass): Allocated: 4.60 GB, Reserved: 12.47 GB
Step 7 (after forward pass): Allocated: 8.00 GB, Reserved: 12.47 GB
Step 7 (before backward pass): Allocated: 8.00 GB, Reserved: 12.47 GB
Step 7 (after backward pass): Allocated: 4.60 GB, Reserved: 12.47 GB
Step 7 (after training_step): Allocated: 4.60 GB, Reserved: 12.47 GB
0%| | 8/19692 [00:02<1:19:40, 4.12it/s]
Step 8 (before training_step): Allocated: 4.60 GB, Reserved: 12.47 GB
Step 8 (before forward pass): Allocated: 4.60 GB, Reserved: 12.47 GB
Step 8 (after forward pass): Allocated: 8.00 GB, Reserved: 12.47 GB
Step 8 (before backward pass): Allocated: 8.00 GB, Reserved: 12.47 GB
Step 8 (after backward pass): Allocated: 4.60 GB, Reserved: 12.47 GB
Step 8 (after training_step): Allocated: 4.60 GB, Reserved: 12.47 GB
0%| | 9/19692 [00:02<1:17:55, 4.21it/s]
Step 9 (before training_step): Allocated: 4.60 GB, Reserved: 12.47 GB
Step 9 (before forward pass): Allocated: 4.60 GB, Reserved: 12.47 GB
Step 9 (after forward pass): Allocated: 8.00 GB, Reserved: 12.47 GB
Step 9 (before backward pass): Allocated: 8.00 GB, Reserved: 12.47 GB
Step 9 (after backward pass): Allocated: 4.60 GB, Reserved: 12.47 GB
Step 9 (after training_step): Allocated: 4.60 GB, Reserved: 12.47 GB
0%| | 10/19692 [00:02<1:16:42, 4.28it/s]
Step 10 (before training_step): Allocated: 4.60 GB, Reserved: 12.47 GB 16:10:23 [146/1813]
Step 10 (before forward pass): Allocated: 4.60 GB, Reserved: 12.47 GB
Step 10 (after forward pass): Allocated: 8.00 GB, Reserved: 12.47 GB
Step 10 (before backward pass): Allocated: 8.00 GB, Reserved: 12.47 GB
Step 10 (after backward pass): Allocated: 4.60 GB, Reserved: 12.47 GB
Step 10 (after training_step): Allocated: 4.60 GB, Reserved: 12.47 GB
0%| | 11/19692 [00:02<1:15:57, 4.32it/s]
Step 11 (before training_step): Allocated: 4.60 GB, Reserved: 12.47 GB
Step 11 (before forward pass): Allocated: 4.60 GB, Reserved: 12.47 GB
Step 11 (after forward pass): Allocated: 8.00 GB, Reserved: 12.47 GB
Step 11 (before backward pass): Allocated: 8.00 GB, Reserved: 12.47 GB
Step 11 (after backward pass): Allocated: 4.60 GB, Reserved: 12.47 GB
Step 11 (after training_step): Allocated: 4.60 GB, Reserved: 12.47 GB
0%| | 12/19692 [00:03<1:15:23, 4.35it/s]
Step 12 (before training_step): Allocated: 4.60 GB, Reserved: 12.47 GB
Step 12 (before forward pass): Allocated: 4.60 GB, Reserved: 12.47 GB
Step 12 (after forward pass): Allocated: 8.00 GB, Reserved: 12.47 GB
Step 12 (before backward pass): Allocated: 8.00 GB, Reserved: 12.47 GB
Step 12 (after backward pass): Allocated: 4.60 GB, Reserved: 12.47 GB
Step 12 (after training_step): Allocated: 4.60 GB, Reserved: 12.47 GB
0%| | 13/19692 [00:03<1:14:51, 4.38it/s]
Step 13 (before training_step): Allocated: 4.60 GB, Reserved: 12.47 GB
Step 13 (before forward pass): Allocated: 4.60 GB, Reserved: 12.47 GB
Step 13 (after forward pass): Allocated: 8.00 GB, Reserved: 12.47 GB
Step 13 (before backward pass): Allocated: 8.00 GB, Reserved: 12.47 GB
Step 13 (after backward pass): Allocated: 4.60 GB, Reserved: 12.47 GB
Step 13 (after training_step): Allocated: 4.60 GB, Reserved: 12.47 GB
0%| | 14/19692 [00:03<1:14:37, 4.39it/s]
Step 14 (before training_step): Allocated: 4.60 GB, Reserved: 12.47 GB
Step 14 (before forward pass): Allocated: 4.60 GB, Reserved: 12.47 GB
Step 14 (after forward pass): Allocated: 8.00 GB, Reserved: 12.47 GB
Step 14 (before backward pass): Allocated: 8.00 GB, Reserved: 12.47 GB
Step 14 (after backward pass): Allocated: 4.60 GB, Reserved: 12.47 GB
Step 14 (after training_step): Allocated: 4.60 GB, Reserved: 12.47 GB
0%| | 15/19692 [00:03<1:14:26, 4.41it/s]
Step 15 (before training_step): Allocated: 4.60 GB, Reserved: 12.47 GB 16:10:24 [111/1813]
Step 15 (before forward pass): Allocated: 4.60 GB, Reserved: 12.47 GB
Step 15 (after forward pass): Allocated: 8.00 GB, Reserved: 12.47 GB
Step 15 (before backward pass): Allocated: 8.00 GB, Reserved: 12.47 GB
Step 15 (after backward pass): Allocated: 4.60 GB, Reserved: 12.47 GB
Step 15 (after training_step): Allocated: 4.60 GB, Reserved: 12.47 GB
0%| | 16/19692 [00:04<1:14:16, 4.42it/s]
Step 16 (before training_step): Allocated: 4.60 GB, Reserved: 12.47 GB
Step 16 (before forward pass): Allocated: 4.60 GB, Reserved: 12.47 GB
Step 16 (after forward pass): Allocated: 8.00 GB, Reserved: 12.47 GB
Step 16 (before backward pass): Allocated: 8.00 GB, Reserved: 12.47 GB
Step 16 (after backward pass): Allocated: 4.60 GB, Reserved: 12.47 GB
Step 16 (after training_step): Allocated: 4.60 GB, Reserved: 12.47 GB
0%| | 17/19692 [00:04<1:14:11, 4.42it/s]
Step 17 (before training_step): Allocated: 4.60 GB, Reserved: 12.47 GB
Step 17 (before forward pass): Allocated: 4.60 GB, Reserved: 12.47 GB
Step 17 (after forward pass): Allocated: 8.00 GB, Reserved: 12.47 GB
Step 17 (before backward pass): Allocated: 8.00 GB, Reserved: 12.47 GB
Step 17 (after backward pass): Allocated: 4.60 GB, Reserved: 12.47 GB
Step 17 (after training_step): Allocated: 4.60 GB, Reserved: 12.47 GB
0%| | 18/19692 [00:04<1:14:02, 4.43it/s]
Step 18 (before training_step): Allocated: 4.60 GB, Reserved: 12.47 GB
Step 18 (before forward pass): Allocated: 4.60 GB, Reserved: 12.47 GB
Step 18 (after forward pass): Allocated: 8.00 GB, Reserved: 12.47 GB
Step 18 (before backward pass): Allocated: 8.00 GB, Reserved: 12.47 GB
Step 18 (after backward pass): Allocated: 4.60 GB, Reserved: 12.47 GB
Step 18 (after training_step): Allocated: 4.60 GB, Reserved: 12.47 GB
0%| | 19/19692 [00:04<1:14:02, 4.43it/s]
Step 19 (before training_step): Allocated: 4.60 GB, Reserved: 12.47 GB
Step 19 (before forward pass): Allocated: 4.60 GB, Reserved: 12.47 GB
Step 19 (after forward pass): Allocated: 8.00 GB, Reserved: 12.47 GB
Step 19 (before backward pass): Allocated: 8.00 GB, Reserved: 12.47 GB
Step 19 (after backward pass): Allocated: 4.60 GB, Reserved: 12.47 GB
Step 19 (after training_step): Allocated: 4.60 GB, Reserved: 12.47 GB
0%|▏ | 20/19692 [00:04<1:14:00, 4.43it/s]
Step 20 (before training_step): Allocated: 4.60 GB, Reserved: 12.47 GB 16:10:25 [76/1813]
Step 20 (before forward pass): Allocated: 4.60 GB, Reserved: 12.47 GB
Step 20 (after forward pass): Allocated: 8.00 GB, Reserved: 12.47 GB
Step 20 (before backward pass): Allocated: 8.00 GB, Reserved: 12.47 GB
Step 20 (after backward pass): Allocated: 4.60 GB, Reserved: 12.47 GB
Step 20 (after training_step): Allocated: 4.60 GB, Reserved: 12.47 GB
0%|▏ | 21/19692 [00:05<1:14:03, 4.43it/s]
Step 21 (before training_step): Allocated: 4.60 GB, Reserved: 12.47 GB
Step 21 (before forward pass): Allocated: 4.60 GB, Reserved: 12.47 GB
Step 21 (after forward pass): Allocated: 8.00 GB, Reserved: 12.47 GB
Step 21 (before backward pass): Allocated: 8.00 GB, Reserved: 12.47 GB
Step 21 (after backward pass): Allocated: 4.60 GB, Reserved: 12.47 GB
Step 21 (after training_step): Allocated: 4.60 GB, Reserved: 12.47 GB
0%|▏ | 22/19692 [00:05<1:13:59, 4.43it/s]
Step 22 (before training_step): Allocated: 4.60 GB, Reserved: 12.47 GB
Step 22 (before forward pass): Allocated: 4.60 GB, Reserved: 12.47 GB
Step 22 (after forward pass): Allocated: 8.00 GB, Reserved: 12.47 GB
Step 22 (before backward pass): Allocated: 8.00 GB, Reserved: 12.47 GB
Step 22 (after backward pass): Allocated: 4.60 GB, Reserved: 12.47 GB
Step 22 (after training_step): Allocated: 4.60 GB, Reserved: 12.47 GB
0%|▏ | 23/19692 [00:05<1:13:53, 4.44it/s]
Step 23 (before training_step): Allocated: 4.60 GB, Reserved: 12.47 GB
Step 23 (before forward pass): Allocated: 4.60 GB, Reserved: 12.47 GB
Step 23 (after forward pass): Allocated: 8.00 GB, Reserved: 12.47 GB
Step 23 (before backward pass): Allocated: 8.00 GB, Reserved: 12.47 GB
Step 23 (after backward pass): Allocated: 4.60 GB, Reserved: 12.47 GB
Step 23 (after training_step): Allocated: 4.60 GB, Reserved: 12.47 GB
0%|▏ | 24/19692 [00:05<1:13:53, 4.44it/s]
Step 24 (before training_step): Allocated: 4.60 GB, Reserved: 12.47 GB
Step 24 (before forward pass): Allocated: 4.60 GB, Reserved: 12.47 GB
Step 24 (after forward pass): Allocated: 8.00 GB, Reserved: 12.47 GB
Step 24 (before backward pass): Allocated: 8.00 GB, Reserved: 12.47 GB
Step 24 (after backward pass): Allocated: 8.05 GB, Reserved: 19.39 GB
Step 24 (after training_step): Allocated: 8.05 GB, Reserved: 19.39 GB
0%|▏ | 25/19692 [00:06<1:21:57, 4.00it/s]
Step 25 (before training_step): Allocated: 8.05 GB, Reserved: 19.39 GB 16:10:26 [41/1813]
Step 25 (before forward pass): Allocated: 8.05 GB, Reserved: 19.39 GB
Step 25 (after forward pass): Allocated: 11.46 GB, Reserved: 19.39 GB
Step 25 (before backward pass): Allocated: 11.46 GB, Reserved: 19.39 GB
Step 25 (after backward pass): Allocated: 8.05 GB, Reserved: 19.39 GB
Step 25 (after training_step): Allocated: 8.05 GB, Reserved: 19.39 GB
0%|▏ | 26/19692 [00:06<1:26:42, 3.78it/s]
Step 26 (before training_step): Allocated: 8.05 GB, Reserved: 19.39 GB
Step 26 (before forward pass): Allocated: 8.05 GB, Reserved: 19.39 GB
Step 26 (after forward pass): Allocated: 11.46 GB, Reserved: 19.39 GB
Step 26 (before backward pass): Allocated: 11.46 GB, Reserved: 19.39 GB
Step 26 (after backward pass): Allocated: 8.05 GB, Reserved: 19.39 GB
Step 26 (after training_step): Allocated: 8.05 GB, Reserved: 19.39 GB
0%|▏ | 27/19692 [00:06<1:29:57, 3.64it/s]
Step 27 (before training_step): Allocated: 8.05 GB, Reserved: 19.39 GB
Step 27 (before forward pass): Allocated: 8.05 GB, Reserved: 19.39 GB
Step 27 (after forward pass): Allocated: 11.46 GB, Reserved: 19.39 GB
Step 27 (before backward pass): Allocated: 11.46 GB, Reserved: 19.39 GB
Step 27 (after backward pass): Allocated: 8.05 GB, Reserved: 19.39 GB
Step 27 (after training_step): Allocated: 8.05 GB, Reserved: 19.39 GB
0%|▏ | 28/19692 [00:07<1:32:10, 3.56it/s]
Step 28 (before training_step): Allocated: 8.05 GB, Reserved: 19.39 GB
Step 28 (before forward pass): Allocated: 8.05 GB, Reserved: 19.39 GB
Step 28 (after forward pass): Allocated: 11.46 GB, Reserved: 19.39 GB
Step 28 (before backward pass): Allocated: 11.46 GB, Reserved: 19.39 GB
Step 28 (after backward pass): Allocated: 8.05 GB, Reserved: 19.39 GB
Step 28 (after training_step): Allocated: 8.05 GB, Reserved: 19.39 GB
0%|▏ | 29/19692 [00:07<1:33:47, 3.49it/s]
Step 29 (before training_step): Allocated: 8.05 GB, Reserved: 19.39 GB
Step 29 (before forward pass): Allocated: 8.05 GB, Reserved: 19.39 GB
Step 29 (after forward pass): Allocated: 11.46 GB, Reserved: 19.39 GB
Step 29 (before backward pass): Allocated: 11.46 GB, Reserved: 19.39 GB
Step 29 (after backward pass): Allocated: 8.05 GB, Reserved: 19.39 GB
Step 29 (after training_step): Allocated: 8.05 GB, Reserved: 19.39 GB
0%|▏ | 30/19692 [00:07<1:34:54, 3.45it/s]
Step 30 (before training_step): Allocated: 8.05 GB, Reserved: 19.39 GB
Step 30 (before forward pass): Allocated: 8.05 GB, Reserved: 19.39 GB
Step 30 (after forward pass): Allocated: 11.46 GB, Reserved: 19.39 GB
Step 30 (before backward pass): Allocated: 11.46 GB, Reserved: 19.39 GB
Step 30 (after backward pass): Allocated: 8.05 GB, Reserved: 19.39 GB
Step 30 (after training_step): Allocated: 8.05 GB, Reserved: 19.39 GB
0%|▏ | 31/19692 [00:07<1:35:49, 3.42it/s]
Step 31 (before training_step): Allocated: 8.05 GB, Reserved: 19.39 GB
Step 31 (before forward pass): Allocated: 8.05 GB, Reserved: 19.39 GB
Step 31 (after forward pass): Allocated: 11.46 GB, Reserved: 19.39 GB
Step 31 (before backward pass): Allocated: 11.46 GB, Reserved: 19.39 GB
Step 31 (after backward pass): Allocated: 8.05 GB, Reserved: 19.39 GB
Step 31 (after training_step): Allocated: 8.05 GB, Reserved: 19.39 GB
0%|▏ | 32/19692 [00:08<1:36:24, 3.40it/s]
Step 32 (before training_step): Allocated: 8.05 GB, Reserved: 19.39 GB
Step 32 (before forward pass): Allocated: 8.05 GB, Reserved: 19.39 GB
Step 32 (after forward pass): Allocated: 11.46 GB, Reserved: 19.39 GB
Step 32 (before backward pass): Allocated: 11.46 GB, Reserved: 19.39 GB
Step 32 (after backward pass): Allocated: 8.05 GB, Reserved: 19.39 GB
Step 32 (after training_step): Allocated: 8.05 GB, Reserved: 19.39 GB
0%|▏ | 33/19692 [00:08<1:36:49, 3.38it/s]
So it looked like the memory usage was higher after the forward pass and before the backward pass. In my earlier results, it had looked like the VRAM usage estimated matched the amount in the "allocated" VRAM, but with this more detailed logging I could see that during the training step, the allocated amount was significantly higher than the estimate (though nowhere near as high as the total reserved amount).
This sounds to me suspiciously like it might be the gradients taking up the VRAM, as I would assume that the activations would have been thrown away after the forward.
This run left me with two questions:
- Why was that high-water mark for the "allocated" VRAM still higher than the 8.36 GiB estimated by the function I ran back at the start of this investigation?
- Is the extra "reserved" VRAM something that can be cleared down so that it would not affect whether or not the model would run -- that is, could I do something that would make the model train on a machine with 11.46 GiB of VRAM, that being the high water mark of the allocated VRAM? It sounded like the extra memory was all caches, after all, so maybe it would be possible to run without them.
I decided to focus on the first of these initially.
The novice asked ChatGPT, and was immediately enlightened
I asked ChatGPT what it thought about that output, and it came up with this:
- Check DeepSpeed Configuration:
Ensure that the DeepSpeed configuration is optimized for your hardware and model size. For example, you might want to adjust the
stage3_prefetch_bucket_size
,stage3_param_persistence_threshold
, andreduce_bucket_size parameters
.
- Model Initialization:
Check if there is any lazy initialization happening within the first few iterations. This could explain why the memory usage jumps after several steps.
- Gradient Accumulation:
Ensure that gradient accumulation is correctly configured. Misconfigurations here can lead to unexpected memory usage spikes.
- Profile Memory Usage:
Use more detailed profiling tools like NVIDIA Nsight Systems or the PyTorch Profiler to get a detailed breakdown of memory usage.
- Reduce Batch Size or Sequence Length:
Try reducing the batch size or the maximum sequence length to see if it affects the memory usage pattern.
Wait, what? Sequence length?
This was my d'oh moment. VRAM usage during training is highly dependent on the sequence length -- I learned that in my first blog post on fine-tuning.
But let's look at the code I was running to work out the memory usage for training this 0.5B model (formatted for legibility):
from transformers import AutoModel
from deepspeed.runtime.zero.stage3 import estimate_zero3_model_states_mem_needs_all_live
model = AutoModel.from_pretrained("Qwen/Qwen1.5-0.5B")
estimate_zero3_model_states_mem_needs_all_live(model, num_gpus_per_node=1, num_nodes=1)
There is nothing in there to do with the sequence length. But, as I knew (and had now been effectively reminded) you cannot accurately estimate the memory usage for training an LLM without that important parameter.
My best guess (as of this post) is that the estimate_zero3_model_states_mem_needs_all_live
function is not designed specifically for LLMs. Many AI models work on fixed input
sizes, so won't have a parameter like sequence length, and perhaps the function
is designed for estimating the memory usage for them. Maybe experienced LLM
researchers just know that you need to allow some extra factor to allow for that.
But anyway, the obvious next step was to try training on a heavily reduced sequence length,
to see what happened to the memory usage then. My tokenize_function
looked like
this:
def tokenize_function(examples):
tokenized = tokenizer(examples["text"], truncation=True, padding="max_length", max_length=2048)
tokenized["labels"] = tokenized["input_ids"][:]
return tokenized
I changed the max_length
from 2048 to 10, and:
0%| | 0/19692 [00:00<?, ?it/s]
Step 0 (before training_step): Allocated: 4.58 GB, Reserved: 5.76 GB
Step 0 (before forward pass): Allocated: 4.58 GB, Reserved: 5.76 GB
Step 0 (after forward pass): Allocated: 4.60 GB, Reserved: 5.77 GB
Step 0 (before backward pass): Allocated: 4.60 GB, Reserved: 5.77 GB
Step 0 (after backward pass): Allocated: 4.60 GB, Reserved: 7.36 GB
Step 0 (after training_step): Allocated: 4.60 GB, Reserved: 7.36 GB
0%| | 1/19692 [00:00<1:16:19, 4.30it/s]
Step 1 (before training_step): Allocated: 4.60 GB, Reserved: 7.36 GB
Step 1 (before forward pass): Allocated: 4.60 GB, Reserved: 7.36 GB
Step 1 (after forward pass): Allocated: 4.61 GB, Reserved: 7.36 GB
Step 1 (before backward pass): Allocated: 4.61 GB, Reserved: 7.36 GB
Step 1 (after backward pass): Allocated: 4.60 GB, Reserved: 7.36 GB
Step 1 (after training_step): Allocated: 4.60 GB, Reserved: 7.36 GB
0%| | 2/19692 [00:00<2:20:14, 2.34it/s]
Step 2 (before training_step): Allocated: 4.60 GB, Reserved: 7.36 GB
Step 2 (before forward pass): Allocated: 4.60 GB, Reserved: 7.36 GB
Step 2 (after forward pass): Allocated: 4.61 GB, Reserved: 7.36 GB
Step 2 (before backward pass): Allocated: 4.61 GB, Reserved: 7.36 GB
Step 2 (after backward pass): Allocated: 4.60 GB, Reserved: 7.36 GB
Step 2 (after training_step): Allocated: 4.60 GB, Reserved: 7.36 GB
0%| | 3/19692 [00:00<1:31:40, 3.58it/s]
Step 3 (before training_step): Allocated: 4.60 GB, Reserved: 7.36 GB
Step 3 (before forward pass): Allocated: 4.60 GB, Reserved: 7.36 GB
Step 3 (after forward pass): Allocated: 4.61 GB, Reserved: 7.36 GB
Step 3 (before backward pass): Allocated: 4.61 GB, Reserved: 7.36 GB
Step 3 (after backward pass): Allocated: 4.60 GB, Reserved: 7.36 GB
Step 3 (after training_step): Allocated: 4.60 GB, Reserved: 7.36 GB
0%| | 4/19692 [00:01<1:08:54, 4.76it/s]
Step 4 (before training_step): Allocated: 4.60 GB, Reserved: 7.36 GB
Step 4 (before forward pass): Allocated: 4.60 GB, Reserved: 7.36 GB
Step 4 (after forward pass): Allocated: 4.61 GB, Reserved: 7.36 GB
Step 4 (before backward pass): Allocated: 4.61 GB, Reserved: 7.36 GB
Step 4 (after backward pass): Allocated: 4.60 GB, Reserved: 7.36 GB
Step 4 (after training_step): Allocated: 4.60 GB, Reserved: 7.36 GB
0%| | 5/19692 [00:01<56:30, 5.81it/s]
Step 5 (before training_step): Allocated: 4.60 GB, Reserved: 7.36 GB 17:43:56 [165/1895]
Step 5 (before forward pass): Allocated: 4.60 GB, Reserved: 7.36 GB
Step 5 (after forward pass): Allocated: 4.61 GB, Reserved: 7.36 GB
Step 5 (before backward pass): Allocated: 4.61 GB, Reserved: 7.36 GB
Step 5 (after backward pass): Allocated: 4.60 GB, Reserved: 7.36 GB
Step 5 (after training_step): Allocated: 4.60 GB, Reserved: 7.36 GB
0%| | 6/19692 [00:01<48:48, 6.72it/s]
Step 6 (before training_step): Allocated: 4.60 GB, Reserved: 7.36 GB
Step 6 (before forward pass): Allocated: 4.60 GB, Reserved: 7.36 GB
Step 6 (after forward pass): Allocated: 4.61 GB, Reserved: 7.36 GB
Step 6 (before backward pass): Allocated: 4.61 GB, Reserved: 7.36 GB
Step 6 (after backward pass): Allocated: 4.60 GB, Reserved: 7.36 GB
Step 6 (after training_step): Allocated: 4.60 GB, Reserved: 7.36 GB
0%| | 7/19692 [00:01<43:53, 7.48it/s]
Step 7 (before training_step): Allocated: 4.60 GB, Reserved: 7.36 GB
Step 7 (before forward pass): Allocated: 4.60 GB, Reserved: 7.36 GB
Step 7 (after forward pass): Allocated: 4.61 GB, Reserved: 7.36 GB
Step 7 (before backward pass): Allocated: 4.61 GB, Reserved: 7.36 GB
Step 7 (after backward pass): Allocated: 4.60 GB, Reserved: 7.36 GB
Step 7 (after training_step): Allocated: 4.60 GB, Reserved: 7.36 GB
0%| | 8/19692 [00:01<40:35, 8.08it/s]
Step 8 (before training_step): Allocated: 4.60 GB, Reserved: 7.36 GB
Step 8 (before forward pass): Allocated: 4.60 GB, Reserved: 7.36 GB
Step 8 (after forward pass): Allocated: 4.61 GB, Reserved: 7.36 GB
Step 8 (before backward pass): Allocated: 4.61 GB, Reserved: 7.36 GB
Step 8 (after backward pass): Allocated: 4.60 GB, Reserved: 7.36 GB
Step 8 (after training_step): Allocated: 4.60 GB, Reserved: 7.36 GB
0%| | 9/19692 [00:01<38:22, 8.55it/s]
Step 9 (before training_step): Allocated: 4.60 GB, Reserved: 7.36 GB
Step 9 (before forward pass): Allocated: 4.60 GB, Reserved: 7.36 GB
Step 9 (after forward pass): Allocated: 4.61 GB, Reserved: 7.36 GB
Step 9 (before backward pass): Allocated: 4.61 GB, Reserved: 7.36 GB
Step 9 (after backward pass): Allocated: 4.60 GB, Reserved: 7.36 GB
Step 9 (after training_step): Allocated: 4.60 GB, Reserved: 7.36 GB
0%| | 10/19692 [00:01<36:56, 8.88it/s]
Step 10 (before training_step): Allocated: 4.60 GB, Reserved: 7.36 GB 17:43:56 [130/1895]
Step 10 (before forward pass): Allocated: 4.60 GB, Reserved: 7.36 GB
Step 10 (after forward pass): Allocated: 4.61 GB, Reserved: 7.36 GB
Step 10 (before backward pass): Allocated: 4.61 GB, Reserved: 7.36 GB
Step 10 (after backward pass): Allocated: 4.60 GB, Reserved: 7.36 GB
Step 10 (after training_step): Allocated: 4.60 GB, Reserved: 7.36 GB
0%| | 11/19692 [00:01<36:03, 9.09it/s]
Step 11 (before training_step): Allocated: 4.60 GB, Reserved: 7.36 GB
Step 11 (before forward pass): Allocated: 4.60 GB, Reserved: 7.36 GB
Step 11 (after forward pass): Allocated: 4.61 GB, Reserved: 7.36 GB
Step 11 (before backward pass): Allocated: 4.61 GB, Reserved: 7.36 GB
Step 11 (after backward pass): Allocated: 4.60 GB, Reserved: 7.36 GB
Step 11 (after training_step): Allocated: 4.60 GB, Reserved: 7.36 GB
0%| | 12/19692 [00:01<35:27, 9.25it/s]
Step 12 (before training_step): Allocated: 4.60 GB, Reserved: 7.36 GB
Step 12 (before forward pass): Allocated: 4.60 GB, Reserved: 7.36 GB
Step 12 (after forward pass): Allocated: 4.61 GB, Reserved: 7.36 GB
Step 12 (before backward pass): Allocated: 4.61 GB, Reserved: 7.36 GB
Step 12 (after backward pass): Allocated: 4.60 GB, Reserved: 7.36 GB
Step 12 (after training_step): Allocated: 4.60 GB, Reserved: 7.36 GB
0%| | 13/19692 [00:01<34:58, 9.38it/s]
Step 13 (before training_step): Allocated: 4.60 GB, Reserved: 7.36 GB
Step 13 (before forward pass): Allocated: 4.60 GB, Reserved: 7.36 GB
Step 13 (after forward pass): Allocated: 4.61 GB, Reserved: 7.36 GB
Step 13 (before backward pass): Allocated: 4.61 GB, Reserved: 7.36 GB
Step 13 (after backward pass): Allocated: 4.60 GB, Reserved: 7.36 GB
Step 13 (after training_step): Allocated: 4.60 GB, Reserved: 7.36 GB
0%| | 14/19692 [00:02<34:35, 9.48it/s]
Step 14 (before training_step): Allocated: 4.60 GB, Reserved: 7.36 GB
Step 14 (before forward pass): Allocated: 4.60 GB, Reserved: 7.36 GB
Step 14 (after forward pass): Allocated: 4.61 GB, Reserved: 7.36 GB
Step 14 (before backward pass): Allocated: 4.61 GB, Reserved: 7.36 GB
Step 14 (after backward pass): Allocated: 4.60 GB, Reserved: 7.36 GB
Step 14 (after training_step): Allocated: 4.60 GB, Reserved: 7.36 GB
0%| | 15/19692 [00:02<34:29, 9.51it/s]
Step 15 (before training_step): Allocated: 4.60 GB, Reserved: 7.36 GB 17:43:57 [95/1895]
Step 15 (before forward pass): Allocated: 4.60 GB, Reserved: 7.36 GB
Step 15 (after forward pass): Allocated: 4.61 GB, Reserved: 7.36 GB
Step 15 (before backward pass): Allocated: 4.61 GB, Reserved: 7.36 GB
Step 15 (after backward pass): Allocated: 4.60 GB, Reserved: 7.36 GB
Step 15 (after training_step): Allocated: 4.60 GB, Reserved: 7.36 GB
0%| | 16/19692 [00:02<34:15, 9.57it/s]
Step 16 (before training_step): Allocated: 4.60 GB, Reserved: 7.36 GB
Step 16 (before forward pass): Allocated: 4.60 GB, Reserved: 7.36 GB
Step 16 (after forward pass): Allocated: 4.61 GB, Reserved: 7.36 GB
Step 16 (before backward pass): Allocated: 4.61 GB, Reserved: 7.36 GB
Step 16 (after backward pass): Allocated: 4.60 GB, Reserved: 7.36 GB
Step 16 (after training_step): Allocated: 4.60 GB, Reserved: 7.36 GB
0%| | 17/19692 [00:02<34:07, 9.61it/s]
Step 17 (before training_step): Allocated: 4.60 GB, Reserved: 7.36 GB
Step 17 (before forward pass): Allocated: 4.60 GB, Reserved: 7.36 GB
Step 17 (after forward pass): Allocated: 4.61 GB, Reserved: 7.36 GB
Step 17 (before backward pass): Allocated: 4.61 GB, Reserved: 7.36 GB
Step 17 (after backward pass): Allocated: 4.60 GB, Reserved: 7.36 GB
Step 17 (after training_step): Allocated: 4.60 GB, Reserved: 7.36 GB
0%| | 18/19692 [00:02<34:00, 9.64it/s]
Step 18 (before training_step): Allocated: 4.60 GB, Reserved: 7.36 GB
Step 18 (before forward pass): Allocated: 4.60 GB, Reserved: 7.36 GB
Step 18 (after forward pass): Allocated: 4.61 GB, Reserved: 7.36 GB
Step 18 (before backward pass): Allocated: 4.61 GB, Reserved: 7.36 GB
Step 18 (after backward pass): Allocated: 4.60 GB, Reserved: 7.36 GB
Step 18 (after training_step): Allocated: 4.60 GB, Reserved: 7.36 GB
0%|▏ | 19/19692 [00:02<33:55, 9.66it/s]
Step 19 (before training_step): Allocated: 4.60 GB, Reserved: 7.36 GB
Step 19 (before forward pass): Allocated: 4.60 GB, Reserved: 7.36 GB
Step 19 (after forward pass): Allocated: 4.61 GB, Reserved: 7.36 GB
Step 19 (before backward pass): Allocated: 4.61 GB, Reserved: 7.36 GB
Step 19 (after backward pass): Allocated: 4.60 GB, Reserved: 7.36 GB
Step 19 (after training_step): Allocated: 4.60 GB, Reserved: 7.36 GB
0%|▏ | 20/19692 [00:02<33:51, 9.68it/s]
Step 20 (before training_step): Allocated: 4.60 GB, Reserved: 7.36 GB 17:43:57 [60/1895]
Step 20 (before forward pass): Allocated: 4.60 GB, Reserved: 7.36 GB
Step 20 (after forward pass): Allocated: 4.61 GB, Reserved: 7.36 GB
Step 20 (before backward pass): Allocated: 4.61 GB, Reserved: 7.36 GB
Step 20 (after backward pass): Allocated: 4.60 GB, Reserved: 7.36 GB
Step 20 (after training_step): Allocated: 4.60 GB, Reserved: 7.36 GB
0%|▏ | 21/19692 [00:02<33:48, 9.70it/s]
Step 21 (before training_step): Allocated: 4.60 GB, Reserved: 7.36 GB
Step 21 (before forward pass): Allocated: 4.60 GB, Reserved: 7.36 GB
Step 21 (after forward pass): Allocated: 4.61 GB, Reserved: 7.36 GB
Step 21 (before backward pass): Allocated: 4.61 GB, Reserved: 7.36 GB
Step 21 (after backward pass): Allocated: 4.60 GB, Reserved: 7.36 GB
Step 21 (after training_step): Allocated: 4.60 GB, Reserved: 7.36 GB
0%|▏ | 22/19692 [00:02<33:51, 9.68it/s]
Step 22 (before training_step): Allocated: 4.60 GB, Reserved: 7.36 GB
Step 22 (before forward pass): Allocated: 4.60 GB, Reserved: 7.36 GB
Step 22 (after forward pass): Allocated: 4.61 GB, Reserved: 7.36 GB
Step 22 (before backward pass): Allocated: 4.61 GB, Reserved: 7.36 GB
Step 22 (after backward pass): Allocated: 8.05 GB, Reserved: 15.43 GB
Step 22 (after training_step): Allocated: 8.05 GB, Reserved: 15.43 GB
0%|▏ | 23/19692 [00:03<42:02, 7.80it/s]
Step 23 (before training_step): Allocated: 8.05 GB, Reserved: 15.43 GB
Step 23 (before forward pass): Allocated: 8.05 GB, Reserved: 15.43 GB
Step 23 (after forward pass): Allocated: 8.07 GB, Reserved: 15.43 GB
Step 23 (before backward pass): Allocated: 8.07 GB, Reserved: 15.43 GB
Step 23 (after backward pass): Allocated: 8.05 GB, Reserved: 15.43 GB
Step 23 (after training_step): Allocated: 8.05 GB, Reserved: 15.43 GB
0%|▏ | 24/19692 [00:03<46:44, 7.01it/s]
Step 24 (before training_step): Allocated: 8.05 GB, Reserved: 15.43 GB
Step 24 (before forward pass): Allocated: 8.05 GB, Reserved: 15.43 GB
Step 24 (after forward pass): Allocated: 8.07 GB, Reserved: 15.43 GB
Step 24 (before backward pass): Allocated: 8.07 GB, Reserved: 15.43 GB
Step 24 (after backward pass): Allocated: 8.05 GB, Reserved: 15.43 GB
Step 24 (after training_step): Allocated: 8.05 GB, Reserved: 15.43 GB
0%|▏ | 25/19692 [00:03<49:52, 6.57it/s]
Step 25 (before training_step): Allocated: 8.05 GB, Reserved: 15.43 GB 17:43:58 [25/1895]
Step 25 (before forward pass): Allocated: 8.05 GB, Reserved: 15.43 GB
Step 25 (after forward pass): Allocated: 8.07 GB, Reserved: 15.43 GB
Step 25 (before backward pass): Allocated: 8.07 GB, Reserved: 15.43 GB
Step 25 (after backward pass): Allocated: 8.05 GB, Reserved: 15.43 GB
Step 25 (after training_step): Allocated: 8.05 GB, Reserved: 15.43 GB
0%|▏ | 26/19692 [00:03<52:12, 6.28it/s]
Step 26 (before training_step): Allocated: 8.05 GB, Reserved: 15.43 GB
Step 26 (before forward pass): Allocated: 8.05 GB, Reserved: 15.43 GB
Step 26 (after forward pass): Allocated: 8.07 GB, Reserved: 15.43 GB
Step 26 (before backward pass): Allocated: 8.07 GB, Reserved: 15.43 GB
Step 26 (after backward pass): Allocated: 8.05 GB, Reserved: 15.43 GB
Step 26 (after training_step): Allocated: 8.05 GB, Reserved: 15.43 GB
0%|▏ | 27/19692 [00:03<53:40, 6.11it/s]
Step 27 (before training_step): Allocated: 8.05 GB, Reserved: 15.43 GB
Step 27 (before forward pass): Allocated: 8.05 GB, Reserved: 15.43 GB
Step 27 (after forward pass): Allocated: 8.07 GB, Reserved: 15.43 GB
Step 27 (before backward pass): Allocated: 8.07 GB, Reserved: 15.43 GB
Step 27 (after backward pass): Allocated: 8.05 GB, Reserved: 15.43 GB
Step 27 (after training_step): Allocated: 8.05 GB, Reserved: 15.43 GB
0%|▏ | 28/19692 [00:03<54:48, 5.98it/s]
Step 28 (before training_step): Allocated: 8.05 GB, Reserved: 15.43 GB
Step 28 (before forward pass): Allocated: 8.05 GB, Reserved: 15.43 GB
Step 28 (after forward pass): Allocated: 8.07 GB, Reserved: 15.43 GB
Step 28 (before backward pass): Allocated: 8.07 GB, Reserved: 15.43 GB
Step 28 (after backward pass): Allocated: 8.05 GB, Reserved: 15.43 GB
Step 28 (after training_step): Allocated: 8.05 GB, Reserved: 15.43 GB
0%|▏ | 29/19692 [00:04<55:54, 5.86it/s]
Step 29 (before training_step): Allocated: 8.05 GB, Reserved: 15.43 GB
Step 29 (before forward pass): Allocated: 8.05 GB, Reserved: 15.43 GB
Step 29 (after forward pass): Allocated: 8.07 GB, Reserved: 15.43 GB
Step 29 (before backward pass): Allocated: 8.07 GB, Reserved: 15.43 GB
Step 29 (after backward pass): Allocated: 8.05 GB, Reserved: 15.43 GB
Step 29 (after training_step): Allocated: 8.05 GB, Reserved: 15.43 GB
0%|▏ | 30/19692 [00:04<56:31, 5.80it/s]
Step 30 (before training_step): Allocated: 8.05 GB, Reserved: 15.43 GB
Step 30 (before forward pass): Allocated: 8.05 GB, Reserved: 15.43 GB
Step 30 (after forward pass): Allocated: 8.07 GB, Reserved: 15.43 GB
Step 30 (before backward pass): Allocated: 8.07 GB, Reserved: 15.43 GB
Step 30 (after backward pass): Allocated: 8.05 GB, Reserved: 15.43 GB
Step 30 (after training_step): Allocated: 8.05 GB, Reserved: 15.43 GB
And there we go. At iteration 25, after the step-up in memory usage, I had seen that with 2048-token length training data, allocated memory usage jumped from 8.05GiB to 11.46GiB during the forward pass. But with 10-token data, I just got a jump from 8.05 GiB to 8.07GiB. That 100% makes sense.
Like I said before, experienced practitioners presumably know that they need to add on some kind of fiddle-factor -- let's call it f(n) for a sequence length of n. It would be really interesting to work out roughly what the shape of f is -- is it linear? Log? Exponential? I think that will be an interesting experiment later.
So now I had the beginnings of a reasonable answer to the question of why the allocated memory usage was so much higher than the estimate provided by that function. The function was not allowing for the sequence length, and that parameter strongly influences VRAM usage.
But were the numbers that I was seeing accurate? After all, they were point-in-time measurement, and even if we had 8 GiB allocated before the backward pass, it could potentially spike during that pass.
What is the real high-water-mark for allocated memory?
Again, ChatGPT gave an interesting suggestion. There is a function called
torch.cuda.memory_summary()
. I added that to the code
at the end of my print_memory_usage
function and ran it. Here's what we had at iteration 54:
Step 54 (before backward pass): Allocated: 11.46 GB, Reserved: 19.39 GB
|===========================================================================|
| PyTorch CUDA memory summary, device ID 0 |
|---------------------------------------------------------------------------|
| CUDA OOMs: 0 | cudaMalloc retries: 0 |
|===========================================================================|
| Metric | Cur Usage | Peak Usage | Tot Alloc | Tot Freed |
|---------------------------------------------------------------------------|
| Allocated memory | 8246 MiB | 14106 MiB | 1653 GiB | 1645 GiB |
| from large pool | 8245 MiB | 14090 MiB | 1648 GiB | 1640 GiB |
| from small pool | 0 MiB | 16 MiB | 5 GiB | 5 GiB |
|---------------------------------------------------------------------------|
| Active memory | 8246 MiB | 14106 MiB | 1653 GiB | 1645 GiB |
| from large pool | 8245 MiB | 14090 MiB | 1648 GiB | 1640 GiB |
| from small pool | 0 MiB | 16 MiB | 5 GiB | 5 GiB |
|---------------------------------------------------------------------------|
| Requested memory | 8241 MiB | 14048 MiB | 1641 GiB | 1632 GiB |
| from large pool | 8240 MiB | 14032 MiB | 1635 GiB | 1627 GiB |
| from small pool | 0 MiB | 16 MiB | 5 GiB | 5 GiB |
|---------------------------------------------------------------------------|
| GPU reserved memory | 19852 MiB | 19852 MiB | 20688 MiB | 836 MiB |
| from large pool | 19832 MiB | 19832 MiB | 20668 MiB | 836 MiB |
| from small pool | 20 MiB | 20 MiB | 20 MiB | 0 MiB |
|---------------------------------------------------------------------------|
| Non-releasable memory | 1207 MiB | 2655 MiB | 2187 GiB | 2185 GiB |
| from large pool | 1206 MiB | 2653 MiB | 2180 GiB | 2179 GiB |
| from small pool | 1 MiB | 3 MiB | 6 GiB | 6 GiB |
|---------------------------------------------------------------------------|
| Allocations | 92 | 605 | 201681 | 201589 |
| from large pool | 56 | 447 | 133996 | 133940 |
| from small pool | 36 | 264 | 67685 | 67649 |
|---------------------------------------------------------------------------|
| Active allocs | 92 | 605 | 201681 | 201589 |
| from large pool | 56 | 447 | 133996 | 133940 |
| from small pool | 36 | 264 | 67685 | 67649 |
|---------------------------------------------------------------------------|
| GPU reserved segments | 140 | 140 | 154 | 14 |
| from large pool | 130 | 130 | 144 | 14 |
| from small pool | 10 | 10 | 10 | 0 |
|---------------------------------------------------------------------------|
| Non-releasable allocs | 77 | 110 | 128843 | 128766 |
| from large pool | 48 | 80 | 99900 | 99852 |
| from small pool | 29 | 34 | 28943 | 28914 |
|---------------------------------------------------------------------------|
| Oversize allocations | 0 | 0 | 0 | 0 |
|---------------------------------------------------------------------------|
| Oversize GPU segments | 0 | 0 | 0 | 0 |
|===========================================================================|
This was with the 2048-token sequence length. That's actually quite interesting. Although the peak usage shown in our previous logging was 11.46GiB (consistent with the previous runs) the memory summary shows a high-water mark at 14.1GiB. Perhaps that's the activations plus the gradients or something like that; it would make sense that a certain amount of VRAM is used over the forward pass but freed at the end of it.
Re-running with a sequence length of 1 gave a high-water mark of 11,781 MiB, and re-introducing the 16-bit stuff into the JSON config brought that down to 10,896 MiB. So this was interesting; we were almost down to the estimated number, at least for allocated memory -- 10 GiB instead of 8 GiB.
I decided that this was close enough, and that it was time to focus on the reserved memory.
Digging into reserved memory a bit more
Considering the PyTorch docs I mentioned earlier:
- The allocated memory is "the current GPU memory occupied by tensors in bytes for a given device."
- The reserved memory is "the current GPU memory managed by the caching allocator in bytes for a given device."
I read this as meaning that the extra VRAM between the peak 14106 MiB allocated and the 19852 MiB peak reserved was likely to have been used by PyTorch caching stuff.
As an experiment, I decided to see what would happen if I disabled all caching. This would obviously ruin performance, but it would give useful data. In the PyTorch forums I found reference to a likely-looking environment variable:
export PYTORCH_NO_CUDA_MEMORY_CACHING=1
So I tried that (back with 32-bit and a 2048 sequence length)... and success! Sort of. VRAM usage reported by the script was zero at every step:
Step 98 (after forward pass): Allocated: 0.00 GB, Reserved: 0.00 GB
|===========================================================================|
| PyTorch CUDA memory summary, device ID 0 |
|---------------------------------------------------------------------------|
| CUDA OOMs: 0 | cudaMalloc retries: 0 |
|===========================================================================|
| Metric | Cur Usage | Peak Usage | Tot Alloc | Tot Freed |
|---------------------------------------------------------------------------|
| Allocated memory | 0 B | 0 B | 0 B | 0 B |
| from large pool | 0 B | 0 B | 0 B | 0 B |
| from small pool | 0 B | 0 B | 0 B | 0 B |
|---------------------------------------------------------------------------|
| Active memory | 0 B | 0 B | 0 B | 0 B |
| from large pool | 0 B | 0 B | 0 B | 0 B |
| from small pool | 0 B | 0 B | 0 B | 0 B |
|---------------------------------------------------------------------------|
| Requested memory | 0 B | 0 B | 0 B | 0 B |
| from large pool | 0 B | 0 B | 0 B | 0 B |
| from small pool | 0 B | 0 B | 0 B | 0 B |
|---------------------------------------------------------------------------|
| GPU reserved memory | 0 B | 0 B | 0 B | 0 B |
| from large pool | 0 B | 0 B | 0 B | 0 B |
| from small pool | 0 B | 0 B | 0 B | 0 B |
|---------------------------------------------------------------------------|
| Non-releasable memory | 0 B | 0 B | 0 B | 0 B |
| from large pool | 0 B | 0 B | 0 B | 0 B |
| from small pool | 0 B | 0 B | 0 B | 0 B |
|---------------------------------------------------------------------------|
| Allocations | 0 | 0 | 0 | 0 |
| from large pool | 0 | 0 | 0 | 0 |
| from small pool | 0 | 0 | 0 | 0 |
|---------------------------------------------------------------------------|
| Active allocs | 0 | 0 | 0 | 0 |
| from large pool | 0 | 0 | 0 | 0 |
| from small pool | 0 | 0 | 0 | 0 |
|---------------------------------------------------------------------------|
| GPU reserved segments | 0 | 0 | 0 | 0 |
| from large pool | 0 | 0 | 0 | 0 |
| from small pool | 0 | 0 | 0 | 0 |
|---------------------------------------------------------------------------|
| Non-releasable allocs | 0 | 0 | 0 | 0 |
| from large pool | 0 | 0 | 0 | 0 |
| from small pool | 0 | 0 | 0 | 0 |
|---------------------------------------------------------------------------|
| Oversize allocations | 0 | 0 | 0 | 0 |
|---------------------------------------------------------------------------|
| Oversize GPU segments | 0 | 0 | 0 | 0 |
|===========================================================================|
...but I could see in nvtop
that
it was mostly using between 5 and 10GiB, with at least one spike up to 13GiB in the steps up
to iteration 98. That was a touch below the "peak usage" allocated memory numbers from
the previous run without that, and there was no indication of any reserved memory.
However, our iteration speed plummeted, which makes sense if we were moving everything off the GPU after each pass.
An alternative option from ChatGPT was to set the PYTORCH_CUDA_ALLOC_CONF
environment
variable:
export PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:128
This appeared to be real, but, if anything, it seemed to make things worse!
Step 86 (after forward pass): Allocated: 11.49 GB, Reserved: 21.23 GB
|===========================================================================|
| PyTorch CUDA memory summary, device ID 0 |
|---------------------------------------------------------------------------|
| CUDA OOMs: 0 | cudaMalloc retries: 0 |
|===========================================================================|
| Metric | Cur Usage | Peak Usage | Tot Alloc | Tot Freed |
|---------------------------------------------------------------------------|
| Allocated memory | 11764 MiB | 14140 MiB | 2698 GiB | 2687 GiB |
| from large pool | 11748 MiB | 14124 MiB | 2690 GiB | 2679 GiB |
| from small pool | 16 MiB | 16 MiB | 8 GiB | 8 GiB |
|---------------------------------------------------------------------------|
| Active memory | 11764 MiB | 14140 MiB | 2698 GiB | 2687 GiB |
| from large pool | 11748 MiB | 14124 MiB | 2690 GiB | 2679 GiB |
| from small pool | 16 MiB | 16 MiB | 8 GiB | 8 GiB |
|---------------------------------------------------------------------------|
| Requested memory | 11675 MiB | 14048 MiB | 2670 GiB | 2659 GiB |
| from large pool | 11659 MiB | 14032 MiB | 2662 GiB | 2651 GiB |
| from small pool | 16 MiB | 16 MiB | 8 GiB | 7 GiB |
|---------------------------------------------------------------------------|
| GPU reserved memory | 21742 MiB | 21742 MiB | 22598 MiB | 856 MiB |
| from large pool | 21722 MiB | 21722 MiB | 22578 MiB | 856 MiB |
| from small pool | 20 MiB | 20 MiB | 20 MiB | 0 MiB |
|---------------------------------------------------------------------------|
| Non-releasable memory | 300632 KiB | 574803 KiB | 1163 GiB | 1163 GiB |
| from large pool | 298880 KiB | 573184 KiB | 1153 GiB | 1153 GiB |
| from small pool | 1752 KiB | 3745 KiB | 9 GiB | 9 GiB |
|---------------------------------------------------------------------------|
| Allocations | 555 | 605 | 328710 | 328155 |
| from large pool | 397 | 447 | 212296 | 211899 |
| from small pool | 158 | 264 | 116414 | 116256 |
|---------------------------------------------------------------------------|
| Active allocs | 555 | 605 | 328710 | 328155 |
| from large pool | 397 | 447 | 212296 | 211899 |
| from small pool | 158 | 264 | 116414 | 116256 |
|---------------------------------------------------------------------------|
| GPU reserved segments | 204 | 204 | 219 | 15 |
| from large pool | 194 | 194 | 209 | 15 |
| from small pool | 10 | 10 | 10 | 0 |
|---------------------------------------------------------------------------|
| Non-releasable allocs | 102 | 107 | 219054 | 218952 |
| from large pool | 72 | 77 | 170332 | 170260 |
| from small pool | 30 | 34 | 48722 | 48692 |
|---------------------------------------------------------------------------|
| Oversize allocations | 7 | 9 | 1553 | 1546 |
|---------------------------------------------------------------------------|
| Oversize GPU segments | 18 | 18 | 20 | 2 |
|===========================================================================|
So that didn't help. ChatGPT's final suggestion was to try calling
torch.cuda.empty_cache()
at strategic points. This sounded worth a look. It
suggested putting it in the print_memory_usage
function, but that doesn't seem right.
Instead, I put it in the training loop, right at the end (after the "after training_step"
print). I also removed the memory_summary
prints, as that was getting in the way a bit.
Here's the code and here's what I got:
Step 0 (before training_step): Allocated: 4.58 GB, Reserved: 5.76 GB
Step 0 (before forward pass): Allocated: 4.58 GB, Reserved: 5.76 GB
Step 0 (after forward pass): Allocated: 8.00 GB, Reserved: 10.01 GB
Step 0 (before backward pass): Allocated: 8.00 GB, Reserved: 10.01 GB
Step 0 (after backward pass): Allocated: 4.60 GB, Reserved: 12.47 GB
Step 0 (after training_step): Allocated: 4.60 GB, Reserved: 12.47 GB
0%| | 1/19692 [00:00<2:07:35, 2.57it/s]
Step 1 (before training_step): Allocated: 4.60 GB, Reserved: 5.78 GB
Step 1 (before forward pass): Allocated: 4.60 GB, Reserved: 5.78 GB
Step 1 (after forward pass): Allocated: 8.01 GB, Reserved: 10.02 GB
Step 1 (before backward pass): Allocated: 8.01 GB, Reserved: 10.02 GB
Step 1 (after backward pass): Allocated: 4.60 GB, Reserved: 12.48 GB
Step 1 (after training_step): Allocated: 4.60 GB, Reserved: 12.48 GB
0%| | 2/19692 [00:01<2:58:03, 1.84it/s]
Step 2 (before training_step): Allocated: 4.60 GB, Reserved: 5.78 GB
Step 2 (before forward pass): Allocated: 4.60 GB, Reserved: 5.78 GB
Step 2 (after forward pass): Allocated: 8.01 GB, Reserved: 10.02 GB
Step 2 (before backward pass): Allocated: 8.01 GB, Reserved: 10.02 GB
Step 2 (after backward pass): Allocated: 4.60 GB, Reserved: 12.48 GB
Step 2 (after training_step): Allocated: 4.60 GB, Reserved: 12.48 GB
0%| | 3/19692 [00:01<2:27:14, 2.23it/s]
Step 3 (before training_step): Allocated: 4.60 GB, Reserved: 5.78 GB
Step 3 (before forward pass): Allocated: 4.60 GB, Reserved: 5.78 GB
Step 3 (after forward pass): Allocated: 8.01 GB, Reserved: 10.02 GB
Step 3 (before backward pass): Allocated: 8.01 GB, Reserved: 10.02 GB
Step 3 (after backward pass): Allocated: 4.60 GB, Reserved: 12.48 GB
Step 3 (after training_step): Allocated: 4.60 GB, Reserved: 12.48 GB
0%| | 4/19692 [00:01<2:13:35, 2.46it/s]
Step 4 (before training_step): Allocated: 4.60 GB, Reserved: 5.78 GB
Step 4 (before forward pass): Allocated: 4.60 GB, Reserved: 5.78 GB
Step 4 (after forward pass): Allocated: 8.01 GB, Reserved: 10.02 GB
Step 4 (before backward pass): Allocated: 8.01 GB, Reserved: 10.02 GB
Step 4 (after backward pass): Allocated: 4.60 GB, Reserved: 12.48 GB
Step 4 (after training_step): Allocated: 4.60 GB, Reserved: 12.48 GB
0%| | 5/19692 [00:02<2:05:09, 2.62it/s]
Step 5 (before training_step): Allocated: 4.60 GB, Reserved: 5.78 GB
Step 5 (before forward pass): Allocated: 4.60 GB, Reserved: 5.78 GB
Step 5 (after forward pass): Allocated: 8.01 GB, Reserved: 10.02 GB
Step 5 (before backward pass): Allocated: 8.01 GB, Reserved: 10.02 GB
Step 5 (after backward pass): Allocated: 4.60 GB, Reserved: 12.48 GB
Step 5 (after training_step): Allocated: 4.60 GB, Reserved: 12.48 GB
0%| | 6/19692 [00:02<2:00:07, 2.73it/s]
Step 6 (before training_step): Allocated: 4.60 GB, Reserved: 5.78 GB
Step 6 (before forward pass): Allocated: 4.60 GB, Reserved: 5.78 GB
Step 6 (after forward pass): Allocated: 8.01 GB, Reserved: 10.02 GB
Step 6 (before backward pass): Allocated: 8.01 GB, Reserved: 10.02 GB
Step 6 (after backward pass): Allocated: 4.60 GB, Reserved: 12.48 GB
Step 6 (after training_step): Allocated: 4.60 GB, Reserved: 12.48 GB
0%| | 7/19692 [00:02<1:57:10, 2.80it/s]
Step 7 (before training_step): Allocated: 4.60 GB, Reserved: 5.78 GB
Step 7 (before forward pass): Allocated: 4.60 GB, Reserved: 5.78 GB
Step 7 (after forward pass): Allocated: 8.01 GB, Reserved: 10.02 GB
Step 7 (before backward pass): Allocated: 8.01 GB, Reserved: 10.02 GB
Step 7 (after backward pass): Allocated: 4.60 GB, Reserved: 12.48 GB
Step 7 (after training_step): Allocated: 4.60 GB, Reserved: 12.48 GB
0%| | 8/19692 [00:03<1:55:06, 2.85it/s]
Step 8 (before training_step): Allocated: 4.60 GB, Reserved: 5.78 GB
Step 8 (before forward pass): Allocated: 4.60 GB, Reserved: 5.78 GB
Step 8 (after forward pass): Allocated: 8.01 GB, Reserved: 10.02 GB
Step 8 (before backward pass): Allocated: 8.01 GB, Reserved: 10.02 GB
Step 8 (after backward pass): Allocated: 4.60 GB, Reserved: 12.48 GB
Step 8 (after training_step): Allocated: 4.60 GB, Reserved: 12.48 GB
0%| | 9/19692 [00:03<1:53:40, 2.89it/s]
Step 9 (before training_step): Allocated: 4.60 GB, Reserved: 5.78 GB
Step 9 (before forward pass): Allocated: 4.60 GB, Reserved: 5.78 GB
Step 9 (after forward pass): Allocated: 8.01 GB, Reserved: 10.02 GB
Step 9 (before backward pass): Allocated: 8.01 GB, Reserved: 10.02 GB
0%| | 10/19692 [00:03<18:53:57 [224/1846]
Step 10 (before training_step): Allocated: 4.60 GB, Reserved: 5.78 GB
Step 10 (before forward pass): Allocated: 4.60 GB, Reserved: 5.78 GB
Step 10 (after forward pass): Allocated: 8.01 GB, Reserved: 10.02 GB
Step 10 (before backward pass): Allocated: 8.01 GB, Reserved: 10.02 GB
Step 10 (after backward pass): Allocated: 4.60 GB, Reserved: 12.48 GB
Step 10 (after training_step): Allocated: 4.60 GB, Reserved: 12.48 GB
0%| | 11/19692 [00:04<1:52:15, 2.92it/s]
Step 11 (before training_step): Allocated: 4.60 GB, Reserved: 5.78 GB
Step 11 (before forward pass): Allocated: 4.60 GB, Reserved: 5.78 GB
Step 11 (after forward pass): Allocated: 8.01 GB, Reserved: 10.02 GB
Step 11 (before backward pass): Allocated: 8.01 GB, Reserved: 10.02 GB
Step 11 (after backward pass): Allocated: 4.60 GB, Reserved: 12.48 GB
Step 11 (after training_step): Allocated: 4.60 GB, Reserved: 12.48 GB
0%| | 12/19692 [00:04<1:51:36, 2.94it/s]
Step 12 (before training_step): Allocated: 4.60 GB, Reserved: 5.78 GB
Step 12 (before forward pass): Allocated: 4.60 GB, Reserved: 5.78 GB
Step 12 (after forward pass): Allocated: 8.01 GB, Reserved: 10.02 GB
Step 12 (before backward pass): Allocated: 8.01 GB, Reserved: 10.02 GB
Step 12 (after backward pass): Allocated: 4.60 GB, Reserved: 12.48 GB
Step 12 (after training_step): Allocated: 4.60 GB, Reserved: 12.48 GB
0%| | 13/19692 [00:04<1:51:23, 2.94it/s]
Step 13 (before training_step): Allocated: 4.60 GB, Reserved: 5.78 GB
Step 13 (before forward pass): Allocated: 4.60 GB, Reserved: 5.78 GB
Step 13 (after forward pass): Allocated: 8.01 GB, Reserved: 10.02 GB
Step 13 (before backward pass): Allocated: 8.01 GB, Reserved: 10.02 GB
Step 13 (after backward pass): Allocated: 4.60 GB, Reserved: 12.48 GB
Step 13 (after training_step): Allocated: 4.60 GB, Reserved: 12.48 GB
0%| | 14/19692 [00:05<1:50:58, 2.96it/s]
Step 14 (before training_step): Allocated: 4.60 GB, Reserved: 5.78 GB
Step 14 (before forward pass): Allocated: 4.60 GB, Reserved: 5.78 GB
Step 14 (after forward pass): Allocated: 8.01 GB, Reserved: 10.02 GB
Step 14 (before backward pass): Allocated: 8.01 GB, Reserved: 10.02 GB
Step 14 (after backward pass): Allocated: 4.60 GB, Reserved: 12.48 GB
Step 14 (after training_step): Allocated: 4.60 GB, Reserved: 12.48 GB
0%| | 15/19692 [00:05<1:50:48, 2.96it/s]
Step 15 (before training_step): Allocated: 4.60 GB, Reserved: 5.78 GB
Step 15 (before forward pass): Allocated: 4.60 GB, Reserved: 5.78 GB
Step 15 (after forward pass): Allocated: 8.01 GB, Reserved: 10.02 GB
Step 15 (before backward pass): Allocated: 8.01 GB, Reserved: 10.02 GB
Step 15 (after backward pass): Allocated: 4.60 GB, Reserved: 12.48 GB 18:53:59 [184/1846]
Step 15 (after training_step): Allocated: 4.60 GB, Reserved: 12.48 GB
0%| | 16/19692 [00:05<1:50:48, 2.96it/s]
Step 16 (before training_step): Allocated: 4.60 GB, Reserved: 5.78 GB
Step 16 (before forward pass): Allocated: 4.60 GB, Reserved: 5.78 GB
Step 16 (after forward pass): Allocated: 8.01 GB, Reserved: 10.02 GB
Step 16 (before backward pass): Allocated: 8.01 GB, Reserved: 10.02 GB
Step 16 (after backward pass): Allocated: 4.60 GB, Reserved: 12.48 GB
Step 16 (after training_step): Allocated: 4.60 GB, Reserved: 12.48 GB
0%| | 17/19692 [00:06<1:50:50, 2.96it/s]
Step 17 (before training_step): Allocated: 4.60 GB, Reserved: 5.78 GB
Step 17 (before forward pass): Allocated: 4.60 GB, Reserved: 5.78 GB
Step 17 (after forward pass): Allocated: 8.01 GB, Reserved: 10.02 GB
Step 17 (before backward pass): Allocated: 8.01 GB, Reserved: 10.02 GB
Step 17 (after backward pass): Allocated: 4.60 GB, Reserved: 12.48 GB
Step 17 (after training_step): Allocated: 4.60 GB, Reserved: 12.48 GB
0%| | 18/19692 [00:06<1:50:55, 2.96it/s]
Step 18 (before training_step): Allocated: 4.60 GB, Reserved: 5.78 GB
Step 18 (before forward pass): Allocated: 4.60 GB, Reserved: 5.78 GB
Step 18 (after forward pass): Allocated: 8.01 GB, Reserved: 10.02 GB
Step 18 (before backward pass): Allocated: 8.01 GB, Reserved: 10.02 GB
Step 18 (after backward pass): Allocated: 4.60 GB, Reserved: 12.48 GB
Step 18 (after training_step): Allocated: 4.60 GB, Reserved: 12.48 GB
0%| | 19/19692 [00:06<1:50:43, 2.96it/s]
Step 19 (before training_step): Allocated: 4.60 GB, Reserved: 5.78 GB
Step 19 (before forward pass): Allocated: 4.60 GB, Reserved: 5.78 GB
Step 19 (after forward pass): Allocated: 8.01 GB, Reserved: 10.02 GB
Step 19 (before backward pass): Allocated: 8.01 GB, Reserved: 10.02 GB
Step 19 (after backward pass): Allocated: 4.60 GB, Reserved: 12.48 GB
Step 19 (after training_step): Allocated: 4.60 GB, Reserved: 12.48 GB
0%|▏ | 20/19692 [00:07<1:50:53, 2.96it/s]
Step 20 (before training_step): Allocated: 4.60 GB, Reserved: 5.78 GB
Step 20 (before forward pass): Allocated: 4.60 GB, Reserved: 5.78 GB
Step 20 (after forward pass): Allocated: 8.01 GB, Reserved: 10.02 GB
Step 20 (before backward pass): Allocated: 8.01 GB, Reserved: 10.02 GB
Step 20 (after backward pass): Allocated: 4.60 GB, Reserved: 12.48 GB
Step 20 (after training_step): Allocated: 4.60 GB, Reserved: 12.48 GB
0%|▏ | 21/19692 [00:07<1:50:37, 2.96it/s]
Step 21 (before training_step): Allocated: 4.60 GB, Reserved: 5.78 GB
Step 21 (before forward pass): Allocated: 4.60 GB, Reserved: 5.78 GB
Step 21 (after forward pass): Allocated: 8.01 GB, Reserved: 10.02 GB 18:54:01 [144/1846]
Step 21 (before backward pass): Allocated: 8.01 GB, Reserved: 10.02 GB
Step 21 (after backward pass): Allocated: 4.60 GB, Reserved: 12.48 GB
Step 21 (after training_step): Allocated: 4.60 GB, Reserved: 12.48 GB
0%|▏ | 22/19692 [00:07<1:50:31, 2.97it/s]
Step 22 (before training_step): Allocated: 4.60 GB, Reserved: 5.78 GB
Step 22 (before forward pass): Allocated: 4.60 GB, Reserved: 5.78 GB
Step 22 (after forward pass): Allocated: 8.01 GB, Reserved: 10.02 GB
Step 22 (before backward pass): Allocated: 8.01 GB, Reserved: 10.02 GB
Step 22 (after backward pass): Allocated: 4.60 GB, Reserved: 12.48 GB
Step 22 (after training_step): Allocated: 4.60 GB, Reserved: 12.48 GB
0%|▏ | 23/19692 [00:08<1:50:21, 2.97it/s]
Step 23 (before training_step): Allocated: 4.60 GB, Reserved: 5.78 GB
Step 23 (before forward pass): Allocated: 4.60 GB, Reserved: 5.78 GB
Step 23 (after forward pass): Allocated: 8.01 GB, Reserved: 10.02 GB
Step 23 (before backward pass): Allocated: 8.01 GB, Reserved: 10.02 GB
Step 23 (after backward pass): Allocated: 4.60 GB, Reserved: 12.48 GB
Step 23 (after training_step): Allocated: 4.60 GB, Reserved: 12.48 GB
0%|▏ | 24/19692 [00:08<1:50:18, 2.97it/s]
Step 24 (before training_step): Allocated: 4.60 GB, Reserved: 5.78 GB
Step 24 (before forward pass): Allocated: 4.60 GB, Reserved: 5.78 GB
Step 24 (after forward pass): Allocated: 8.01 GB, Reserved: 10.02 GB
Step 24 (before backward pass): Allocated: 8.01 GB, Reserved: 10.02 GB
Step 24 (after backward pass): Allocated: 8.05 GB, Reserved: 19.39 GB
Step 24 (after training_step): Allocated: 8.05 GB, Reserved: 19.39 GB
0%|▏ | 25/19692 [00:08<1:58:54, 2.76it/s]
Step 25 (before training_step): Allocated: 8.05 GB, Reserved: 9.23 GB
Step 25 (before forward pass): Allocated: 8.05 GB, Reserved: 9.23 GB
Step 25 (after forward pass): Allocated: 11.46 GB, Reserved: 13.47 GB
Step 25 (before backward pass): Allocated: 11.46 GB, Reserved: 13.47 GB
Step 25 (after backward pass): Allocated: 8.05 GB, Reserved: 19.39 GB
Step 25 (after training_step): Allocated: 8.05 GB, Reserved: 19.39 GB
0%|▏ | 26/19692 [00:09<2:05:37, 2.61it/s]
Step 26 (before training_step): Allocated: 8.05 GB, Reserved: 9.23 GB
Step 26 (before forward pass): Allocated: 8.05 GB, Reserved: 9.23 GB
Step 26 (after forward pass): Allocated: 11.46 GB, Reserved: 13.47 GB
Step 26 (before backward pass): Allocated: 11.46 GB, Reserved: 13.47 GB
Step 26 (after backward pass): Allocated: 8.05 GB, Reserved: 19.39 GB
Step 26 (after training_step): Allocated: 8.05 GB, Reserved: 19.39 GB
0%|▏ | 27/19692 [00:09<18:54:03 [105/1846]
Step 27 (before training_step): Allocated: 8.05 GB, Reserved: 9.23 GB
Step 27 (before forward pass): Allocated: 8.05 GB, Reserved: 9.23 GB
Step 27 (after forward pass): Allocated: 11.46 GB, Reserved: 13.47 GB
Step 27 (before backward pass): Allocated: 11.46 GB, Reserved: 13.47 GB
Step 27 (after backward pass): Allocated: 8.05 GB, Reserved: 19.39 GB
Step 27 (after training_step): Allocated: 8.05 GB, Reserved: 19.39 GB
0%|▏ | 28/19692 [00:10<2:13:49, 2.45it/s]
Step 28 (before training_step): Allocated: 8.05 GB, Reserved: 9.23 GB
Step 28 (before forward pass): Allocated: 8.05 GB, Reserved: 9.23 GB
Step 28 (after forward pass): Allocated: 11.46 GB, Reserved: 13.47 GB
Step 28 (before backward pass): Allocated: 11.46 GB, Reserved: 13.47 GB
Step 28 (after backward pass): Allocated: 8.05 GB, Reserved: 19.39 GB
Step 28 (after training_step): Allocated: 8.05 GB, Reserved: 19.39 GB
0%|▏ | 29/19692 [00:10<2:15:59, 2.41it/s]
Step 29 (before training_step): Allocated: 8.05 GB, Reserved: 9.23 GB
Step 29 (before forward pass): Allocated: 8.05 GB, Reserved: 9.23 GB
Step 29 (after forward pass): Allocated: 11.46 GB, Reserved: 13.47 GB
Step 29 (before backward pass): Allocated: 11.46 GB, Reserved: 13.47 GB
Step 29 (after backward pass): Allocated: 8.05 GB, Reserved: 19.39 GB
Step 29 (after training_step): Allocated: 8.05 GB, Reserved: 19.39 GB
0%|▏ | 30/19692 [00:11<2:17:36, 2.38it/s]
Step 30 (before training_step): Allocated: 8.05 GB, Reserved: 9.23 GB
Step 30 (before forward pass): Allocated: 8.05 GB, Reserved: 9.23 GB
Step 30 (after forward pass): Allocated: 11.46 GB, Reserved: 13.47 GB
Step 30 (before backward pass): Allocated: 11.46 GB, Reserved: 13.47 GB
Step 30 (after backward pass): Allocated: 8.05 GB, Reserved: 19.39 GB
Step 30 (after training_step): Allocated: 8.05 GB, Reserved: 19.39 GB
0%|▏ | 31/19692 [00:11<2:18:43, 2.36it/s]
Step 31 (before training_step): Allocated: 8.05 GB, Reserved: 9.23 GB
Step 31 (before forward pass): Allocated: 8.05 GB, Reserved: 9.23 GB
Step 31 (after forward pass): Allocated: 11.46 GB, Reserved: 13.47 GB
Step 31 (before backward pass): Allocated: 11.46 GB, Reserved: 13.47 GB
Step 31 (after backward pass): Allocated: 8.05 GB, Reserved: 19.39 GB
Step 31 (after training_step): Allocated: 8.05 GB, Reserved: 19.39 GB
0%|▏ | 32/19692 [00:11<2:19:31, 2.35it/s]
Step 32 (before training_step): Allocated: 8.05 GB, Reserved: 9.23 GB
Step 32 (before forward pass): Allocated: 8.05 GB, Reserved: 9.23 GB
Step 32 (after forward pass): Allocated: 11.46 GB, Reserved: 13.47 GB
Step 32 (before backward pass): Allocated: 11.46 GB, Reserved: 13.47 GB
Step 32 (after backward pass): Allocated: 8.05 GB, Reserved: 19.39 GB 18:54:06 [65/1846]
Step 32 (after training_step): Allocated: 8.05 GB, Reserved: 19.39 GB
0%|▏ | 33/19692 [00:12<2:20:02, 2.34it/s]
Step 33 (before training_step): Allocated: 8.05 GB, Reserved: 9.23 GB
Step 33 (before forward pass): Allocated: 8.05 GB, Reserved: 9.23 GB
Step 33 (after forward pass): Allocated: 11.46 GB, Reserved: 13.47 GB
Step 33 (before backward pass): Allocated: 11.46 GB, Reserved: 13.47 GB
Step 33 (after backward pass): Allocated: 8.05 GB, Reserved: 19.39 GB
Step 33 (after training_step): Allocated: 8.05 GB, Reserved: 19.39 GB
0%|▏ | 34/19692 [00:12<2:20:17, 2.34it/s]
Step 34 (before training_step): Allocated: 8.05 GB, Reserved: 9.23 GB
Step 34 (before forward pass): Allocated: 8.05 GB, Reserved: 9.23 GB
Step 34 (after forward pass): Allocated: 11.46 GB, Reserved: 13.47 GB
Step 34 (before backward pass): Allocated: 11.46 GB, Reserved: 13.47 GB
Step 34 (after backward pass): Allocated: 8.05 GB, Reserved: 19.39 GB
Step 34 (after training_step): Allocated: 8.05 GB, Reserved: 19.39 GB
0%|▏ | 35/19692 [00:13<2:20:29, 2.33it/s]
Step 35 (before training_step): Allocated: 8.05 GB, Reserved: 9.23 GB
Step 35 (before forward pass): Allocated: 8.05 GB, Reserved: 9.23 GB
Step 35 (after forward pass): Allocated: 11.46 GB, Reserved: 13.47 GB
Step 35 (before backward pass): Allocated: 11.46 GB, Reserved: 13.47 GB
Step 35 (after backward pass): Allocated: 8.05 GB, Reserved: 19.39 GB
Step 35 (after training_step): Allocated: 8.05 GB, Reserved: 19.39 GB
0%|▏ | 36/19692 [00:13<2:20:39, 2.33it/s]
Step 36 (before training_step): Allocated: 8.05 GB, Reserved: 9.23 GB
Step 36 (before forward pass): Allocated: 8.05 GB, Reserved: 9.23 GB
Step 36 (after forward pass): Allocated: 11.46 GB, Reserved: 13.47 GB
Step 36 (before backward pass): Allocated: 11.46 GB, Reserved: 13.47 GB
Step 36 (after backward pass): Allocated: 8.05 GB, Reserved: 19.39 GB
Step 36 (after training_step): Allocated: 8.05 GB, Reserved: 19.39 GB
0%|▏ | 37/19692 [00:14<2:23:44, 2.28it/s]
Step 37 (before training_step): Allocated: 8.05 GB, Reserved: 9.23 GB
Step 37 (before forward pass): Allocated: 8.05 GB, Reserved: 9.23 GB
Step 37 (after forward pass): Allocated: 11.46 GB, Reserved: 13.47 GB
Step 37 (before backward pass): Allocated: 11.46 GB, Reserved: 13.47 GB
Step 37 (after backward pass): Allocated: 8.05 GB, Reserved: 19.39 GB
Step 37 (after training_step): Allocated: 8.05 GB, Reserved: 19.39 GB
0%|▏ | 38/19692 [00:14<2:23:02, 2.29it/s]
That was clearly having some effect; compare the last two steps there with a previous run with the same 2048 sequence length:
Step 31 (before training_step): Allocated: 8.05 GB, Reserved: 19.39 GB
Step 31 (before forward pass): Allocated: 8.05 GB, Reserved: 19.39 GB
Step 31 (after forward pass): Allocated: 11.46 GB, Reserved: 19.39 GB
Step 31 (before backward pass): Allocated: 11.46 GB, Reserved: 19.39 GB
Step 31 (after backward pass): Allocated: 8.05 GB, Reserved: 19.39 GB
Step 31 (after training_step): Allocated: 8.05 GB, Reserved: 19.39 GB
0%|▏ | 32/19692 [00:08<1:36:24, 3.40it/s]
Step 32 (before training_step): Allocated: 8.05 GB, Reserved: 19.39 GB
Step 32 (before forward pass): Allocated: 8.05 GB, Reserved: 19.39 GB
Step 32 (after forward pass): Allocated: 11.46 GB, Reserved: 19.39 GB
Step 32 (before backward pass): Allocated: 11.46 GB, Reserved: 19.39 GB
Step 32 (after backward pass): Allocated: 8.05 GB, Reserved: 19.39 GB
Step 32 (after training_step): Allocated: 8.05 GB, Reserved: 19.39 GB
0%|▏
The allocated memory was following the same pattern, as you'd expect given that the only change was in the caching, but the reserved memory was starting each training step only 1 GiB higher than allocated, then rising over the course of the training step to the same level. This is pretty much as you might expect from having put code to empty the cache at the end of the training step. Note also that it had slowed down: in the test with the cache-emptying I was getting about 2.3 iterations/second instead of the 3.4 without that code. Again, that sounds pretty much as you'd expect.
Now, the big rise in cache usage appears to be over the course of the backward pass:
Step 37 (before training_step): Allocated: 8.05 GB, Reserved: 9.23 GB
Step 37 (before forward pass): Allocated: 8.05 GB, Reserved: 9.23 GB
Step 37 (after forward pass): Allocated: 11.46 GB, Reserved: 13.47 GB
Step 37 (before backward pass): Allocated: 11.46 GB, Reserved: 13.47 GB
Step 37 (after backward pass): Allocated: 8.05 GB, Reserved: 19.39 GB
Step 37 (after training_step): Allocated: 8.05 GB, Reserved: 19.39 GB
What would happen if I cleared the cache between the forward and the backward pass?
0%| | 0/19692 [00:00<?, ?it/s]
Step 0 (before training_step): Allocated: 4.58 GB, Reserved: 5.76 GB
Step 0 (before forward pass): Allocated: 4.58 GB, Reserved: 5.76 GB
Step 0 (after forward pass): Allocated: 8.00 GB, Reserved: 10.01 GB
Step 0 (before backward pass): Allocated: 8.00 GB, Reserved: 8.25 GB
Step 0 (after backward pass): Allocated: 4.60 GB, Reserved: 11.87 GB
Step 0 (after training_step): Allocated: 4.60 GB, Reserved: 11.87 GB
0%| | 1/19692 [00:00<2:07:21, 2.58it/s]
Step 1 (before training_step): Allocated: 4.60 GB, Reserved: 6.92 GB
Step 1 (before forward pass): Allocated: 4.60 GB, Reserved: 6.92 GB
Step 1 (after forward pass): Allocated: 7.96 GB, Reserved: 9.98 GB
Step 1 (before backward pass): Allocated: 7.96 GB, Reserved: 8.22 GB
Step 1 (after backward pass): Allocated: 4.60 GB, Reserved: 11.84 GB
Step 1 (after training_step): Allocated: 4.60 GB, Reserved: 11.84 GB
0%| | 2/19692 [00:01<2:54:15, 1.88it/s]
Step 2 (before training_step): Allocated: 4.60 GB, Reserved: 6.92 GB
Step 2 (before forward pass): Allocated: 4.60 GB, Reserved: 6.92 GB
Step 2 (after forward pass): Allocated: 7.96 GB, Reserved: 9.98 GB
Step 2 (before backward pass): Allocated: 7.96 GB, Reserved: 8.22 GB
Step 2 (after backward pass): Allocated: 4.60 GB, Reserved: 11.84 GB
Step 2 (after training_step): Allocated: 4.60 GB, Reserved: 11.84 GB
0%| | 3/19692 [00:01<2:14:59, 2.43it/s]
Step 3 (before training_step): Allocated: 4.60 GB, Reserved: 6.92 GB
Step 3 (before forward pass): Allocated: 4.60 GB, Reserved: 6.92 GB
Step 3 (after forward pass): Allocated: 7.96 GB, Reserved: 9.98 GB
Step 3 (before backward pass): Allocated: 7.96 GB, Reserved: 8.22 GB
Step 3 (after backward pass): Allocated: 4.60 GB, Reserved: 11.84 GB
Step 3 (after training_step): Allocated: 4.60 GB, Reserved: 11.84 GB
0%| | 4/19692 [00:01<1:56:21, 2.82it/s]
Step 4 (before training_step): Allocated: 4.60 GB, Reserved: 6.92 GB
Step 4 (before forward pass): Allocated: 4.60 GB, Reserved: 6.92 GB
Step 4 (after forward pass): Allocated: 7.96 GB, Reserved: 9.98 GB
Step 4 (before backward pass): Allocated: 7.96 GB, Reserved: 8.22 GB
Step 4 (after backward pass): Allocated: 4.60 GB, Reserved: 11.84 GB
Step 4 (after training_step): Allocated: 4.60 GB, Reserved: 11.84 GB
0%| | 5/19692 [00:01<19:10:53 [187/1936]
Step 5 (before training_step): Allocated: 4.60 GB, Reserved: 6.92 GB
Step 5 (before forward pass): Allocated: 4.60 GB, Reserved: 6.92 GB
Step 5 (after forward pass): Allocated: 7.96 GB, Reserved: 9.98 GB
Step 5 (before backward pass): Allocated: 7.96 GB, Reserved: 8.22 GB
Step 5 (after backward pass): Allocated: 4.60 GB, Reserved: 11.84 GB
Step 5 (after training_step): Allocated: 4.60 GB, Reserved: 11.84 GB
0%| | 6/19692 [00:02<1:39:47, 3.29it/s]
Step 6 (before training_step): Allocated: 4.60 GB, Reserved: 6.92 GB
Step 6 (before forward pass): Allocated: 4.60 GB, Reserved: 6.92 GB
Step 6 (after forward pass): Allocated: 7.96 GB, Reserved: 9.98 GB
Step 6 (before backward pass): Allocated: 7.96 GB, Reserved: 8.22 GB
Step 6 (after backward pass): Allocated: 4.60 GB, Reserved: 11.84 GB
Step 6 (after training_step): Allocated: 4.60 GB, Reserved: 11.84 GB
0%| | 7/19692 [00:02<1:35:48, 3.42it/s]
Step 7 (before training_step): Allocated: 4.60 GB, Reserved: 6.92 GB
Step 7 (before forward pass): Allocated: 4.60 GB, Reserved: 6.92 GB
Step 7 (after forward pass): Allocated: 7.96 GB, Reserved: 9.98 GB
Step 7 (before backward pass): Allocated: 7.96 GB, Reserved: 8.22 GB
Step 7 (after backward pass): Allocated: 4.60 GB, Reserved: 11.84 GB
Step 7 (after training_step): Allocated: 4.60 GB, Reserved: 11.84 GB
0%| | 8/19692 [00:02<1:33:21, 3.51it/s]
Step 8 (before training_step): Allocated: 4.60 GB, Reserved: 6.92 GB
Step 8 (before forward pass): Allocated: 4.60 GB, Reserved: 6.92 GB
Step 8 (after forward pass): Allocated: 7.96 GB, Reserved: 9.98 GB
Step 8 (before backward pass): Allocated: 7.96 GB, Reserved: 8.22 GB
Step 8 (after backward pass): Allocated: 4.60 GB, Reserved: 11.84 GB
Step 8 (after training_step): Allocated: 4.60 GB, Reserved: 11.84 GB
0%| | 9/19692 [00:02<1:31:31, 3.58it/s]
Step 9 (before training_step): Allocated: 4.60 GB, Reserved: 6.92 GB
Step 9 (before forward pass): Allocated: 4.60 GB, Reserved: 6.92 GB
Step 9 (after forward pass): Allocated: 7.96 GB, Reserved: 9.98 GB
Step 9 (before backward pass): Allocated: 7.96 GB, Reserved: 8.22 GB
Step 9 (after backward pass): Allocated: 4.60 GB, Reserved: 11.84 GB
Step 9 (after training_step): Allocated: 4.60 GB, Reserved: 11.84 GB
0%| | 10/19692 [00:03<1:30:35, 3.62it/s]
Step 10 (before training_step): Allocated: 4.60 GB, Reserved: 6.92 GB 19:10:55 [151/1936]
Step 10 (before forward pass): Allocated: 4.60 GB, Reserved: 6.92 GB
Step 10 (after forward pass): Allocated: 7.96 GB, Reserved: 9.98 GB
Step 10 (before backward pass): Allocated: 7.96 GB, Reserved: 8.22 GB
Step 10 (after backward pass): Allocated: 4.60 GB, Reserved: 11.84 GB
Step 10 (after training_step): Allocated: 4.60 GB, Reserved: 11.84 GB
0%| | 11/19692 [00:03<1:29:47, 3.65it/s]
Step 11 (before training_step): Allocated: 4.60 GB, Reserved: 6.92 GB
Step 11 (before forward pass): Allocated: 4.60 GB, Reserved: 6.92 GB
Step 11 (after forward pass): Allocated: 7.96 GB, Reserved: 9.98 GB
Step 11 (before backward pass): Allocated: 7.96 GB, Reserved: 8.22 GB
Step 11 (after backward pass): Allocated: 4.60 GB, Reserved: 11.84 GB
Step 11 (after training_step): Allocated: 4.60 GB, Reserved: 11.84 GB
0%| | 12/19692 [00:03<1:29:18, 3.67it/s]
Step 12 (before training_step): Allocated: 4.60 GB, Reserved: 6.92 GB
Step 12 (before forward pass): Allocated: 4.60 GB, Reserved: 6.92 GB
Step 12 (after forward pass): Allocated: 7.96 GB, Reserved: 9.98 GB
Step 12 (before backward pass): Allocated: 7.96 GB, Reserved: 8.22 GB
Step 12 (after backward pass): Allocated: 4.60 GB, Reserved: 11.84 GB
Step 12 (after training_step): Allocated: 4.60 GB, Reserved: 11.84 GB
0%| | 13/19692 [00:03<1:28:54, 3.69it/s]
Step 13 (before training_step): Allocated: 4.60 GB, Reserved: 6.92 GB
Step 13 (before forward pass): Allocated: 4.60 GB, Reserved: 6.92 GB
Step 13 (after forward pass): Allocated: 7.96 GB, Reserved: 9.98 GB
Step 13 (before backward pass): Allocated: 7.96 GB, Reserved: 8.22 GB
Step 13 (after backward pass): Allocated: 4.60 GB, Reserved: 11.84 GB
Step 13 (after training_step): Allocated: 4.60 GB, Reserved: 11.84 GB
0%| | 14/19692 [00:04<1:28:28, 3.71it/s]
Step 14 (before training_step): Allocated: 4.60 GB, Reserved: 6.92 GB
Step 14 (before forward pass): Allocated: 4.60 GB, Reserved: 6.92 GB
Step 14 (after forward pass): Allocated: 7.96 GB, Reserved: 9.98 GB
Step 14 (before backward pass): Allocated: 7.96 GB, Reserved: 8.22 GB
Step 14 (after backward pass): Allocated: 4.60 GB, Reserved: 11.84 GB
Step 14 (after training_step): Allocated: 4.60 GB, Reserved: 11.84 GB
0%| | 15/19692 [00:04<1:28:15, 3.72it/s]
Step 15 (before training_step): Allocated: 4.60 GB, Reserved: 6.92 GB
Step 15 (before forward pass): Allocated: 4.60 GB, Reserved: 6.92 GB
Step 15 (after forward pass): Allocated: 7.96 GB, Reserved: 9.98 GB
Step 15 (before backward pass): Allocated: 7.96 GB, Reserved: 8.22 GB
Step 15 (after backward pass): Allocated: 4.60 GB, Reserved: 11.84 GB
Step 15 (after training_step): Allocated: 4.60 GB, Reserved: 11.84 GB 19:10:56 [111/1936]
0%| | 16/19692 [00:04<1:28:03, 3.72it/s]
Step 16 (before training_step): Allocated: 4.60 GB, Reserved: 6.92 GB
Step 16 (before forward pass): Allocated: 4.60 GB, Reserved: 6.92 GB
Step 16 (after forward pass): Allocated: 7.96 GB, Reserved: 9.98 GB
Step 16 (before backward pass): Allocated: 7.96 GB, Reserved: 8.22 GB
Step 16 (after backward pass): Allocated: 4.60 GB, Reserved: 11.84 GB
Step 16 (after training_step): Allocated: 4.60 GB, Reserved: 11.84 GB
0%| | 17/19692 [00:05<1:27:57, 3.73it/s]
Step 17 (before training_step): Allocated: 4.60 GB, Reserved: 6.92 GB
Step 17 (before forward pass): Allocated: 4.60 GB, Reserved: 6.92 GB
Step 17 (after forward pass): Allocated: 7.96 GB, Reserved: 9.98 GB
Step 17 (before backward pass): Allocated: 7.96 GB, Reserved: 8.22 GB
Step 17 (after backward pass): Allocated: 4.60 GB, Reserved: 11.84 GB
Step 17 (after training_step): Allocated: 4.60 GB, Reserved: 11.84 GB
0%| | 18/19692 [00:05<1:28:01, 3.72it/s]
Step 18 (before training_step): Allocated: 4.60 GB, Reserved: 6.92 GB
Step 18 (before forward pass): Allocated: 4.60 GB, Reserved: 6.92 GB
Step 18 (after forward pass): Allocated: 7.96 GB, Reserved: 9.98 GB
Step 18 (before backward pass): Allocated: 7.96 GB, Reserved: 8.22 GB
Step 18 (after backward pass): Allocated: 4.60 GB, Reserved: 11.84 GB
Step 18 (after training_step): Allocated: 4.60 GB, Reserved: 11.84 GB
0%| | 19/19692 [00:05<1:27:56, 3.73it/s]
Step 19 (before training_step): Allocated: 4.60 GB, Reserved: 6.92 GB
Step 19 (before forward pass): Allocated: 4.60 GB, Reserved: 6.92 GB
Step 19 (after forward pass): Allocated: 7.96 GB, Reserved: 9.98 GB
Step 19 (before backward pass): Allocated: 7.96 GB, Reserved: 8.22 GB
Step 19 (after backward pass): Allocated: 4.60 GB, Reserved: 11.84 GB
Step 19 (after training_step): Allocated: 4.60 GB, Reserved: 11.84 GB
0%|▏ | 20/19692 [00:05<1:28:04, 3.72it/s]
Step 20 (before training_step): Allocated: 4.60 GB, Reserved: 6.92 GB
Step 20 (before forward pass): Allocated: 4.60 GB, Reserved: 6.92 GB
Step 20 (after forward pass): Allocated: 7.96 GB, Reserved: 9.98 GB
Step 20 (before backward pass): Allocated: 7.96 GB, Reserved: 8.22 GB
Step 20 (after backward pass): Allocated: 4.60 GB, Reserved: 11.84 GB
Step 20 (after training_step): Allocated: 4.60 GB, Reserved: 11.84 GB
0%|▏ | 21/19692 [00:06<1:28:01, 3.72it/s]
Step 21 (before training_step): Allocated: 4.60 GB, Reserved: 6.92 GB
Step 21 (before forward pass): Allocated: 4.60 GB, Reserved: 6.92 GB
Step 21 (after forward pass): Allocated: 7.96 GB, Reserved: 9.98 GB
Step 21 (before backward pass): Allocated: 7.96 GB, Reserved: 8.22 GB 19:10:58 [71/1936]
Step 21 (after backward pass): Allocated: 4.60 GB, Reserved: 11.84 GB
Step 21 (after training_step): Allocated: 4.60 GB, Reserved: 11.84 GB
0%|▏ | 22/19692 [00:06<1:27:54, 3.73it/s]
Step 22 (before training_step): Allocated: 4.60 GB, Reserved: 6.92 GB
Step 22 (before forward pass): Allocated: 4.60 GB, Reserved: 6.92 GB
Step 22 (after forward pass): Allocated: 7.96 GB, Reserved: 9.98 GB
Step 22 (before backward pass): Allocated: 7.96 GB, Reserved: 8.22 GB
Step 22 (after backward pass): Allocated: 4.60 GB, Reserved: 11.84 GB
Step 22 (after training_step): Allocated: 4.60 GB, Reserved: 11.84 GB
0%|▏ | 23/19692 [00:06<1:27:56, 3.73it/s]
Step 23 (before training_step): Allocated: 4.60 GB, Reserved: 6.92 GB
Step 23 (before forward pass): Allocated: 4.60 GB, Reserved: 6.92 GB
Step 23 (after forward pass): Allocated: 7.96 GB, Reserved: 9.98 GB
Step 23 (before backward pass): Allocated: 7.96 GB, Reserved: 8.22 GB
Step 23 (after backward pass): Allocated: 4.60 GB, Reserved: 11.84 GB
Step 23 (after training_step): Allocated: 4.60 GB, Reserved: 11.84 GB
0%|▏ | 24/19692 [00:06<1:27:53, 3.73it/s]
Step 24 (before training_step): Allocated: 4.60 GB, Reserved: 6.92 GB
Step 24 (before forward pass): Allocated: 4.60 GB, Reserved: 6.92 GB
Step 24 (after forward pass): Allocated: 7.96 GB, Reserved: 9.98 GB
Step 24 (before backward pass): Allocated: 7.96 GB, Reserved: 8.22 GB
Step 24 (after backward pass): Allocated: 8.05 GB, Reserved: 18.75 GB
Step 24 (after training_step): Allocated: 8.05 GB, Reserved: 18.75 GB
0%|▏ | 25/19692 [00:07<1:36:17, 3.40it/s]
Step 25 (before training_step): Allocated: 8.05 GB, Reserved: 10.37 GB
Step 25 (before forward pass): Allocated: 8.05 GB, Reserved: 10.37 GB
Step 25 (after forward pass): Allocated: 11.42 GB, Reserved: 13.44 GB
Step 25 (before backward pass): Allocated: 11.42 GB, Reserved: 11.68 GB
Step 25 (after backward pass): Allocated: 8.05 GB, Reserved: 18.76 GB
Step 25 (after training_step): Allocated: 8.05 GB, Reserved: 18.76 GB
0%|▏ | 26/19692 [00:07<1:41:08, 3.24it/s]
Step 26 (before training_step): Allocated: 8.05 GB, Reserved: 10.37 GB
Step 26 (before forward pass): Allocated: 8.05 GB, Reserved: 10.37 GB
Step 26 (after forward pass): Allocated: 11.42 GB, Reserved: 13.44 GB
Step 26 (before backward pass): Allocated: 11.42 GB, Reserved: 11.68 GB
Step 26 (after backward pass): Allocated: 8.05 GB, Reserved: 18.76 GB
Step 26 (after training_step): Allocated: 8.05 GB, Reserved: 18.76 GB
0%|▏ | 27/19692 [00:07<1:44:35, 3.13it/s]
Step 27 (before training_step): Allocated: 8.05 GB, Reserved: 10.37 GB 19:11:00 [32/1936]
Step 27 (before forward pass): Allocated: 8.05 GB, Reserved: 10.37 GB
Step 27 (after forward pass): Allocated: 11.42 GB, Reserved: 13.44 GB
Step 27 (before backward pass): Allocated: 11.42 GB, Reserved: 11.68 GB
Step 27 (after backward pass): Allocated: 8.05 GB, Reserved: 18.76 GB
Step 27 (after training_step): Allocated: 8.05 GB, Reserved: 18.76 GB
0%|▏ | 28/19692 [00:08<1:46:55, 3.07it/s]
Step 28 (before training_step): Allocated: 8.05 GB, Reserved: 10.37 GB
Step 28 (before forward pass): Allocated: 8.05 GB, Reserved: 10.37 GB
Step 28 (after forward pass): Allocated: 11.42 GB, Reserved: 13.44 GB
Step 28 (before backward pass): Allocated: 11.42 GB, Reserved: 11.68 GB
Step 28 (after backward pass): Allocated: 8.05 GB, Reserved: 18.76 GB
Step 28 (after training_step): Allocated: 8.05 GB, Reserved: 18.76 GB
0%|▏ | 29/19692 [00:08<1:48:33, 3.02it/s]
Step 29 (before training_step): Allocated: 8.05 GB, Reserved: 10.37 GB
Step 29 (before forward pass): Allocated: 8.05 GB, Reserved: 10.37 GB
Step 29 (after forward pass): Allocated: 11.42 GB, Reserved: 13.44 GB
Step 29 (before backward pass): Allocated: 11.42 GB, Reserved: 11.68 GB
Step 29 (after backward pass): Allocated: 8.05 GB, Reserved: 18.76 GB
Step 29 (after training_step): Allocated: 8.05 GB, Reserved: 18.76 GB
0%|▏ | 30/19692 [00:08<1:49:37, 2.99it/s]
Step 30 (before training_step): Allocated: 8.05 GB, Reserved: 10.37 GB
Step 30 (before forward pass): Allocated: 8.05 GB, Reserved: 10.37 GB
Step 30 (after forward pass): Allocated: 11.42 GB, Reserved: 13.44 GB
Step 30 (before backward pass): Allocated: 11.42 GB, Reserved: 11.68 GB
Step 30 (after backward pass): Allocated: 8.05 GB, Reserved: 18.76 GB
Step 30 (after training_step): Allocated: 8.05 GB, Reserved: 18.76 GB
0%|▏ | 31/19692 [00:09<1:50:21, 2.97it/s]
Step 31 (before training_step): Allocated: 8.05 GB, Reserved: 10.37 GB
Step 31 (before forward pass): Allocated: 8.05 GB, Reserved: 10.37 GB
Step 31 (after forward pass): Allocated: 11.42 GB, Reserved: 13.44 GB
Step 31 (before backward pass): Allocated: 11.42 GB, Reserved: 11.68 GB
Step 31 (after backward pass): Allocated: 8.05 GB, Reserved: 18.76 GB
Step 31 (after training_step): Allocated: 8.05 GB, Reserved: 18.76 GB
0%|▏ | 32/19692 [00:09<1:50:54, 2.95it/s]
Step 32 (before training_step): Allocated: 8.05 GB, Reserved: 10.37 GB
Step 32 (before forward pass): Allocated: 8.05 GB, Reserved: 10.37 GB
Step 32 (after forward pass): Allocated: 11.42 GB, Reserved: 13.44 GB
Step 32 (before backward pass): Allocated: 11.42 GB, Reserved: 11.68 GB
Step 32 (after backward pass): Allocated: 8.05 GB, Reserved: 18.76 GB
Step 32 (after training_step): Allocated: 8.05 GB, Reserved: 18.76 GB
0%|▏ | 33/19692 [00:10<1:51:31, 2.94it/s]
Step 33 (before training_step): Allocated: 8.05 GB, Reserved: 10.37 GB
Step 33 (before forward pass): Allocated: 8.05 GB, Reserved: 10.37 GB
Step 33 (after forward pass): Allocated: 11.42 GB, Reserved: 13.44 GB
Step 33 (before backward pass): Allocated: 11.42 GB, Reserved: 11.68 GB
Step 33 (after backward pass): Allocated: 8.05 GB, Reserved: 18.76 GB
Step 33 (after training_step): Allocated: 8.05 GB, Reserved: 18.76 GB
0%|▏ | 34/19692 [00:10<1:51:52, 2.93it/s]
Huh. That seemed to help a tiny bit; before the backward pass the allocated and reserved are essentially the same; this is immediately after the new call to empty the cache. But it still went up to 18.76GiB during the backward pass. Oddly enough, the number of iterations per second increased with this -- so, without clearing caches I got 3.4 iterations/second, clearing the cache at the end of each step I got 2.3, and then clearing the cache both at the end and before the backward pass I got 2.9. Weird.
But hang on, surely PyTorch should be able to look after its own caches?
At this point I felt that I was heading in the wrong direction again. I'd been finding some interesting stuff out, but my initial goal with this set of experiments wasn't to dig into the details of PyTorch's caching, but rather to find out what was wrong with my code that meant that it used up so much more VRAM than the DeepSpeed estimation function did. I had discovered that:
- The VRAM usage was made up of "allocated" and "reserved". The latter was all memory on the GPU that PyTotch was using. The allocated amount was the portion of the reserved amount that was in active use for things like parameters, activations, gradients and so on. The remaining portion of the reserved amount appears to be used for caching.
- With a sequence length of 1 and 16-bit parameters, the allocated portion was actually pretty close to the amount estimated. The main issue with that was just that I'd not noticed that the estimation function didn't take a parameter for sequence length.
Now, I'd been fiddling around with inserting cache control code directly into PyTorch's training loop in order to try to bring the reserved amount closer to allocated. But surely that's not something that people normally do; you'd expect PyTorch to be smart enough to manage its own caching and not cache if there's not enough memory to do so.
I decided to see how it behaved with less VRAM. ChatGPT told me that there was a function to limit it:
torch.cuda.set_per_process_memory_fraction(0.5, device=torch.cuda.current_device())
Which does exist.
Looking at the memory stats above for a run with no manual clearing of caches, and (as usual) with a 2048-token sequence length:
| Active memory | 8246 MiB | 14106 MiB | 1653 GiB | 1645 GiB |
| from large pool | 8245 MiB | 14090 MiB | 1648 GiB | 1640 GiB |
| from small pool | 0 MiB | 16 MiB | 5 GiB | 5 GiB |
The peak memory usage was 14.1GiB. Given that the card I was using had 24 GiB VRAM, then that's 0.5875 of VRAM. So I tried running with pretty much that:
torch.cuda.set_per_process_memory_fraction(0.6)
...at the start (no device
specified means "default device").
At iteration 25 I got this error:
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 1.16 GiB.
GPU 0 has a total capacity of 23.68 GiB of which 8.35 GiB is free. Including
non-PyTorch memory, this process has 14.37 GiB memory in use. Of the allocated
memory 12.57 GiB is allocated by PyTorch, and 1.29 GiB is reserved by PyTorch
but unallocated. If reserved but unallocated memory is large try setting
PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See
documentation for Memory Management
(https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
Now that was interesting! That was the same environment variable as I had tried before, but it was setting it to a different value. I decided to set it:
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
..and run it:
0%| | 0/19692 [00:00<?, ?it/s]
Step 0 (before training_step): Allocated: 4.58 GB, Reserved: 5.94 GB
Step 0 (before forward pass): Allocated: 4.58 GB, Reserved: 5.94 GB
Step 0 (after forward pass): Allocated: 7.94 GB, Reserved: 9.92 GB
Step 0 (before backward pass): Allocated: 7.94 GB, Reserved: 9.92 GB
Step 0 (after backward pass): Allocated: 4.59 GB, Reserved: 12.38 GB
Step 0 (after training_step): Allocated: 4.59 GB, Reserved: 12.38 GB
0%| | 1/19692 [00:00<1:45:23, 3.11it/s]
Step 1 (before training_step): Allocated: 4.59 GB, Reserved: 12.38 GB
Step 1 (before forward pass): Allocated: 4.59 GB, Reserved: 12.38 GB
Step 1 (after forward pass): Allocated: 7.95 GB, Reserved: 12.38 GB
Step 1 (before backward pass): Allocated: 7.95 GB, Reserved: 12.38 GB
Step 1 (after backward pass): Allocated: 4.59 GB, Reserved: 12.40 GB
Step 1 (after training_step): Allocated: 4.59 GB, Reserved: 12.40 GB
0%| | 2/19692 [00:00<2:35:49, 2.11it/s]
Step 2 (before training_step): Allocated: 4.59 GB, Reserved: 12.40 GB
Step 2 (before forward pass): Allocated: 4.59 GB, Reserved: 12.40 GB
Step 2 (after forward pass): Allocated: 7.95 GB, Reserved: 12.40 GB
Step 2 (before backward pass): Allocated: 7.95 GB, Reserved: 12.40 GB
Step 2 (after backward pass): Allocated: 4.59 GB, Reserved: 12.40 GB
Step 2 (after training_step): Allocated: 4.59 GB, Reserved: 12.40 GB
0%| | 3/19692 [00:01<1:58:21, 2.77it/s]
Step 3 (before training_step): Allocated: 4.59 GB, Reserved: 12.40 GB
Step 3 (before forward pass): Allocated: 4.59 GB, Reserved: 12.40 GB
Step 3 (after forward pass): Allocated: 7.95 GB, Reserved: 12.40 GB
Step 3 (before backward pass): Allocated: 7.95 GB, Reserved: 12.40 GB
Step 3 (after backward pass): Allocated: 4.59 GB, Reserved: 12.40 GB
Step 3 (after training_step): Allocated: 4.59 GB, Reserved: 12.40 GB
0%| | 4/19692 [00:01<1:40:46, 3.26it/s]
Step 4 (before training_step): Allocated: 4.59 GB, Reserved: 12.40 GB
Step 4 (before forward pass): Allocated: 4.59 GB, Reserved: 12.40 GB
Step 4 (after forward pass): Allocated: 7.95 GB, Reserved: 12.40 GB
Step 4 (before backward pass): Allocated: 7.95 GB, Reserved: 12.40 GB
Step 4 (after backward pass): Allocated: 4.59 GB, Reserved: 12.40 GB
Step 4 (after training_step): Allocated: 4.59 GB, Reserved: 12.40 GB
0%| | 5/19692 [00:01<19:31:32 [257/1879]
Step 5 (before training_step): Allocated: 4.59 GB, Reserved: 12.40 GB
Step 5 (before forward pass): Allocated: 4.59 GB, Reserved: 12.40 GB
Step 5 (after forward pass): Allocated: 7.95 GB, Reserved: 12.40 GB
Step 5 (before backward pass): Allocated: 7.95 GB, Reserved: 12.40 GB
Step 5 (after backward pass): Allocated: 4.59 GB, Reserved: 12.40 GB
Step 5 (after training_step): Allocated: 4.59 GB, Reserved: 12.40 GB
0%| | 6/19692 [00:01<1:25:14, 3.85it/s]
Step 6 (before training_step): Allocated: 4.59 GB, Reserved: 12.40 GB
Step 6 (before forward pass): Allocated: 4.59 GB, Reserved: 12.40 GB
Step 6 (after forward pass): Allocated: 7.95 GB, Reserved: 12.40 GB
Step 6 (before backward pass): Allocated: 7.95 GB, Reserved: 12.40 GB
Step 6 (after backward pass): Allocated: 4.59 GB, Reserved: 12.40 GB
Step 6 (after training_step): Allocated: 4.59 GB, Reserved: 12.40 GB
0%| | 7/19692 [00:02<1:21:25, 4.03it/s]
Step 7 (before training_step): Allocated: 4.59 GB, Reserved: 12.40 GB
Step 7 (before forward pass): Allocated: 4.59 GB, Reserved: 12.40 GB
Step 7 (after forward pass): Allocated: 7.95 GB, Reserved: 12.40 GB
Step 7 (before backward pass): Allocated: 7.95 GB, Reserved: 12.40 GB
Step 7 (after backward pass): Allocated: 4.59 GB, Reserved: 12.40 GB
Step 7 (after training_step): Allocated: 4.59 GB, Reserved: 12.40 GB
0%| | 8/19692 [00:02<1:18:56, 4.16it/s]
Step 8 (before training_step): Allocated: 4.59 GB, Reserved: 12.40 GB
Step 8 (before forward pass): Allocated: 4.59 GB, Reserved: 12.40 GB
Step 8 (after forward pass): Allocated: 7.95 GB, Reserved: 12.40 GB
Step 8 (before backward pass): Allocated: 7.95 GB, Reserved: 12.40 GB
Step 8 (after backward pass): Allocated: 4.59 GB, Reserved: 12.40 GB
Step 8 (after training_step): Allocated: 4.59 GB, Reserved: 12.40 GB
0%| | 9/19692 [00:02<1:17:23, 4.24it/s]
Step 9 (before training_step): Allocated: 4.59 GB, Reserved: 12.40 GB
Step 9 (before forward pass): Allocated: 4.59 GB, Reserved: 12.40 GB
Step 9 (after forward pass): Allocated: 7.95 GB, Reserved: 12.40 GB
Step 9 (before backward pass): Allocated: 7.95 GB, Reserved: 12.40 GB
Step 9 (after backward pass): Allocated: 4.59 GB, Reserved: 12.40 GB
Step 9 (after training_step): Allocated: 4.59 GB, Reserved: 12.40 GB
0%| | 10/19692 [00:02<1:16:11, 4.31it/s]
Step 10 (before training_step): Allocated: 4.59 GB, Reserved: 12.40 GB
Step 10 (before forward pass): Allocated: 4.59 GB, Reserved: 12.40 GB
Step 10 (after forward pass): Allocated: 7.95 GB, Reserved: 12.40 GB
Step 10 (before backward pass): Allocated: 7.95 GB, Reserved: 12.40 GB 19:31:33 [218/1879]
Step 10 (after backward pass): Allocated: 4.59 GB, Reserved: 12.40 GB
Step 10 (after training_step): Allocated: 4.59 GB, Reserved: 12.40 GB
0%| | 11/19692 [00:02<1:15:24, 4.35it/s]
Step 11 (before training_step): Allocated: 4.59 GB, Reserved: 12.40 GB
Step 11 (before forward pass): Allocated: 4.59 GB, Reserved: 12.40 GB
Step 11 (after forward pass): Allocated: 7.95 GB, Reserved: 12.40 GB
Step 11 (before backward pass): Allocated: 7.95 GB, Reserved: 12.40 GB
Step 11 (after backward pass): Allocated: 4.59 GB, Reserved: 12.40 GB
Step 11 (after training_step): Allocated: 4.59 GB, Reserved: 12.40 GB
0%| | 12/19692 [00:03<1:14:50, 4.38it/s]
Step 12 (before training_step): Allocated: 4.59 GB, Reserved: 12.40 GB
Step 12 (before forward pass): Allocated: 4.59 GB, Reserved: 12.40 GB
Step 12 (after forward pass): Allocated: 7.95 GB, Reserved: 12.40 GB
Step 12 (before backward pass): Allocated: 7.95 GB, Reserved: 12.40 GB
Step 12 (after backward pass): Allocated: 4.59 GB, Reserved: 12.40 GB
Step 12 (after training_step): Allocated: 4.59 GB, Reserved: 12.40 GB
0%| | 13/19692 [00:03<1:14:33, 4.40it/s]
Step 13 (before training_step): Allocated: 4.59 GB, Reserved: 12.40 GB
Step 13 (before forward pass): Allocated: 4.59 GB, Reserved: 12.40 GB
Step 13 (after forward pass): Allocated: 7.95 GB, Reserved: 12.40 GB
Step 13 (before backward pass): Allocated: 7.95 GB, Reserved: 12.40 GB
Step 13 (after backward pass): Allocated: 4.59 GB, Reserved: 12.40 GB
Step 13 (after training_step): Allocated: 4.59 GB, Reserved: 12.40 GB
0%| | 14/19692 [00:03<1:14:21, 4.41it/s]
Step 14 (before training_step): Allocated: 4.59 GB, Reserved: 12.40 GB
Step 14 (before forward pass): Allocated: 4.59 GB, Reserved: 12.40 GB
Step 14 (after forward pass): Allocated: 7.95 GB, Reserved: 12.40 GB
Step 14 (before backward pass): Allocated: 7.95 GB, Reserved: 12.40 GB
Step 14 (after backward pass): Allocated: 4.59 GB, Reserved: 12.40 GB
Step 14 (after training_step): Allocated: 4.59 GB, Reserved: 12.40 GB
0%| | 15/19692 [00:03<1:14:15, 4.42it/s]
Step 15 (before training_step): Allocated: 4.59 GB, Reserved: 12.40 GB
Step 15 (before forward pass): Allocated: 4.59 GB, Reserved: 12.40 GB
Step 15 (after forward pass): Allocated: 7.95 GB, Reserved: 12.40 GB
Step 15 (before backward pass): Allocated: 7.95 GB, Reserved: 12.40 GB
Step 15 (after backward pass): Allocated: 4.59 GB, Reserved: 12.40 GB
Step 15 (after training_step): Allocated: 4.59 GB, Reserved: 12.40 GB
0%| | 16/19692 [00:04<1:14:03, 4.43it/s]
Step 16 (before training_step): Allocated: 4.59 GB, Reserved: 12.40 GB
Step 16 (before forward pass): Allocated: 4.59 GB, Reserved: 12.40 GB 19:31:35 [178/1879]
Step 16 (after forward pass): Allocated: 7.95 GB, Reserved: 12.40 GB
Step 16 (before backward pass): Allocated: 7.95 GB, Reserved: 12.40 GB
Step 16 (after backward pass): Allocated: 4.59 GB, Reserved: 12.40 GB
Step 16 (after training_step): Allocated: 4.59 GB, Reserved: 12.40 GB
0%| | 17/19692 [00:04<1:14:00, 4.43it/s]
Step 17 (before training_step): Allocated: 4.59 GB, Reserved: 12.40 GB
Step 17 (before forward pass): Allocated: 4.59 GB, Reserved: 12.40 GB
Step 17 (after forward pass): Allocated: 7.95 GB, Reserved: 12.40 GB
Step 17 (before backward pass): Allocated: 7.95 GB, Reserved: 12.40 GB
Step 17 (after backward pass): Allocated: 4.59 GB, Reserved: 12.40 GB
Step 17 (after training_step): Allocated: 4.59 GB, Reserved: 12.40 GB
0%| | 18/19692 [00:04<1:13:56, 4.43it/s]
Step 18 (before training_step): Allocated: 4.59 GB, Reserved: 12.40 GB
Step 18 (before forward pass): Allocated: 4.59 GB, Reserved: 12.40 GB
Step 18 (after forward pass): Allocated: 7.95 GB, Reserved: 12.40 GB
Step 18 (before backward pass): Allocated: 7.95 GB, Reserved: 12.40 GB
Step 18 (after backward pass): Allocated: 4.59 GB, Reserved: 12.40 GB
Step 18 (after training_step): Allocated: 4.59 GB, Reserved: 12.40 GB
0%| | 19/19692 [00:04<1:14:00, 4.43it/s]
Step 19 (before training_step): Allocated: 4.59 GB, Reserved: 12.40 GB
Step 19 (before forward pass): Allocated: 4.59 GB, Reserved: 12.40 GB
Step 19 (after forward pass): Allocated: 7.95 GB, Reserved: 12.40 GB
Step 19 (before backward pass): Allocated: 7.95 GB, Reserved: 12.40 GB
Step 19 (after backward pass): Allocated: 4.59 GB, Reserved: 12.40 GB
Step 19 (after training_step): Allocated: 4.59 GB, Reserved: 12.40 GB
0%|▏ | 20/19692 [00:04<1:13:57, 4.43it/s]
Step 20 (before training_step): Allocated: 4.59 GB, Reserved: 12.40 GB
Step 20 (before forward pass): Allocated: 4.59 GB, Reserved: 12.40 GB
Step 20 (after forward pass): Allocated: 7.95 GB, Reserved: 12.40 GB
Step 20 (before backward pass): Allocated: 7.95 GB, Reserved: 12.40 GB
Step 20 (after backward pass): Allocated: 4.59 GB, Reserved: 12.40 GB
Step 20 (after training_step): Allocated: 4.59 GB, Reserved: 12.40 GB
0%|▏ | 21/19692 [00:05<1:13:53, 4.44it/s]
Step 21 (before training_step): Allocated: 4.59 GB, Reserved: 12.40 GB
Step 21 (before forward pass): Allocated: 4.59 GB, Reserved: 12.40 GB
Step 21 (after forward pass): Allocated: 7.95 GB, Reserved: 12.40 GB
Step 21 (before backward pass): Allocated: 7.95 GB, Reserved: 12.40 GB
Step 21 (after backward pass): Allocated: 4.59 GB, Reserved: 12.40 GB
Step 21 (after training_step): Allocated: 4.59 GB, Reserved: 12.40 GB
0%|▏ | 22/19692 [00:05<19:31:36 [138/1879]
Step 22 (before training_step): Allocated: 4.59 GB, Reserved: 12.40 GB
Step 22 (before forward pass): Allocated: 4.59 GB, Reserved: 12.40 GB
Step 22 (after forward pass): Allocated: 7.95 GB, Reserved: 12.40 GB
Step 22 (before backward pass): Allocated: 7.95 GB, Reserved: 12.40 GB
Step 22 (after backward pass): Allocated: 4.59 GB, Reserved: 12.40 GB
Step 22 (after training_step): Allocated: 4.59 GB, Reserved: 12.40 GB
0%|▏ | 23/19692 [00:05<1:13:58, 4.43it/s]
Step 23 (before training_step): Allocated: 4.59 GB, Reserved: 12.40 GB
Step 23 (before forward pass): Allocated: 4.59 GB, Reserved: 12.40 GB
Step 23 (after forward pass): Allocated: 7.95 GB, Reserved: 12.40 GB
Step 23 (before backward pass): Allocated: 7.95 GB, Reserved: 12.40 GB
Step 23 (after backward pass): Allocated: 4.59 GB, Reserved: 12.40 GB
Step 23 (after training_step): Allocated: 4.59 GB, Reserved: 12.40 GB
0%|▏ | 24/19692 [00:05<1:13:53, 4.44it/s]
Step 24 (before training_step): Allocated: 4.59 GB, Reserved: 12.40 GB
Step 24 (before forward pass): Allocated: 4.59 GB, Reserved: 12.40 GB
Step 24 (after forward pass): Allocated: 7.95 GB, Reserved: 12.40 GB
Step 24 (before backward pass): Allocated: 7.95 GB, Reserved: 12.40 GB
[2024-06-16 19:31:36,035] [WARNING] [stage3.py:2069:step] 1 pytorch allocator cache flushes since last step. this happens when there is high memory pressure and is detrim
ental to performance. if this is happening frequently consider adjusting settings to reduce memory consumption. If you are unable to make the cache flushes go away consid
er adding get_accelerator().empty_cache() calls in your training loop to ensure that all ranks flush their caches at the same time
Step 24 (after backward pass): Allocated: 8.05 GB, Reserved: 12.03 GB
Step 24 (after training_step): Allocated: 8.05 GB, Reserved: 12.03 GB
0%|▏ | 25/19692 [00:06<1:28:27, 3.71it/s]
Step 25 (before training_step): Allocated: 8.05 GB, Reserved: 12.03 GB
Step 25 (before forward pass): Allocated: 8.05 GB, Reserved: 12.03 GB
Step 25 (after forward pass): Allocated: 11.40 GB, Reserved: 13.40 GB
Step 25 (before backward pass): Allocated: 11.40 GB, Reserved: 13.40 GB
[2024-06-16 19:31:36,540] [WARNING] [stage3.py:2069:step] 3 pytorch allocator cache flushes since last step. this happens when there is high memory pressure and is detrim
ental to performance. if this is happening frequently consider adjusting settings to reduce memory consumption. If you are unable to make the cache flushes go away consid
er adding get_accelerator().empty_cache() calls in your training loop to ensure that all ranks flush their caches at the same time
Step 25 (after backward pass): Allocated: 8.05 GB, Reserved: 12.03 GB
Step 25 (after training_step): Allocated: 8.05 GB, Reserved: 12.03 GB
0%|▏ | 26/19692 [00:06<1:51:34, 2.94it/s]
Step 26 (before training_step): Allocated: 8.05 GB, Reserved: 12.03 GB
Step 26 (before forward pass): Allocated: 8.05 GB, Reserved: 12.03 GB
Step 26 (after forward pass): Allocated: 11.40 GB, Reserved: 13.40 GB
Step 26 (before backward pass): Allocated: 11.40 GB, Reserved: 13.40 GB
[2024-06-16 19:31:37,045] [WARNING] [stage3.py:2069:step] 3 pytorch allocator cache flushes since last step. this happens when there is high memory pres19:31:38 [99/1879]
ental to performance. if this is happening frequently consider adjusting settings to reduce memory consumption. If you are unable to make the cache flushes go away consid
er adding get_accelerator().empty_cache() calls in your training loop to ensure that all ranks flush their caches at the same time
Step 26 (after backward pass): Allocated: 8.05 GB, Reserved: 12.03 GB
Step 26 (after training_step): Allocated: 8.05 GB, Reserved: 12.03 GB
0%|▏ | 27/19692 [00:07<2:07:42, 2.57it/s]
Step 27 (before training_step): Allocated: 8.05 GB, Reserved: 12.03 GB
Step 27 (before forward pass): Allocated: 8.05 GB, Reserved: 12.03 GB
Step 27 (after forward pass): Allocated: 11.40 GB, Reserved: 13.40 GB
Step 27 (before backward pass): Allocated: 11.40 GB, Reserved: 13.40 GB
[2024-06-16 19:31:37,550] [WARNING] [stage3.py:2069:step] 3 pytorch allocator cache flushes since last step. this happens when there is high memory pressure and is detrim
ental to performance. if this is happening frequently consider adjusting settings to reduce memory consumption. If you are unable to make the cache flushes go away consid
er adding get_accelerator().empty_cache() calls in your training loop to ensure that all ranks flush their caches at the same time
Step 27 (after backward pass): Allocated: 8.05 GB, Reserved: 12.03 GB
Step 27 (after training_step): Allocated: 8.05 GB, Reserved: 12.03 GB
0%|▏ | 28/19692 [00:07<2:19:03, 2.36it/s]
Step 28 (before training_step): Allocated: 8.05 GB, Reserved: 12.03 GB
Step 28 (before forward pass): Allocated: 8.05 GB, Reserved: 12.03 GB
Step 28 (after forward pass): Allocated: 11.40 GB, Reserved: 13.40 GB
Step 28 (before backward pass): Allocated: 11.40 GB, Reserved: 13.40 GB
[2024-06-16 19:31:38,055] [WARNING] [stage3.py:2069:step] 3 pytorch allocator cache flushes since last step. this happens when there is high memory pressure and is detrim
ental to performance. if this is happening frequently consider adjusting settings to reduce memory consumption. If you are unable to make the cache flushes go away consid
er adding get_accelerator().empty_cache() calls in your training loop to ensure that all ranks flush their caches at the same time
Step 28 (after backward pass): Allocated: 8.05 GB, Reserved: 12.03 GB
Step 28 (after training_step): Allocated: 8.05 GB, Reserved: 12.03 GB
0%|▏ | 29/19692 [00:08<2:26:58, 2.23it/s]
Step 29 (before training_step): Allocated: 8.05 GB, Reserved: 12.03 GB
Step 29 (before forward pass): Allocated: 8.05 GB, Reserved: 12.03 GB
Step 29 (after forward pass): Allocated: 11.40 GB, Reserved: 13.40 GB
Step 29 (before backward pass): Allocated: 11.40 GB, Reserved: 13.40 GB
[2024-06-16 19:31:38,559] [WARNING] [stage3.py:2069:step] 3 pytorch allocator cache flushes since last step. this happens when there is high memory pressure and is detrim
ental to performance. if this is happening frequently consider adjusting settings to reduce memory consumption. If you are unable to make the cache flushes go away consid
er adding get_accelerator().empty_cache() calls in your training loop to ensure that all ranks flush their caches at the same time
Step 29 (after backward pass): Allocated: 8.05 GB, Reserved: 12.03 GB
Step 29 (after training_step): Allocated: 8.05 GB, Reserved: 12.03 GB
0%|▏ | 30/19692 [00:08<2:32:25, 2.15it/s]
Step 30 (before training_step): Allocated: 8.05 GB, Reserved: 12.03 GB
Step 30 (before forward pass): Allocated: 8.05 GB, Reserved: 12.03 GB
Step 30 (after forward pass): Allocated: 11.40 GB, Reserved: 13.40 GB
Step 30 (before backward pass): Allocated: 11.40 GB, Reserved: 13.40 GB
[2024-06-16 19:31:39,063] [WARNING] [stage3.py:2069:step] 3 pytorch allocator cache flushes since last step. this happens when there is high memory pres19:31:40 [59/1879]
ental to performance. if this is happening frequently consider adjusting settings to reduce memory consumption. If you are unable to make the cache flushes go away consid
er adding get_accelerator().empty_cache() calls in your training loop to ensure that all ranks flush their caches at the same time
Step 30 (after backward pass): Allocated: 8.05 GB, Reserved: 12.03 GB
Step 30 (after training_step): Allocated: 8.05 GB, Reserved: 12.03 GB
0%|▏ | 31/19692 [00:09<2:36:15, 2.10it/s]
Step 31 (before training_step): Allocated: 8.05 GB, Reserved: 12.03 GB
Step 31 (before forward pass): Allocated: 8.05 GB, Reserved: 12.03 GB
Step 31 (after forward pass): Allocated: 11.40 GB, Reserved: 13.40 GB
Step 31 (before backward pass): Allocated: 11.40 GB, Reserved: 13.40 GB
[2024-06-16 19:31:39,569] [WARNING] [stage3.py:2069:step] 3 pytorch allocator cache flushes since last step. this happens when there is high memory pressure and is detrim
ental to performance. if this is happening frequently consider adjusting settings to reduce memory consumption. If you are unable to make the cache flushes go away consid
er adding get_accelerator().empty_cache() calls in your training loop to ensure that all ranks flush their caches at the same time
Step 31 (after backward pass): Allocated: 8.05 GB, Reserved: 12.03 GB
Step 31 (after training_step): Allocated: 8.05 GB, Reserved: 12.03 GB
0%|▏ | 32/19692 [00:09<2:39:02, 2.06it/s]
Step 32 (before training_step): Allocated: 8.05 GB, Reserved: 12.03 GB
Step 32 (before forward pass): Allocated: 8.05 GB, Reserved: 12.03 GB
Step 32 (after forward pass): Allocated: 11.40 GB, Reserved: 13.40 GB
Step 32 (before backward pass): Allocated: 11.40 GB, Reserved: 13.40 GB
[2024-06-16 19:31:40,073] [WARNING] [stage3.py:2069:step] 3 pytorch allocator cache flushes since last step. this happens when there is high memory pressure and is detrim
ental to performance. if this is happening frequently consider adjusting settings to reduce memory consumption. If you are unable to make the cache flushes go away consid
er adding get_accelerator().empty_cache() calls in your training loop to ensure that all ranks flush their caches at the same time
Step 32 (after backward pass): Allocated: 8.05 GB, Reserved: 12.03 GB
Step 32 (after training_step): Allocated: 8.05 GB, Reserved: 12.03 GB
0%|▏ | 33/19692 [00:10<2:40:53, 2.04it/s]
It worked! It appeared that PyTorch caches are not a problem, so long as you set that environment variable. But then, why wasn't it set by default?
You can see some warnings cutting in at iteration 24; with the appropriate settings, PyTorch won't use VRAM for caches if there isn't enough VRAM, but it will complain that it's not got enough to run efficiently. That's totally reasonable.
One final test; previously I'd determined that with a sequence length of 10 tokens,
VRAM usage was essentially what was predicted by the DeepSpeed helper function.
So what would happen if I set sequence size to 10, and VRAM usage to the predicted
size, 8.36 GiB? with 24 GiB VRAM, that should be 0.355 in the call to
set_per_process_memory_fraction
.
I tried it, and interestingly enough, it didn't work -- a CUDA out-of-memory error. Fiddling around with the number about determined that 0.51 was enough -- just over 12 GiB. Running with that:
Step 43 (before training_step): Allocated: 8.05 GB, Reserved: 12.01 GB
Step 43 (before forward pass): Allocated: 8.05 GB, Reserved: 12.01 GB
Step 43 (after forward pass): Allocated: 8.06 GB, Reserved: 12.03 GB
Step 43 (before backward pass): Allocated: 8.06 GB, Reserved: 12.03 GB
[2024-06-16 19:55:21,591] [WARNING] [stage3.py:2069:step] 3 pytorch allocator cache flushes since last step. this happens when there is high memory pressure and is detrim
ental to performance. if this is happening frequently consider adjusting settings to reduce memory consumption. If you are unable to make the cache flushes go away consid
er adding get_accelerator().empty_cache() calls in your training loop to ensure that all ranks flush their caches at the same time
Step 43 (after backward pass): Allocated: 8.05 GB, Reserved: 12.01 GB
Step 43 (after training_step): Allocated: 8.05 GB, Reserved: 12.01 GB
Sadly I've lost the code I used for that so I don't know if it was running in 32-bit, PyTorch-managed 16-bit, or DeepSpeed-managed 16-bit, but it's still reasonably close.
As a sanity-check, I decided to try try that memory limit with 2048 token sequences -- as I expected, it ran out of memory at iteration 25, which was a relief: it would have broken my entire mental model if it had worked.
Time to put an end to this madness
At this point I decided to wrap up this set of somewhat disorganised exploratory experiments.
The question I wanted to answer was "why does that DeepSpeed script that estimates memory usage report that I need X GiB of VRAM to train my model, but in reality I need more than that?"
What I learned is that [later: note that I had to re-evaluate much of this in the next post, so there are mistakes here]:
- There are two kinds of memory usage; "allocated" and "reserved". Reserved is a superset of allocated.
- The difference between the two appeared to be explained by PyTorch caches.
- PyTorch is smart enough that if there is not enough VRAM for the caches, it will purge them at appropriate times, though you need to set an environment variable to tell it that doing so is OK, and obviously it will run more slowly. It did seem odd that this wasn't the default, though.
- The sequence length matters a lot when you're talking about memory usage. This is obvious in retrospect, as I had to adjust batch size downwards to fit into memory when I increased the number of tokens back in one of my earlier posts.
- Even with trivially short sequence lengths, the caches use up quite a lot of VRAM.
My underlying question was, can I run that fine-tune of the larger 8B model on the multi-GPU machine without having to offload the optimizer? The answer to that is still uncertain. The DeepSpeed script was saying that I'd need 17.68 GiB per GPU, but that isn't allowing for the sequence length, nor for caches.
I think it would be interesting to try that run again with
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
...to try to keep the cache usage under control. And, of course, it would be really interesting to try to work out how to calculate the effects of sequence lengths on the VRAM required. I think the second of those might be a good next step.