Skip to content

Commit 802e869

Browse files
authored
[Tracing] Reinstate ignore functionality (#1423)
## Purpose ## * Fix issues with e2e testing and examples where some recipe still target vision decoder blocks in the no_split_params, and therefore trace the vision sections ## Changes ## * Add back module ignoring functionality ## Testing ## * Ran qwen2_5_vl example and demonstrated that tracing that works nows with default no_split_params Signed-off-by: Kyle Sayers <[email protected]>
1 parent ebb2b50 commit 802e869

File tree

1 file changed

+14
-5
lines changed

1 file changed

+14
-5
lines changed

src/llmcompressor/pipelines/sequential/helpers.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -80,15 +80,16 @@ def trace_subgraphs(
8080
:param sample_input: inputs whose values will change during execution but whose
8181
__len__, __bool__, and __contains__ values are assumed constant across batches
8282
:param sequential_targets: list of patterns matching sequential targets
83-
:param ignore: TODO: unused, in the future will specify functions and methods to
84-
skip during tracing
83+
:param ignore: modules to ignore during tracing, in the future will specify
84+
functions and methods to skip during tracing
8585
:return: a list of Subgraphs in order of execution
8686
"""
8787
# find modules
8888
sequential_targets = match_modules(model, sequential_targets)
89+
ignore = match_modules(model, ignore)
8990

9091
# initialize arguments
91-
tracer = get_tracer(model, sequential_targets)
92+
tracer = get_tracer(model, sequential_targets, ignore)
9293
concrete_args = populate_concrete_args(model, sample_input)
9394

9495
# trace
@@ -118,7 +119,9 @@ def trace_subgraphs(
118119
return subgraphs
119120

120121

121-
def get_tracer(model: Module, sequential_targets: Set[Module]) -> HFTracer:
122+
def get_tracer(
123+
model: Module, sequential_targets: Set[Module], ignore: Set[Module]
124+
) -> HFTracer:
122125
"""
123126
Get a tracer specialized for the given model. The resulting tracer will not trace
124127
inside of sequential targets, nor any modules which are not call graph ancestors of
@@ -129,6 +132,8 @@ def get_tracer(model: Module, sequential_targets: Set[Module]) -> HFTracer:
129132
130133
:param model: model being traced
131134
:param sequential_targets: modules which are sequential targets
135+
:param ignore: modules to ignore during tracing, in the future will specify
136+
functions and methods to skip during tracing
132137
"""
133138
sequential_ancestors = get_sequential_ancestors(model, sequential_targets)
134139
offloaded_modules = set(m for m in model.modules() if has_offloaded_params(m))
@@ -155,7 +160,11 @@ def create_arg(self, a: Any) -> Argument:
155160
return super().create_arg(a)
156161

157162
def is_leaf_module(self, module: Module, module_qualified_name: str) -> bool:
158-
return module not in sequential_ancestors or module in offloaded_modules
163+
return (
164+
module not in sequential_ancestors
165+
or module in offloaded_modules
166+
or module in ignore
167+
)
159168

160169
def trace(self, root: Union[Module, Callable], *args, **kwargs) -> Graph:
161170
if isinstance(root, Module):

0 commit comments

Comments
 (0)