Skip to content

Commit 3c67013

Browse files
committed
Tutorial for DebugMode
1 parent 7f8b6dc commit 3c67013

File tree

1 file changed

+265
-0
lines changed

1 file changed

+265
-0
lines changed
Lines changed: 265 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,265 @@
1+
# -*- coding: utf-8 -*-
2+
3+
"""
4+
DebugMode: Recording Dispatched Operations and Numerical Debugging
5+
=================================================================
6+
7+
**Authors:** Pian Pawakapan, Shangdi Yu
8+
"""
9+
10+
######################################################################
11+
# Overview
12+
# --------
13+
#
14+
# ``DebugMode`` (:class:`torch.utils._debug_mode.DebugMode`) is a
15+
# ``TorchDispatchMode`` that intercepts PyTorch runtime calls and emits a
16+
# hierarchical log of operations. It is particularly useful when you need to
17+
# understand *what* actually runs, both in eager mode and under ``torch.compile``
18+
# or when you need to pinpoint numerical divergence between two runs.
19+
#
20+
# Key capabilities:
21+
#
22+
# * **Runtime logging** – Records dispatched operations and TorchInductor compiled
23+
# Triton kernels.
24+
# * **Tensor hashing** – Attaches deterministic hashes to inputs/outputs to enable
25+
# diffing runs to locate numerical divergences.
26+
# * **Dispatch hooks** – Allows registration of custom hooks to annotate calls
27+
#
28+
# .. note::
29+
#
30+
# This recipe describes a prototype feature. Prototype features are typically
31+
# at an early stage for feedback and testing and are subject to change.
32+
#
33+
34+
######################################################################
35+
# Quick start
36+
# -----------
37+
#
38+
# The snippet below captures a small eager workload and prints the debug string:
39+
40+
from torch._inductor.decomposition import decomps_to_exclude
41+
import torch
42+
from torch.utils._debug_mode import DebugMode
43+
44+
def run_once():
45+
x = torch.randn(8, 8)
46+
y = torch.randn(8, 8)
47+
return torch.mm(torch.relu(x), y)
48+
49+
with DebugMode() as debug_mode:
50+
out = run_once()
51+
52+
print("DebugMode output:")
53+
print(debug_mode.debug_string())
54+
55+
56+
######################################################################
57+
# Getting more metadata
58+
# -----------
59+
#
60+
# For most investigations, you'll want to enable stack traces, tensor IDs, and tensor hashing.
61+
# These features provide metadata to correlate operations back to model code.
62+
#
63+
# ``DebugMode.log_tensor_hashes`` decorates the log with hashes for every call.
64+
# The ``hash_tensor`` hash function uses ``torch.hash_tensor``, which returns 0 for tensors whose
65+
# elements are all the same. The ``norm`` hash function uses ``norm`` with ``p=1``.
66+
# With both these functions, especially ``norm``, tensor closeness in numerics is related to hash closeness,
67+
# so it's rather interpretable. The default ``hash_fn`` is ``norm``.
68+
69+
with (
70+
DebugMode(
71+
# record_stack_trace is only supported for eager in pytorch 2.10
72+
record_stack_trace=True,
73+
record_ids=True,
74+
) as debug_mode,
75+
DebugMode.log_tensor_hashes(
76+
hash_fn=["norm"], # this is the default
77+
hash_inputs=True,
78+
),
79+
):
80+
result = run_once()
81+
82+
print("DebugMode output with more metadata:")
83+
print(
84+
debug_mode.debug_string(show_stack_trace=True)
85+
)
86+
87+
######################################################################
88+
# Each line follows ``op(args) -> outputs``. When ``record_ids`` is enabled,
89+
# tensors are suffixed with ``$<id>`` and DTensors are labeled ``dt``.
90+
91+
92+
######################################################################
93+
# Log Triton kernels
94+
# ------------------
95+
#
96+
# Though Triton kernels are not dispatched, DebugMode has custom logic that logs their inputs and outputs.
97+
#
98+
# Inductor-generated Triton kernels show up with a ``[triton]`` prefix.
99+
# Pre/post hash annotations report buffer hashes around each kernel call, which
100+
# is helpful when isolating incorrect kernels.
101+
def f(x):
102+
return torch.mm(torch.relu(x), x.T)
103+
104+
x = torch.randn(3, 3, device="cuda")
105+
106+
with (
107+
DebugMode(record_output=True) as debug_mode,
108+
DebugMode.log_tensor_hashes(
109+
hash_inputs=True,
110+
)
111+
):
112+
a = torch.compile(f)(x)
113+
114+
print("Triton in DebugMode logs:")
115+
print(debug_mode.debug_string())
116+
117+
######################################################################
118+
# Numerical debugging with tensor hashes
119+
# --------------------------------------
120+
#
121+
# If you have numerical divergence between modes, you can use DebugMode to find where the
122+
# numerical divergence originates.
123+
# In the example below, you can see that all tensor hashes are the same for eager mode and compiled mode.
124+
# If any hash is different, then that's where the numerical divergence is coming from.
125+
126+
def run_model(model, data, *, compile_with=None):
127+
if compile_with is not None:
128+
model = torch.compile(model, backend=compile_with)
129+
with DebugMode(record_output=True) as dm, DebugMode.log_tensor_hashes(
130+
hash_inputs=True,
131+
):
132+
dm_out = model(*data)
133+
return dm, dm_out
134+
135+
class Toy(torch.nn.Module):
136+
def forward(self, x):
137+
return torch.relu(x).mm(x.T)
138+
139+
inputs = (torch.randn(4, 4),)
140+
dm_eager, _ = run_model(Toy(), inputs)
141+
dm_compiled, _ = run_model(Toy(), inputs, compile_with="aot_eager")
142+
143+
print("Eager mode:")
144+
print(dm_eager.debug_string())
145+
print("Compiled aot_eager mode:")
146+
print(dm_compiled.debug_string())
147+
148+
###############################################################################################
149+
# Now let's look at an example where the tensor hashes are different.
150+
# I intentionally wrote a wrong decomposition that decomposes cosine to sin.
151+
# This will cause numerical divergence.
152+
153+
154+
from torch._dynamo.backends.common import aot_autograd
155+
from torch._dynamo.backends.debugging import get_nop_func
156+
157+
def wrong_decomp(x):
158+
return torch.sin(x)
159+
160+
decomp_table = {}
161+
decomp_table[torch.ops.aten.cos.default] = wrong_decomp
162+
163+
backend = aot_autograd(
164+
fw_compiler=get_nop_func(),
165+
bw_compiler=get_nop_func(),
166+
decompositions=decomp_table
167+
)
168+
169+
def f(x):
170+
y = x.relu()
171+
z = torch.cos(x)
172+
return y + z
173+
174+
x = torch.randn(3, 3)
175+
with DebugMode(record_output=True) as dm_eager, DebugMode.log_tensor_hashes(
176+
hash_inputs=True,
177+
):
178+
f(x)
179+
180+
with DebugMode(record_output=True) as dm_compiled, DebugMode.log_tensor_hashes(
181+
hash_inputs=True,
182+
):
183+
torch.compile(f, backend=backend)(x)
184+
185+
print("Eager:")
186+
print(dm_eager.debug_string(show_stack_trace=True))
187+
print()
188+
print("Compiled with wrong decomposition:")
189+
print(dm_compiled.debug_string())
190+
191+
###############################################################################################
192+
# In the eager log, we have ``aten::cos``, but in the compiled log, we have ``aten::sin``.
193+
# Moreover, the output hash is different between eager and compiled mode.
194+
# Diffing the two logs would show that the first numerical divergence shows up in the ``aten::cos`` call.
195+
196+
197+
198+
199+
######################################################################
200+
# Custom dispatch hooks
201+
# ---------------------
202+
#
203+
# Hooks allow you to annotate each call with custom metadata such as GPU memory usage. ``log_hook`` returns a mapping
204+
# that is rendered inline with the debug string.
205+
206+
MB = 1024 * 1024.0
207+
208+
def memory_hook(func, types, args, kwargs, result):
209+
mem = torch.cuda.memory_allocated() / MB if torch.cuda.is_available() else 0.0
210+
peak = torch.cuda.max_memory_allocated() / MB if torch.cuda.is_available() else 0.0
211+
torch.cuda.reset_peak_memory_stats() if torch.cuda.is_available() else None
212+
return {"mem": f"{mem:.3f} MB", "peak": f"{peak:.3f} MB"}
213+
214+
with (
215+
DebugMode() as dm,
216+
DebugMode.dispatch_hooks(log_hook=memory_hook),
217+
):
218+
run_once()
219+
220+
print("DebugMode output with memory usage:")
221+
print(dm.debug_string())
222+
223+
######################################################################
224+
# Module boundaries
225+
# ----------------------------------
226+
#
227+
# ``record_nn_module=True`` inserts ``[nn.Mod]`` markers that show which
228+
# module executed each set of operations. As of PyTorch 2.10 it only works in eager mode,
229+
# but support for compiled modes is under development.
230+
231+
class Foo(torch.nn.Module):
232+
def __init__(self):
233+
super().__init__()
234+
self.l1 = torch.nn.Linear(4, 4)
235+
self.l2 = torch.nn.Linear(4, 4)
236+
237+
def forward(self, x):
238+
return self.l2(self.l1(x))
239+
240+
class Bar(torch.nn.Module):
241+
def __init__(self):
242+
super().__init__()
243+
self.abc = Foo()
244+
self.xyz = torch.nn.Linear(4, 4)
245+
246+
def forward(self, x):
247+
return self.xyz(self.abc(x))
248+
249+
mod = Bar()
250+
inp = torch.randn(4, 4)
251+
with DebugMode(record_nn_module=True, record_output=False) as debug_mode:
252+
_ = mod(inp)
253+
254+
print("DebugMode output with stack traces and module boundaries:")
255+
print(debug_mode.debug_string(show_stack_trace=True))
256+
257+
######################################################################
258+
# Conclusion
259+
# ----------
260+
#
261+
# DebugMode gives you a lightweight, runtime-only view of what PyTorch actually
262+
# executed, whether you are running eager code or compiled graphs. By layering
263+
# tensor hashing, Triton logging, and custom dispatch hooks you can quickly
264+
# track down numerical differences. This is especially helpful in debugging
265+
# bit-wise equivalence between runs.

0 commit comments

Comments
 (0)