Skip to content

Commit 6e85e8a

Browse files
committed
Tutorial for DebugMode
1 parent 7f8b6dc commit 6e85e8a

File tree

1 file changed

+234
-0
lines changed

1 file changed

+234
-0
lines changed
Lines changed: 234 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,234 @@
1+
# -*- coding: utf-8 -*-
2+
3+
"""
4+
DebugMode: Recording Dispatched Operations and Numerical Debugging
5+
=================================================================
6+
7+
**Authors:** Pian Pawakapan, Shangdi Yu
8+
"""
9+
10+
######################################################################
11+
# Overview
12+
# --------
13+
#
14+
# ``DebugMode`` (:class:`torch.utils._debug_mode.DebugMode`) is a
15+
# ``TorchDispatchMode`` that intercepts PyTorch runtime calls and emits a
16+
# hierarchical log of operations. It is particularly useful when you need to
17+
# understand *what* actually ran, both in eager mode and under ``torch.compile``
18+
# or when you need to pinpoint numerical divergence between two runs.
19+
#
20+
# Key capabilities:
21+
#
22+
# * **Runtime logging** – Records dispatched operations and TorchInductor compiled
23+
# Triton kernels.
24+
# * **Tensor hashing** – Attaches deterministic hashes to inputs/outputs so you
25+
# can diff runs and locate numerical divergences.
26+
# * **Dispatch hooks** – Lets you register custom hooks to annotate each call
27+
28+
######################################################################
29+
# Quick start
30+
# -----------
31+
#
32+
# The snippet below captures a small eager workload and prints the debug string:
33+
34+
import torch
35+
from torch.utils._debug_mode import DebugMode
36+
37+
def run_once():
38+
x = torch.randn(8, 8)
39+
y = torch.randn(8, 8)
40+
return torch.mm(torch.relu(x), y)
41+
42+
with DebugMode() as debug_mode:
43+
out = run_once()
44+
45+
print("DebugMode output:")
46+
print(debug_mode.debug_string())
47+
48+
49+
######################################################################
50+
# Getting more metadata
51+
# -----------
52+
#
53+
# For most investigations, you'll want to enable stack traces, tensor IDs, and tensor hashing.
54+
# These features provide metadata to correlate operations back to model code.
55+
#
56+
# ``DebugMode.log_tensor_hashes`` decorates the log with hashes for every call.
57+
# The ``hash_tensor`` hash function uses ``torch.hash_tensor``, which returns 0 for tensors whose
58+
# elements are all the same. The ``norm`` hash function uses ``norm`` with ``p=1``.
59+
60+
with (
61+
DebugMode(
62+
record_output=True,
63+
record_stack_trace=True,
64+
record_ids=True,
65+
) as debug_mode,
66+
DebugMode.log_tensor_hashes(
67+
hash_fn=["norm", "hash_tensor"],
68+
hash_inputs=True,
69+
),
70+
):
71+
result = run_once()
72+
73+
print("DebugMode output with more metadata:")
74+
print(
75+
debug_mode.debug_string(show_stack_trace=True)
76+
)
77+
78+
######################################################################
79+
# Interpreting the log
80+
# --------------------
81+
#
82+
# Each line follows ``op(args) -> outputs``. When ``record_ids`` is enabled,
83+
# tensors are suffixed with ``$<id>`` and DTensors are labeled ``dt``.
84+
#
85+
# Indentation generally reflects the dynamic call stack, but it's not guaranteed to be the same
86+
# as the call stack at runtime, especially for DTensor calls.
87+
#
88+
# For the tuple in hash, e.g., ``'hash': (25.47251951135695, 9216239975761182720)``, each number corresponds to the hash result
89+
# using the hash function specified in the ``hash_fn`` list.
90+
91+
92+
######################################################################
93+
# Log Triton kernels
94+
# ------------------
95+
#
96+
# Though Triton kernels are not dispatched, DebugMode has custom logic that logs their inputs and outputs.
97+
#
98+
# Inductor-generated Triton kernels show up with a ``[triton]`` prefix.
99+
# Pre/post hash annotations report buffer hashes around each kernel call, which
100+
# is helpful when isolating incorrect kernels.
101+
def f(x):
102+
return torch.mm(torch.relu(x), x.T)
103+
104+
x = torch.randn(3, 3, device="cuda")
105+
106+
with (
107+
DebugMode(record_output=True) as debug_mode,
108+
DebugMode.log_tensor_hashes(
109+
hash_fn=["norm"],
110+
hash_inputs=True,
111+
)
112+
):
113+
a = torch.compile(f)(x)
114+
115+
print("Triton in DebugMode logs:")
116+
print(debug_mode.debug_string())
117+
118+
######################################################################
119+
# Numerical debugging with tensor hashes
120+
# --------------------------------------
121+
#
122+
# If you have some numerical divergence between modes, you can use DebugMode to find where the
123+
# numerical divergence originates.
124+
# In the example below, you can see that all tensor hashes are the same for eager mode and compiled mode.
125+
# If any hash looks different, then that's where the numerical divergence is coming from.
126+
127+
def run_model(model, data, *, compile_with=None):
128+
if compile_with is not None:
129+
model = torch.compile(model, backend=compile_with)
130+
with DebugMode(record_output=True) as dm, DebugMode.log_tensor_hashes(
131+
hash_fn=["norm"],
132+
hash_inputs=True,
133+
):
134+
dm_out = model(*data)
135+
return dm, dm_out
136+
137+
class Toy(torch.nn.Module):
138+
def forward(self, x):
139+
return torch.relu(x).mm(x.T)
140+
141+
inputs = (torch.randn(4, 4),)
142+
dm_eager, _ = run_model(Toy(), inputs)
143+
dm_compiled, _ = run_model(Toy(), inputs, compile_with="aot_eager")
144+
145+
print("Eager mode:")
146+
print(dm_eager.debug_string())
147+
print("Compiled aot_eager mode:")
148+
print(dm_compiled.debug_string())
149+
150+
151+
152+
153+
######################################################################
154+
# Custom dispatch hooks
155+
# ---------------------
156+
#
157+
# Hooks allow you to annotate each call with custom metadata such as GPU memory usage. ``log_hook`` returns a mapping
158+
# that is rendered inline with the debug string.
159+
160+
MB = 1024 * 1024.0
161+
162+
def memory_hook(func, types, args, kwargs, result):
163+
mem = torch.cuda.memory_allocated() / MB if torch.cuda.is_available() else 0.0
164+
peak = torch.cuda.max_memory_allocated() / MB if torch.cuda.is_available() else 0.0
165+
torch.cuda.reset_peak_memory_stats() if torch.cuda.is_available() else None
166+
return {"mem": f"{mem:.3f} MB", "peak": f"{peak:.3f} MB"}
167+
168+
with (
169+
DebugMode() as dm,
170+
DebugMode.dispatch_hooks(log_hook=memory_hook),
171+
):
172+
run_once()
173+
174+
print("DebugMode output with memory usage:")
175+
print(dm.debug_string())
176+
177+
######################################################################
178+
# Module boundaries
179+
# ----------------------------------
180+
#
181+
# ``record_nn_module=True`` inserts ``[nn.Mod]`` markers that show which
182+
# module executed each set of operations. As of PyTorch 2.10 it only works in eager mode,
183+
# but support for compiled modes is under development.
184+
185+
class Foo(torch.nn.Module):
186+
def __init__(self):
187+
super().__init__()
188+
self.l1 = torch.nn.Linear(4, 4)
189+
self.l2 = torch.nn.Linear(4, 4)
190+
191+
def forward(self, x):
192+
return self.l2(self.l1(x))
193+
194+
class Bar(torch.nn.Module):
195+
def __init__(self):
196+
super().__init__()
197+
self.abc = Foo()
198+
self.xyz = torch.nn.Linear(4, 4)
199+
200+
def forward(self, x):
201+
return self.xyz(self.abc(x))
202+
203+
mod = Bar()
204+
inp = torch.randn(4, 4)
205+
with DebugMode(record_nn_module=True, record_output=False) as debug_mode:
206+
_ = mod(inp)
207+
208+
print("DebugMode output with stack traces and module boundaries:")
209+
print(debug_mode.debug_string(show_stack_trace=True))
210+
211+
######################################################################
212+
# Annotation
213+
# ----------------------------------
214+
#
215+
# You can insert annotations in DebugMode logs by calling ``DebugMode._annotate``
216+
217+
x = torch.randn(8, 8)
218+
219+
class Foo(torch.nn.Module):
220+
def __init__(self):
221+
super().__init__()
222+
self.l1 = torch.nn.Linear(8, 8)
223+
224+
def forward(self, x):
225+
DebugMode._annotate("Foo")
226+
return self.l1(x)
227+
228+
mod = Foo()
229+
with DebugMode(record_nn_module=True) as debug_mode:
230+
DebugMode._annotate("forward")
231+
mod(x)
232+
233+
print("DebugMode output with annotation:")
234+
print(debug_mode.debug_string())

0 commit comments

Comments
 (0)