@@ -45,27 +45,27 @@ def tmp_config_dir(tmp_path_factory) -> Path:
45
45
class TestTensorParallelism :
46
46
def _get_components (self , config_file_path : Path ) -> Tuple [FSDP2 , DeviceMesh ]:
47
47
class ComponentsInstantiationModel (BaseModel ):
48
- fsdp_model : PydanticFSDP2ModuleType
48
+ model : PydanticFSDP2ModuleType
49
49
device_mesh : PydanticDeviceMeshIFType
50
50
51
51
main_obj = Main (config_file_path )
52
52
components : ComponentsInstantiationModel = main_obj .build_components (
53
53
components_model_type = ComponentsInstantiationModel
54
54
)
55
- return components .fsdp_model , components .device_mesh
55
+ return components .model , components .device_mesh
56
56
57
57
@pytest .mark .parametrize (
58
- "activation_type, fsdp_config_path , tp_config_path, port" ,
58
+ "activation_type, fsdp2_config_path , tp_config_path, port" ,
59
59
[
60
60
(
61
61
"gelu" ,
62
- Path ("tests/fsdp2_parallelization/tp_test_configs/fsdp_config .yaml" ),
62
+ Path ("tests/fsdp2_parallelization/tp_test_configs/fsdp2_config .yaml" ),
63
63
Path ("tests/fsdp2_parallelization/tp_test_configs/tp_config.yaml" ),
64
64
22755 ,
65
65
),
66
66
(
67
67
"swiglu" ,
68
- Path ("tests/fsdp2_parallelization/tp_test_configs/fsdp_config .yaml" ),
68
+ Path ("tests/fsdp2_parallelization/tp_test_configs/fsdp2_config .yaml" ),
69
69
Path ("tests/fsdp2_parallelization/tp_test_configs/tp_config.yaml" ),
70
70
22756 ,
71
71
),
@@ -74,15 +74,15 @@ class ComponentsInstantiationModel(BaseModel):
74
74
def test_tp_sharding (
75
75
self ,
76
76
activation_type : str ,
77
- fsdp_config_path : Path ,
77
+ fsdp2_config_path : Path ,
78
78
tp_config_path : Path ,
79
79
tmp_config_dir : Path ,
80
80
port : int ,
81
81
):
82
82
world_size = 4
83
83
mp .spawn (
84
84
self ._test_tp_sharding_impl ,
85
- args = (activation_type , fsdp_config_path , tp_config_path , world_size , tmp_config_dir , port ),
85
+ args = (activation_type , fsdp2_config_path , tp_config_path , world_size , tmp_config_dir , port ),
86
86
nprocs = world_size ,
87
87
join = True ,
88
88
)
@@ -91,7 +91,7 @@ def _test_tp_sharding_impl(
91
91
self ,
92
92
process_id : int ,
93
93
activation_type : str ,
94
- fsdp_config_path : Path ,
94
+ fsdp2_config_path : Path ,
95
95
tp_config_path : Path ,
96
96
world_size : int ,
97
97
tmp_config_dir : Path ,
@@ -107,7 +107,7 @@ def _test_tp_sharding_impl(
107
107
):
108
108
# Seed before FSDP2 instantiation
109
109
torch .manual_seed (42 )
110
- fsdp2_path = patch_config_file (fsdp_config_path , activation_type , tmp_config_dir )
110
+ fsdp2_path = patch_config_file (fsdp2_config_path , activation_type , tmp_config_dir )
111
111
fsdp2_model , fsdp2_mesh = self ._get_components (fsdp2_path )
112
112
113
113
# Seed again before TP instantiation to match
@@ -160,13 +160,13 @@ def all_named_tensors(model: nn.Module):
160
160
yield from model .named_parameters ()
161
161
yield from model .named_buffers ()
162
162
163
- fsdp_tensors = dict (all_named_tensors (fsdp2_model ))
163
+ fsdp2_tensors = dict (all_named_tensors (fsdp2_model ))
164
164
tp_tensors = dict (all_named_tensors (tp_model ))
165
165
166
- assert fsdp_tensors .keys () == tp_tensors .keys (), "Model structures differ"
166
+ assert fsdp2_tensors .keys () == tp_tensors .keys (), "Model structures differ"
167
167
168
- for name in fsdp_tensors :
169
- a , b = fsdp_tensors [name ], tp_tensors [name ]
168
+ for name in fsdp2_tensors :
169
+ a , b = fsdp2_tensors [name ], tp_tensors [name ]
170
170
171
171
a_mat = a .redistribute (fsdp2_mesh , [Replicate ()]).to_local () if isinstance (a , DTensor ) else a
172
172
b_mat = b .redistribute (tp_mesh , [Replicate (), Replicate ()]).to_local () if isinstance (b , DTensor ) else b
0 commit comments