Skip to content

Add GPU RE for MPZCH module tests #3185

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 14 additions & 46 deletions torchrec/modules/tests/test_hash_mc_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def test_scriptability_lru(self) -> None:
torch.jit.script(mcc_ec)

@unittest.skipIf(
torch.cuda.device_count() < 1,
torch.cuda.device_count() < 2,
"Not enough GPUs, this test requires at least one GPUs",
)
# pyre-ignore [56]
Expand Down Expand Up @@ -292,7 +292,7 @@ def test_zch_hash_train_to_inf_block_bucketize(
)

@unittest.skipIf(
torch.cuda.device_count() < 1,
torch.cuda.device_count() < 2,
"Not enough GPUs, this test requires at least one GPUs",
)
# pyre-ignore [56]
Expand Down Expand Up @@ -404,13 +404,13 @@ def test_zch_hash_train_rescales_two(self, hash_size: int) -> None:
)

@unittest.skipIf(
torch.cuda.device_count() < 1,
torch.cuda.device_count() < 2,
"Not enough GPUs, this test requires at least one GPUs",
)
# pyre-ignore [56]
@given(hash_size=st.sampled_from([0, 80]))
@settings(max_examples=5, deadline=None)
def test_zch_hash_train_rescales_four(self, hash_size: int) -> None:
def test_zch_hash_train_rescales_one(self, hash_size: int) -> None:
keep_original_indices = True
kjt = KeyedJaggedTensor(
keys=["f"],
Expand Down Expand Up @@ -446,23 +446,20 @@ def test_zch_hash_train_rescales_four(self, hash_size: int) -> None:
),
)

# start with world_size = 4
world_size = 4
# start with world_size = 2
world_size = 2
block_sizes = torch.tensor(
[(size + world_size - 1) // world_size for size in [hash_size]],
dtype=torch.int64,
device="cuda",
)

m1_1 = m0.rebuild_with_output_id_range((0, 10))
m2_1 = m0.rebuild_with_output_id_range((10, 20))
m3_1 = m0.rebuild_with_output_id_range((20, 30))
m4_1 = m0.rebuild_with_output_id_range((30, 40))
m1_1 = m0.rebuild_with_output_id_range((0, 20))
m2_1 = m0.rebuild_with_output_id_range((20, 40))

# shard, now world size 2!
# start with world_size = 4
# shard, now world size 1!
if hash_size > 0:
world_size = 2
world_size = 1
block_sizes = torch.tensor(
[(size + world_size - 1) // world_size for size in [hash_size]],
dtype=torch.int64,
Expand All @@ -476,7 +473,7 @@ def test_zch_hash_train_rescales_four(self, hash_size: int) -> None:
keep_original_indices=keep_original_indices,
output_permute=True,
)
in1_2, in2_2 = bucketized_kjt.split([len(kjt.keys())] * world_size)
in1_2 = bucketized_kjt.split([len(kjt.keys())] * world_size)[0]
else:
bucketized_kjt, permute = bucketize_kjt_before_all2all(
kjt,
Expand All @@ -492,14 +489,8 @@ def test_zch_hash_train_rescales_four(self, hash_size: int) -> None:
values=torch.cat([kjts[0].values(), kjts[1].values()], dim=0),
lengths=torch.cat([kjts[0].lengths(), kjts[1].lengths()], dim=0),
)
in2_2 = KeyedJaggedTensor(
keys=kjts[2].keys(),
values=torch.cat([kjts[2].values(), kjts[3].values()], dim=0),
lengths=torch.cat([kjts[2].lengths(), kjts[3].lengths()], dim=0),
)

m1_2 = m0.rebuild_with_output_id_range((0, 20))
m2_2 = m0.rebuild_with_output_id_range((20, 40))
m1_2 = m0.rebuild_with_output_id_range((0, 40))
m1_zch_identities = torch.cat(
[
m1_1.state_dict()["_hash_zch_identities"],
Expand All @@ -516,53 +507,30 @@ def test_zch_hash_train_rescales_four(self, hash_size: int) -> None:
state_dict["_hash_zch_identities"] = m1_zch_identities
state_dict["_hash_zch_metadata"] = m1_zch_metadata
m1_2.load_state_dict(state_dict)

m2_zch_identities = torch.cat(
[
m3_1.state_dict()["_hash_zch_identities"],
m4_1.state_dict()["_hash_zch_identities"],
]
)
m2_zch_metadata = torch.cat(
[
m3_1.state_dict()["_hash_zch_metadata"],
m4_1.state_dict()["_hash_zch_metadata"],
]
)
state_dict = m2_2.state_dict()
state_dict["_hash_zch_identities"] = m2_zch_identities
state_dict["_hash_zch_metadata"] = m2_zch_metadata
m2_2.load_state_dict(state_dict)

_ = m1_2(in1_2.to_dict())
_ = m2_2(in2_2.to_dict())

m0.reset_inference_mode() # just clears out training state
full_zch_identities = torch.cat(
[
m1_2.state_dict()["_hash_zch_identities"],
m2_2.state_dict()["_hash_zch_identities"],
]
)
state_dict = m0.state_dict()
state_dict["_hash_zch_identities"] = full_zch_identities
m0.load_state_dict(state_dict)

# now set all models to eval, and run kjt
m1_2.eval()
m2_2.eval()
assert m0.training is False

inf_input = kjt.to_dict()
inf_output = m0(inf_input)

inf_output = m0(inf_input)
o1_2 = m1_2(in1_2.to_dict())
o2_2 = m2_2(in2_2.to_dict())
self.assertTrue(
torch.allclose(
inf_output["f"].values(),
torch.index_select(
torch.cat([x["f"].values() for x in [o1_2, o2_2]]),
o1_2["f"].values(),
dim=0,
index=cast(torch.Tensor, permute),
),
Expand Down
Loading