Skip to content

Commit 8b378ac

Browse files
committed
chore: improved naming consistenty in TP test
1 parent 472c1b0 commit 8b378ac

File tree

3 files changed

+15
-15
lines changed

3 files changed

+15
-15
lines changed

tests/fsdp2_parallelization/test_tensor_parallelism.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -45,27 +45,27 @@ def tmp_config_dir(tmp_path_factory) -> Path:
4545
class TestTensorParallelism:
4646
def _get_components(self, config_file_path: Path) -> Tuple[FSDP2, DeviceMesh]:
4747
class ComponentsInstantiationModel(BaseModel):
48-
fsdp_model: PydanticFSDP2ModuleType
48+
model: PydanticFSDP2ModuleType
4949
device_mesh: PydanticDeviceMeshIFType
5050

5151
main_obj = Main(config_file_path)
5252
components: ComponentsInstantiationModel = main_obj.build_components(
5353
components_model_type=ComponentsInstantiationModel
5454
)
55-
return components.fsdp_model, components.device_mesh
55+
return components.model, components.device_mesh
5656

5757
@pytest.mark.parametrize(
58-
"activation_type, fsdp_config_path, tp_config_path, port",
58+
"activation_type, fsdp2_config_path, tp_config_path, port",
5959
[
6060
(
6161
"gelu",
62-
Path("tests/fsdp2_parallelization/tp_test_configs/fsdp_config.yaml"),
62+
Path("tests/fsdp2_parallelization/tp_test_configs/fsdp2_config.yaml"),
6363
Path("tests/fsdp2_parallelization/tp_test_configs/tp_config.yaml"),
6464
22755,
6565
),
6666
(
6767
"swiglu",
68-
Path("tests/fsdp2_parallelization/tp_test_configs/fsdp_config.yaml"),
68+
Path("tests/fsdp2_parallelization/tp_test_configs/fsdp2_config.yaml"),
6969
Path("tests/fsdp2_parallelization/tp_test_configs/tp_config.yaml"),
7070
22756,
7171
),
@@ -74,15 +74,15 @@ class ComponentsInstantiationModel(BaseModel):
7474
def test_tp_sharding(
7575
self,
7676
activation_type: str,
77-
fsdp_config_path: Path,
77+
fsdp2_config_path: Path,
7878
tp_config_path: Path,
7979
tmp_config_dir: Path,
8080
port: int,
8181
):
8282
world_size = 4
8383
mp.spawn(
8484
self._test_tp_sharding_impl,
85-
args=(activation_type, fsdp_config_path, tp_config_path, world_size, tmp_config_dir, port),
85+
args=(activation_type, fsdp2_config_path, tp_config_path, world_size, tmp_config_dir, port),
8686
nprocs=world_size,
8787
join=True,
8888
)
@@ -91,7 +91,7 @@ def _test_tp_sharding_impl(
9191
self,
9292
process_id: int,
9393
activation_type: str,
94-
fsdp_config_path: Path,
94+
fsdp2_config_path: Path,
9595
tp_config_path: Path,
9696
world_size: int,
9797
tmp_config_dir: Path,
@@ -107,7 +107,7 @@ def _test_tp_sharding_impl(
107107
):
108108
# Seed before FSDP2 instantiation
109109
torch.manual_seed(42)
110-
fsdp2_path = patch_config_file(fsdp_config_path, activation_type, tmp_config_dir)
110+
fsdp2_path = patch_config_file(fsdp2_config_path, activation_type, tmp_config_dir)
111111
fsdp2_model, fsdp2_mesh = self._get_components(fsdp2_path)
112112

113113
# Seed again before TP instantiation to match
@@ -160,13 +160,13 @@ def all_named_tensors(model: nn.Module):
160160
yield from model.named_parameters()
161161
yield from model.named_buffers()
162162

163-
fsdp_tensors = dict(all_named_tensors(fsdp2_model))
163+
fsdp2_tensors = dict(all_named_tensors(fsdp2_model))
164164
tp_tensors = dict(all_named_tensors(tp_model))
165165

166-
assert fsdp_tensors.keys() == tp_tensors.keys(), "Model structures differ"
166+
assert fsdp2_tensors.keys() == tp_tensors.keys(), "Model structures differ"
167167

168-
for name in fsdp_tensors:
169-
a, b = fsdp_tensors[name], tp_tensors[name]
168+
for name in fsdp2_tensors:
169+
a, b = fsdp2_tensors[name], tp_tensors[name]
170170

171171
a_mat = a.redistribute(fsdp2_mesh, [Replicate()]).to_local() if isinstance(a, DTensor) else a
172172
b_mat = b.redistribute(tp_mesh, [Replicate(), Replicate()]).to_local() if isinstance(b, DTensor) else b

tests/fsdp2_parallelization/tp_test_configs/fsdp_config.yaml renamed to tests/fsdp2_parallelization/tp_test_configs/fsdp2_config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
fsdp_model:
1+
model:
22
component_key: model
33
variant_key: fsdp2_wrapped
44
config:

tests/fsdp2_parallelization/tp_test_configs/tp_config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
fsdp_model:
1+
model:
22
component_key: model
33
variant_key: fsdp2_wrapped
44
config:

0 commit comments

Comments
 (0)