Skip to content

Conversation

greg-kwasniewski1
Copy link

[https://github.com/NVIDIA/issues/6342][feat] Applying sharding transformations from model config

Description

If base_model_tp_plan is present in the model config and ad_config.use_sharding_from_config == True, skip sharing pattern detection, and instead, apply the sharding from the config.

Test Coverage

tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py has been updated to test new sharding logic.

@greg-kwasniewski1 greg-kwasniewski1 self-assigned this Jul 24, 2025
@greg-kwasniewski1 greg-kwasniewski1 added the enhancement New feature or request label Jul 24, 2025
"o_proj",
]
if any(attn_name in module_name for attn_name in attn_names):
min_local_shape = head_dim
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see why you want head_dim here. I don't think that's a good reason to break the factory/config <> graph transform abstraction.

The name matching is also very fragile to infer whether min_local_shape is necessary and not scalable

It really seems like a corner case not worth addressing or complicating the code.

I think if we use the factory sharding config we should just use it and don't build in extra sanity checks. The config should be executed as instructed

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lucaslie but then we risk the KV head problem we had with, e.g., qwen:
https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen3_moe/configuration_qwen3_moe.py

(...)
num_key_value_heads=4,
(...)
"layers.*.self_attn.k_proj": "colwise",

So we need to prevent from "sub-head" sharding. Either we can get it from the config, or deduce it from attn_node.meta['val'].shape, but this option, arguably, is even more fragile. Which one do you prefer?

greg-kwasniewski1 and others added 8 commits August 5, 2025 14:33
Signed-off-by: greg-kwasniewski1 <[email protected]>
Signed-off-by: greg-kwasniewski1 <[email protected]>
Signed-off-by: greg-kwasniewski1 <[email protected]>
Signed-off-by: greg-kwasniewski1 <[email protected]>
Signed-off-by: greg-kwasniewski1 <[email protected]>
Signed-off-by: greg-kwasniewski1 <[email protected]>
Signed-off-by: greg-kwasniewski1 <[email protected]>
Signed-off-by: greg-kwasniewski1 <[email protected]>
@suyoggupta
Copy link

as discussed offline, let's do the following:

  1. Rebase to TRTLLM-main
  2. Run trtllm-bench perf tests to make sure there are no regressions

ROW = 0 # Split along rows (first dimension)
COLUMN = 1 # Split along columns (second dimension)
# NOTE: The names COLUMN/ROW reflect the hugging face
# base_tp_plan sharding notation, but since we assume Y = W @ X^T,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Y = W^T * X?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants