You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Implement debug mode for GarbageCollection (#1230)
Summary:
When the debug mode is turned on, 1) warn_tensor_cycles() will be called
on rank0 and 2) gc.collect() will be called every iteration to
understand the possible memory (tensor) leakage.
Reference: https://pytorch.org/blog/understanding-gpu-memory-2/
The current TorchTitan shows memory leakage:
```
CONFIG_FILE="torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --training.steps=20
```
```
[rank0]:[titan] 2025-05-27 20:44:18,824 - root - INFO - step: 1 loss: 12.2721 memory: 42.16GiB(44.38%) tps: 1,677 tflops: 97.12 mfu: 9.82%
[rank0]:[titan] 2025-05-27 20:44:18,824 - root - INFO - Synchronizing and adjusting timeout for
all ProcessGroups to 0:01:40
[rank0]:[titan] 2025-05-27 20:44:33,456 - root - INFO - step: 10 loss: 10.0301 memory: 68.72GiB(72.34%) tps: 5,040 tflops: 291.86 mfu: 29.51%
[rank0]:[titan] 2025-05-27 20:44:48,141 - root - INFO - step: 20 loss: 8.4547 memory: 90.29GiB(95.03%) tps: 5,579 tflops: 323.12 mfu: 32.67%
[rank0]:[titan] 2025-05-27 20:44:48,150 - root - INFO - Sleeping 2 seconds for other ranks to complete
[rank0]:[titan] 2025-05-27 20:44:50,152 - root - INFO - Training completed
[rank0]:[titan] 2025-05-27 20:44:50,569 - root - INFO - Process group destroyed.
```
With this PR, we can use the following command to debug
```
CONFIG_FILE="torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --training.steps=20 --training.gc_debug
```
```
[rank0]:[titan] 2025-05-27 20:49:04,858 - root - INFO - Force GC to perform collection to get the debug information.
[rank0]:[rank0]:W0527 20:49:05.414000 2423031 torch/utils/viz/_cycles.py:498] Reference cycle includes a CUDA Tensor see visualization of cycle /tmp/tmp9oginps8.html
[rank0]:[rank0]:W0527 20:49:05.687000 2423031 torch/utils/viz/_cycles.py:59] CUDA Memory changed during GC, 2147483648 bytes freed.
[rank0]:[titan] 2025-05-27 20:49:07,157 - root - INFO - step: 20 loss: 8.3943 memory: 49.66GiB(52.27%) tps: 3,573 tflops: 206.93 mfu: 20.92%
[rank0]:[titan] 2025-05-27 20:49:07,167 - root - INFO - Sleeping 2 seconds for other ranks to complete
[rank0]:[titan] 2025-05-27 20:49:09,169 - root - INFO - Training completed
[rank0]:[titan] 2025-05-27 20:49:10,198 - root - INFO - Process group destroyed.
```
`warn_tensor_cycles()` shows that 1) there are reference cycles that
include CUDA tensors, 2) 2GB GPU memory is freed when `gc.collect()` is
called. And the visualization shows the reference cycle seems to be from
activation checkpointing.
<img width="1597" alt="Screenshot 2025-05-27 at 8 52 00 PM"
src="https://github.com/user-attachments/assets/2e241baa-16fe-4e87-acce-fb72710babc2"
/>
0 commit comments