Skip to content

Commit f8f545c

Browse files
authored
Implement debug mode for GarbageCollection (#1230)
Summary: When the debug mode is turned on, 1) warn_tensor_cycles() will be called on rank0 and 2) gc.collect() will be called every iteration to understand the possible memory (tensor) leakage. Reference: https://pytorch.org/blog/understanding-gpu-memory-2/ The current TorchTitan shows memory leakage: ``` CONFIG_FILE="torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --training.steps=20 ``` ``` [rank0]:[titan] 2025-05-27 20:44:18,824 - root - INFO - step: 1 loss: 12.2721 memory: 42.16GiB(44.38%) tps: 1,677 tflops: 97.12 mfu: 9.82% [rank0]:[titan] 2025-05-27 20:44:18,824 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:[titan] 2025-05-27 20:44:33,456 - root - INFO - step: 10 loss: 10.0301 memory: 68.72GiB(72.34%) tps: 5,040 tflops: 291.86 mfu: 29.51% [rank0]:[titan] 2025-05-27 20:44:48,141 - root - INFO - step: 20 loss: 8.4547 memory: 90.29GiB(95.03%) tps: 5,579 tflops: 323.12 mfu: 32.67% [rank0]:[titan] 2025-05-27 20:44:48,150 - root - INFO - Sleeping 2 seconds for other ranks to complete [rank0]:[titan] 2025-05-27 20:44:50,152 - root - INFO - Training completed [rank0]:[titan] 2025-05-27 20:44:50,569 - root - INFO - Process group destroyed. ``` With this PR, we can use the following command to debug ``` CONFIG_FILE="torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --training.steps=20 --training.gc_debug ``` ``` [rank0]:[titan] 2025-05-27 20:49:04,858 - root - INFO - Force GC to perform collection to get the debug information. [rank0]:[rank0]:W0527 20:49:05.414000 2423031 torch/utils/viz/_cycles.py:498] Reference cycle includes a CUDA Tensor see visualization of cycle /tmp/tmp9oginps8.html [rank0]:[rank0]:W0527 20:49:05.687000 2423031 torch/utils/viz/_cycles.py:59] CUDA Memory changed during GC, 2147483648 bytes freed. [rank0]:[titan] 2025-05-27 20:49:07,157 - root - INFO - step: 20 loss: 8.3943 memory: 49.66GiB(52.27%) tps: 3,573 tflops: 206.93 mfu: 20.92% [rank0]:[titan] 2025-05-27 20:49:07,167 - root - INFO - Sleeping 2 seconds for other ranks to complete [rank0]:[titan] 2025-05-27 20:49:09,169 - root - INFO - Training completed [rank0]:[titan] 2025-05-27 20:49:10,198 - root - INFO - Process group destroyed. ``` `warn_tensor_cycles()` shows that 1) there are reference cycles that include CUDA tensors, 2) 2GB GPU memory is freed when `gc.collect()` is called. And the visualization shows the reference cycle seems to be from activation checkpointing. <img width="1597" alt="Screenshot 2025-05-27 at 8 52 00 PM" src="https://github.com/user-attachments/assets/2e241baa-16fe-4e87-acce-fb72710babc2" />
1 parent 594a120 commit f8f545c

File tree

3 files changed

+25
-7
lines changed

3 files changed

+25
-7
lines changed

torchtitan/config_manager.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,14 @@ class Training:
224224
gc_freq: int = 50
225225
"""Python garbage control scheduling interval, in steps"""
226226

227+
gc_debug: bool = False
228+
"""
229+
Enable GC debugging mode. This will perform gc.collect() at every step to
230+
detect if there is a reference cycle that includes a CUDA Tensor.
231+
Note that you may want to lower the training steps to avoid generating too
232+
many temporary files.
233+
"""
234+
227235
seed: int | None = None
228236
"""Choose the base RNG seed used for training"""
229237

@@ -625,7 +633,6 @@ def parse_args(self, args: list[str] = sys.argv[1:]) -> JobConfig:
625633
return self.config
626634

627635
def _maybe_load_toml(self, args: list[str]) -> dict[str, Any] | None:
628-
629636
# 1. Check CLI
630637
valid_keys = {"--job.config-file", "--job.config_file"}
631638
for i, arg in enumerate(args):

torchtitan/tools/utils.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,14 +36,23 @@ def get_device_info():
3636

3737
# used to avoid stragglers in garbage collection
3838
class GarbageCollection:
39-
def __init__(self, gc_freq=1000):
39+
def __init__(self, gc_freq: int = 1000, debug: bool = False):
4040
assert gc_freq > 0, "gc_freq must be a positive integer"
4141
self.gc_freq = gc_freq
42+
self.debug = debug
4243
gc.disable()
4344
self.collect("Initial GC collection.")
45+
if debug:
46+
from torch.utils.viz._cycles import warn_tensor_cycles
4447

45-
def run(self, step_count):
46-
if step_count > 1 and step_count % self.gc_freq == 0:
48+
if torch.distributed.get_rank() == 0:
49+
warn_tensor_cycles()
50+
51+
def run(self, step_count: int):
52+
if self.debug:
53+
logger.info("Force GC to perform collection to obtain debug information.")
54+
gc.collect()
55+
elif step_count > 1 and step_count % self.gc_freq == 0:
4756
self.collect("Peforming periodical GC collection.")
4857

4958
@staticmethod

torchtitan/train.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,6 @@ def __init__(self, job_config: JobConfig):
7070
if job_config.job.print_args:
7171
logger.info(f"Running with args: {job_config.to_dict()}")
7272

73-
# take control of garbage collection to avoid stragglers
74-
self.gc_handler = utils.GarbageCollection(gc_freq=job_config.training.gc_freq)
75-
7673
device_module, device_type = utils.device_module, utils.device_type
7774
self.device = torch.device(f"{device_type}:{int(os.environ['LOCAL_RANK'])}")
7875
# Device has to be set before creating TorchFT manager.
@@ -106,6 +103,11 @@ def __init__(self, job_config: JobConfig):
106103
if self.ft_manager.enabled:
107104
dp_degree, dp_rank = self.ft_manager.get_dp_info(dp_degree, dp_rank)
108105

106+
# take control of garbage collection to avoid stragglers
107+
self.gc_handler = utils.GarbageCollection(
108+
gc_freq=job_config.training.gc_freq, debug=job_config.training.gc_debug
109+
)
110+
109111
# Set random seed, and maybe enable deterministic mode
110112
# (mainly for debugging, expect perf loss).
111113
dist_utils.set_determinism(

0 commit comments

Comments
 (0)