Skip to content

Commit 90c4075

Browse files
authored
[Tracing] Better runtime error messages (#1307)
## Purpose ## * Add better exception messages when encountering tracing errors ## Example ## * Below is an example of a potential tracing runtime error (this particular error was forced for demonstration purposes) ```` Traceback (most recent call last): File "/home/kyle/llm-compressor/src/llmcompressor/pipelines/sequential/helpers.py", line 45, in forward outputs = forward_fn(*args, **kwargs) File "<string>", line 12, in forward TypeError: iter(v, w): v must be callable The above exception was the direct cause of the following exception: Traceback (most recent call last): File "/home/kyle/llm-compressor/src/llmcompressor/modifiers/quantization/gptq/base.py", line 234, in on_initialize run_sequential( File "/home/kyle/llm-compressor/src/llmcompressor/pipelines/sequential/pipeline.py", line 67, in run_pipeline subgraph.forward(model, **inputs) File "/home/kyle/llm-compressor/src/llmcompressor/pipelines/sequential/helpers.py", line 47, in forward raise RuntimeError( RuntimeError: Raised an exception during execution of the following code: ``` 1 2 3 4 def forward(self, input_ids : torch.Tensor, attention_mask : torch.Tensor): 5 model_rotary_emb_inv_freq = self.model.rotary_emb.inv_freq 6 getitem_10 = model_rotary_emb_inv_freq[(None, slice(None, None, None), None)]; model_rotary_emb_inv_freq = None 7 model_embed_tokens = self.model.embed_tokens(input_ids); input_ids = None 8 size_3 = attention_mask.size(); size_3 = None 9 dim = attention_mask.dim() 10 size_6 = attention_mask.size() 11 getitem_8 = attention_mask[(slice(None, None, None), None, None, slice(None, None, None))] 12 iter_6 = iter(attention_mask, 'device'); attention_mask = None 13 float_1 = getitem_10.float(); getitem_10 = None 14 size = model_embed_tokens.size() 15 iter_1 = iter(model_embed_tokens, 'device') ``` ```` ## Changes ## * Move forward call to inside Subgraph class and wrap forward call in order to catch and propagate exceptions --------- Signed-off-by: Kyle Sayers <[email protected]>
1 parent 1f2b796 commit 90c4075

File tree

2 files changed

+33
-12
lines changed

2 files changed

+33
-12
lines changed

src/llmcompressor/pipelines/sequential/helpers.py

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import inspect
22
from collections import deque
33
from dataclasses import dataclass
4-
from typing import Any, Callable, Dict, List, Set, Union
4+
from typing import Any, Callable, Dict, List, Optional, Set, Union
55

66
from compressed_tensors import has_offloaded_params
77
from compressed_tensors.quantization import find_name_or_class_matches
88
from torch.fx import Graph, GraphModule, Node
9+
from torch.fx.graph import PythonCode
910
from torch.fx.proxy import Argument
1011
from torch.nn import Module
1112
from transformers import PreTrainedModel
@@ -32,16 +33,33 @@ class Subgraph:
3233
graph: Graph
3334
input_names: Set[str]
3435
consumed_names: Set[str]
36+
_code: Optional[PythonCode] = None
3537

36-
def compile_forward(self) -> Callable[[Any], Any]:
38+
def forward(self, *args, **kwargs) -> Dict[str, Any]:
3739
"""
38-
Generate and compile code for executing this subgraph
40+
Execute the operations within the subgraph
3941
40-
:return: function which, when called, executes this subgraph
42+
:param \\*args: argument inputs to subgraph forward function
43+
:param \\**kwargs: keyword inputs to subgraph forward function
44+
:return keyword outputs of subgraph forward function (non-consumed variables):
4145
"""
42-
code = self.graph.python_code("self")
43-
exec(code.src, code.globals)
44-
return code.globals.get("forward")
46+
if self._code is None:
47+
self._code = self.graph.python_code("self")
48+
exec(self._code.src, self._code.globals)
49+
50+
forward_fn = self._code.globals.get("forward")
51+
52+
try:
53+
outputs = forward_fn(*args, **kwargs)
54+
except Exception as exception:
55+
raise RuntimeError(
56+
"Raised an exception during execution of the following code:\n"
57+
f"```\n{add_line_numbers(self._code.src)}\n```\n"
58+
"This is likely due to a violation of shape assumptions made when "
59+
"tracing"
60+
) from exception
61+
62+
return outputs
4563

4664

4765
def trace_subgraphs(
@@ -376,3 +394,9 @@ def match_modules(model: Module, target_names: List[str]) -> Set[Module]:
376394
for name, module in model.named_modules()
377395
if find_name_or_class_matches(name, module, target_names)
378396
)
397+
398+
399+
def add_line_numbers(text: str) -> str:
400+
lines = text.splitlines()
401+
numbered_lines = [f"{i + 1} {line}" for i, line in enumerate(lines)]
402+
return "\n".join(numbered_lines)

src/llmcompressor/pipelines/sequential/pipeline.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,13 +61,10 @@ def run_pipeline(
6161
calib_desc = f"({subgraph_index + 1}/{num_subgraphs}): Calibrating"
6262
prop_desc = f"({subgraph_index + 1}/{num_subgraphs}): Propagating"
6363

64-
# compile subgraph forward function
65-
forward_function = subgraph.compile_forward()
66-
6764
# do an preliminary pass to trigger modifier hooks
6865
for batch_index in tqdm.tqdm(range(len(dataloader)), desc=calib_desc):
6966
inputs = intermediates.fetch(batch_index, subgraph.input_names)
70-
forward_function(model, **inputs)
67+
subgraph.forward(model, **inputs)
7168

7269
# TODO: replace with a lifecycle event
7370
if callback_modifier:
@@ -78,7 +75,7 @@ def run_pipeline(
7875
with HooksMixin.disable_hooks():
7976
for batch_index in tqdm.tqdm(range(len(dataloader)), desc=prop_desc):
8077
inputs = intermediates.fetch(batch_index, subgraph.input_names)
81-
output = forward_function(model, **inputs)
78+
output = subgraph.forward(model, **inputs)
8279

8380
if subgraph_index < num_subgraphs - 1:
8481
intermediates.update(batch_index, output)

0 commit comments

Comments
 (0)