Skip to content

Commit 64fca5d

Browse files
committed
reviewer feedback
Signed-off-by: Lucas Liebenwein <[email protected]>
1 parent c3465d0 commit 64fca5d

File tree

4 files changed

+16
-8
lines changed

4 files changed

+16
-8
lines changed

tensorrt_llm/_torch/auto_deploy/models/hf.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -525,12 +525,12 @@ def post_process(self, sub_mod: nn.Module, sub_gm: GraphModule):
525525
)
526526

527527
def _init_dynamic_shape_lookup(self) -> Dict[str, DynamicShape]:
528-
batch_size_dyn = Dim.DYNAMIC
529-
seq_len_dyn = Dim.DYNAMIC
528+
batch_size_dynamic = Dim.DYNAMIC
529+
seq_len_dynamic = Dim.DYNAMIC
530530
return {
531-
"input_ids": {0: batch_size_dyn, 1: seq_len_dyn},
532-
"inputs_embeds": {0: batch_size_dyn, 1: seq_len_dyn},
533-
"position_ids": {0: batch_size_dyn, 1: seq_len_dyn},
531+
"input_ids": {0: batch_size_dynamic, 1: seq_len_dynamic},
532+
"inputs_embeds": {0: batch_size_dynamic, 1: seq_len_dynamic},
533+
"position_ids": {0: batch_size_dynamic, 1: seq_len_dynamic},
534534
}
535535

536536
@classmethod

tensorrt_llm/_torch/auto_deploy/transform/library/export_to_gm.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def set_exact_signature(mod: nn.Module, kwargs: Dict[str, Any]):
8888

8989
reset_signature = False
9090
if hasattr(forward_func, "__signature__"):
91-
signature_attribute = mod.forward.__signature__
91+
signature_attribute = forward_func.__signature__
9292
reset_signature = True
9393

9494
# construct signature object from kwargs
@@ -139,8 +139,13 @@ def _apply_to_full_model(
139139
# independent, which would conflict with graph capture logic, i.e., you cannot graph-capture
140140
# "model" and "model.text_model" for example. However, you can export "model.text_model" and
141141
# "model.vision_model" separately.
142+
def _is_child(child: str, parent: str) -> bool:
143+
"""Check if ``child`` is a child of ``parent``."""
144+
# covers "a.b.c" is a parent of "a.b" or parent being "", i.e., root (a parent of all!)
145+
return parent == "" or child.startswith(f"{parent}.")
146+
142147
sub_keys = [info.submodule_name for info in export_infos]
143-
assert all(not k1.startswith(k2) for k1 in sub_keys for k2 in sub_keys if k1 != k2), (
148+
assert all(not _is_child(k1, k2) for k1 in sub_keys for k2 in sub_keys if k1 != k2), (
144149
f"Cannot export submodules of already exported submodules, {sub_keys=}"
145150
)
146151

tensorrt_llm/_torch/auto_deploy/transform/library/kvcache_transformers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def _apply_to_full_model(
113113
cm: CachedSequenceInterface,
114114
factory: ModelFactory,
115115
shared_config: SharedConfig,
116-
) -> Tuple[GraphModule, TransformInfo]:
116+
) -> Tuple[nn.Module, TransformInfo]:
117117
# Register profiler attn operator
118118
ALL_ATTENTION_FUNCTIONS.register("ad_profile_mha", fake_profiler_mha)
119119

tensorrt_llm/_torch/auto_deploy/transformations/_graph.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,9 @@ def move_to_device(mod: nn.Module, device: DeviceLikeType) -> None:
137137
# get device
138138
device = torch.device(device)
139139

140+
# move the model to the device
141+
mod.to(device)
142+
140143
for _, subgm in reversed(list(named_graphmodules(mod))):
141144
# recompile graph to update self generated codes in subgraph
142145
_move_single_gm_to_device(subgm, device)

0 commit comments

Comments
 (0)