28
28
pytestmark = pytest .mark .threadleak (enabled = False )
29
29
30
30
31
- @pytest .fixture
32
- def pixtral_vision_config ():
31
+ def make_pixtral_vision_config ():
33
32
# Values taken from:
34
33
# https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503/blob/main/config.json
35
34
return model_config_lib .ModelConfig (
@@ -71,9 +70,10 @@ def init_hf_model(cls, config, dtype, device):
71
70
72
71
@torch .no_grad ()
73
72
@pytest .mark .usefixtures ("set_seed" )
74
- def test_pixtral_vision_model_vs_hf (pixtral_vision_config ):
73
+ def test_pixtral_vision_model_vs_hf ():
75
74
dtype = torch .bfloat16
76
75
device = torch .device ("cuda" )
76
+ pixtral_vision_config = make_pixtral_vision_config ()
77
77
pretrained_config = pixtral_vision_config .pretrained_config
78
78
79
79
pixtral_model = (
@@ -111,13 +111,14 @@ def test_pixtral_vision_model_vs_hf(pixtral_vision_config):
111
111
112
112
@pytest .mark .parametrize ("mpi_pool_executor" , [2 ], indirect = True )
113
113
@torch .no_grad ()
114
- def test_tensor_parallelism (pixtral_vision_config , mpi_pool_executor , tmp_path ):
114
+ def test_tensor_parallelism (mpi_pool_executor , tmp_path ):
115
115
mapping = mapping_lib .Mapping (world_size = 2 , tp_size = 2 )
116
116
if (num_available_devices := torch .cuda .device_count ()) < mapping .world_size :
117
117
pytest .skip (f"{ num_available_devices = } is less than the requested { mapping .world_size } ." )
118
118
119
119
dtype = torch .bfloat16
120
120
device = torch .device ("cuda" )
121
+ pixtral_vision_config = make_pixtral_vision_config ()
121
122
pretrained_config = pixtral_vision_config .pretrained_config
122
123
123
124
hf_pixtral_model = init_hf_model (
@@ -157,20 +158,22 @@ def test_tensor_parallelism(pixtral_vision_config, mpi_pool_executor, tmp_path):
157
158
gc .collect ()
158
159
torch .cuda .empty_cache ()
159
160
161
+ # NOTE: we cannot send `pixtral_vision_config` across the process barrier, as it contains
162
+ # `weakref` objects, which cannot be pickled. Instead, each worker will recreate it by
163
+ # calling the `make_pixtral_vision_config` function.
160
164
world_size = mapping .world_size
161
- pixtral_vision_config .mapping = mapping
162
165
results = mpi_pool_executor .starmap (
163
166
_run_pixtral_and_compare_against_ref ,
164
167
[
165
168
(
166
- pixtral_vision_config ,
169
+ mapping_lib . Mapping ( tp_size = world_size , world_size = world_size , rank = rank ) ,
167
170
hf_weights_path ,
168
171
pixel_values ,
169
172
image_sizes ,
170
173
ref_out ,
171
174
num_params ,
172
175
)
173
- for _ in range (world_size )
176
+ for rank in range (world_size )
174
177
],
175
178
)
176
179
@@ -179,7 +182,7 @@ def test_tensor_parallelism(pixtral_vision_config, mpi_pool_executor, tmp_path):
179
182
180
183
181
184
def _run_pixtral_and_compare_against_ref (
182
- pixtral_vision_config : model_config_lib . ModelConfig [ transformers . PixtralVisionConfig ] ,
185
+ mapping : mapping_lib . Mapping ,
183
186
hf_weights_path : pathlib .Path ,
184
187
pixel_values : torch .Tensor ,
185
188
image_sizes : torch .Tensor ,
@@ -197,7 +200,8 @@ def _run_pixtral_and_compare_against_ref(
197
200
image_sizes = image_sizes .to ("cuda" )
198
201
expected_output = expected_output .to ("cuda" )
199
202
200
- pixtral_vision_config .mapping .rank = rank
203
+ pixtral_vision_config = make_pixtral_vision_config ()
204
+ pixtral_vision_config .mapping = mapping
201
205
pixtral_model = (
202
206
modeling_pixtral .PixtralVisionModel (model_config = pixtral_vision_config ).eval ().to ("cuda" )
203
207
)
0 commit comments