-
Notifications
You must be signed in to change notification settings - Fork 0
[feat] TP Sharding read from the model config (fixes #6342) #117
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: feat/ad-2025-07-22
Are you sure you want to change the base?
Conversation
tensorrt_llm/_torch/auto_deploy/transformations/library/sharding.py
Outdated
Show resolved
Hide resolved
tensorrt_llm/_torch/auto_deploy/transformations/library/sharding.py
Outdated
Show resolved
Hide resolved
tensorrt_llm/_torch/auto_deploy/transformations/library/sharding.py
Outdated
Show resolved
Hide resolved
"o_proj", | ||
] | ||
if any(attn_name in module_name for attn_name in attn_names): | ||
min_local_shape = head_dim |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see why you want head_dim
here. I don't think that's a good reason to break the factory/config <> graph transform abstraction.
The name matching is also very fragile to infer whether min_local_shape
is necessary and not scalable
It really seems like a corner case not worth addressing or complicating the code.
I think if we use the factory sharding config we should just use it and don't build in extra sanity checks. The config should be executed as instructed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@lucaslie but then we risk the KV head problem we had with, e.g., qwen:
https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen3_moe/configuration_qwen3_moe.py
(...)
num_key_value_heads=4,
(...)
"layers.*.self_attn.k_proj": "colwise",
So we need to prevent from "sub-head" sharding. Either we can get it from the config, or deduce it from attn_node.meta['val'].shape
, but this option, arguably, is even more fragile. Which one do you prefer?
tensorrt_llm/_torch/auto_deploy/transformations/library/sharding.py
Outdated
Show resolved
Hide resolved
tensorrt_llm/_torch/auto_deploy/transformations/library/sharding.py
Outdated
Show resolved
Hide resolved
Signed-off-by: greg-kwasniewski1 <[email protected]>
Signed-off-by: greg-kwasniewski1 <[email protected]>
Signed-off-by: greg-kwasniewski1 <[email protected]>
Signed-off-by: greg-kwasniewski1 <[email protected]>
Signed-off-by: greg-kwasniewski1 <[email protected]>
Signed-off-by: greg-kwasniewski1 <[email protected]>
Signed-off-by: greg-kwasniewski1 <[email protected]>
Signed-off-by: greg-kwasniewski1 <[email protected]>
355b46b
to
872e572
Compare
as discussed offline, let's do the following:
|
ROW = 0 # Split along rows (first dimension) | ||
COLUMN = 1 # Split along columns (second dimension) | ||
# NOTE: The names COLUMN/ROW reflect the hugging face | ||
# base_tp_plan sharding notation, but since we assume Y = W @ X^T, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Y = W^T * X?
[https://github.com/NVIDIA/issues/6342][feat] Applying sharding transformations from model config
Description
If
base_model_tp_plan
is present in the model config andad_config.use_sharding_from_config == True
, skip sharing pattern detection, and instead, apply the sharding from the config.Test Coverage
tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py
has been updated to test new sharding logic.