Skip to content

Commit a3406b6

Browse files
committed
[TRTLLM-5252][fix] Propagate mapping to intermediate layers
Signed-off-by: William Zhang <[email protected]>
1 parent 3b2dd40 commit a3406b6

File tree

3 files changed

+17
-9
lines changed

3 files changed

+17
-9
lines changed

tensorrt_llm/_torch/models/modeling_mistral.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -475,6 +475,7 @@ def __init__(self, model_config: ModelConfig[Mistral3Config]):
475475
out_features=hidden_size,
476476
bias=False,
477477
dtype=config.torch_dtype,
478+
mapping=model_config.mapping,
478479
)
479480

480481
@torch.inference_mode()
@@ -539,13 +540,15 @@ def __init__(self, model_config: ModelConfig[Mistral3Config]):
539540
out_features=config.text_config.hidden_size,
540541
bias=config.multimodal_projector_bias,
541542
dtype=dtype,
543+
mapping=model_config.mapping,
542544
)
543545
self.act = ACT2FN[config.projector_hidden_act]
544546
self.linear_2 = Linear(
545547
in_features=config.text_config.hidden_size,
546548
out_features=config.text_config.hidden_size,
547549
bias=config.multimodal_projector_bias,
548550
dtype=dtype,
551+
mapping=model_config.mapping,
549552
)
550553

551554
@torch.inference_mode()

tests/integration/test_lists/test-db/l0_dgx_h100.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ l0_dgx_h100:
5252
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[MMLU-gen_tp=2-ctx_pp=2]
5353
- test_e2e.py::test_ptp_quickstart_advanced_bs1
5454
- test_e2e.py::test_ptp_quickstart_advanced_deepseek_v3_lite_4gpus_adp_balance[DeepSeek-V3-Lite-FP8-DeepSeek-V3-Lite/fp8]
55+
- unittest/_torch/modeling/test_modeling_pixtral.py::test_tensor_parallelism
5556
- condition:
5657
ranges:
5758
system_gpu_count:

tests/unittest/_torch/modeling/test_modeling_pixtral.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,7 @@
2828
pytestmark = pytest.mark.threadleak(enabled=False)
2929

3030

31-
@pytest.fixture
32-
def pixtral_vision_config():
31+
def make_pixtral_vision_config():
3332
# Values taken from:
3433
# https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503/blob/main/config.json
3534
return model_config_lib.ModelConfig(
@@ -71,9 +70,10 @@ def init_hf_model(cls, config, dtype, device):
7170

7271
@torch.no_grad()
7372
@pytest.mark.usefixtures("set_seed")
74-
def test_pixtral_vision_model_vs_hf(pixtral_vision_config):
73+
def test_pixtral_vision_model_vs_hf():
7574
dtype = torch.bfloat16
7675
device = torch.device("cuda")
76+
pixtral_vision_config = make_pixtral_vision_config()
7777
pretrained_config = pixtral_vision_config.pretrained_config
7878

7979
pixtral_model = (
@@ -111,13 +111,14 @@ def test_pixtral_vision_model_vs_hf(pixtral_vision_config):
111111

112112
@pytest.mark.parametrize("mpi_pool_executor", [2], indirect=True)
113113
@torch.no_grad()
114-
def test_tensor_parallelism(pixtral_vision_config, mpi_pool_executor, tmp_path):
114+
def test_tensor_parallelism(mpi_pool_executor, tmp_path):
115115
mapping = mapping_lib.Mapping(world_size=2, tp_size=2)
116116
if (num_available_devices := torch.cuda.device_count()) < mapping.world_size:
117117
pytest.skip(f"{num_available_devices=} is less than the requested {mapping.world_size}.")
118118

119119
dtype = torch.bfloat16
120120
device = torch.device("cuda")
121+
pixtral_vision_config = make_pixtral_vision_config()
121122
pretrained_config = pixtral_vision_config.pretrained_config
122123

123124
hf_pixtral_model = init_hf_model(
@@ -157,20 +158,22 @@ def test_tensor_parallelism(pixtral_vision_config, mpi_pool_executor, tmp_path):
157158
gc.collect()
158159
torch.cuda.empty_cache()
159160

161+
# NOTE: we cannot send `pixtral_vision_config` across the process barrier, as it contains
162+
# `weakref` objects, which cannot be pickled. Instead, each worker will recreate it by
163+
# calling the `make_pixtral_vision_config` function.
160164
world_size = mapping.world_size
161-
pixtral_vision_config.mapping = mapping
162165
results = mpi_pool_executor.starmap(
163166
_run_pixtral_and_compare_against_ref,
164167
[
165168
(
166-
pixtral_vision_config,
169+
mapping_lib.Mapping(tp_size=world_size, world_size=world_size, rank=rank),
167170
hf_weights_path,
168171
pixel_values,
169172
image_sizes,
170173
ref_out,
171174
num_params,
172175
)
173-
for _ in range(world_size)
176+
for rank in range(world_size)
174177
],
175178
)
176179

@@ -179,7 +182,7 @@ def test_tensor_parallelism(pixtral_vision_config, mpi_pool_executor, tmp_path):
179182

180183

181184
def _run_pixtral_and_compare_against_ref(
182-
pixtral_vision_config: model_config_lib.ModelConfig[transformers.PixtralVisionConfig],
185+
mapping: mapping_lib.Mapping,
183186
hf_weights_path: pathlib.Path,
184187
pixel_values: torch.Tensor,
185188
image_sizes: torch.Tensor,
@@ -197,7 +200,8 @@ def _run_pixtral_and_compare_against_ref(
197200
image_sizes = image_sizes.to("cuda")
198201
expected_output = expected_output.to("cuda")
199202

200-
pixtral_vision_config.mapping.rank = rank
203+
pixtral_vision_config = make_pixtral_vision_config()
204+
pixtral_vision_config.mapping = mapping
201205
pixtral_model = (
202206
modeling_pixtral.PixtralVisionModel(model_config=pixtral_vision_config).eval().to("cuda")
203207
)

0 commit comments

Comments
 (0)