Skip to content

Commit 5ac1269

Browse files
kausvfacebook-github-bot
authored andcommitted
Add GPU RE for MPZCH module tests
Summary: Tests run on CPU and then unittest skips on no gpu. I also changed tests to make sure only 2 gpus are needed instead of 4 Differential Revision: D78169070
1 parent d797031 commit 5ac1269

File tree

1 file changed

+14
-46
lines changed

1 file changed

+14
-46
lines changed

torchrec/modules/tests/test_hash_mc_modules.py

Lines changed: 14 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ def test_scriptability_lru(self) -> None:
213213
torch.jit.script(mcc_ec)
214214

215215
@unittest.skipIf(
216-
torch.cuda.device_count() < 1,
216+
torch.cuda.device_count() < 2,
217217
"Not enough GPUs, this test requires at least one GPUs",
218218
)
219219
# pyre-ignore [56]
@@ -292,7 +292,7 @@ def test_zch_hash_train_to_inf_block_bucketize(
292292
)
293293

294294
@unittest.skipIf(
295-
torch.cuda.device_count() < 1,
295+
torch.cuda.device_count() < 2,
296296
"Not enough GPUs, this test requires at least one GPUs",
297297
)
298298
# pyre-ignore [56]
@@ -404,13 +404,13 @@ def test_zch_hash_train_rescales_two(self, hash_size: int) -> None:
404404
)
405405

406406
@unittest.skipIf(
407-
torch.cuda.device_count() < 1,
407+
torch.cuda.device_count() < 2,
408408
"Not enough GPUs, this test requires at least one GPUs",
409409
)
410410
# pyre-ignore [56]
411411
@given(hash_size=st.sampled_from([0, 80]))
412412
@settings(max_examples=5, deadline=None)
413-
def test_zch_hash_train_rescales_four(self, hash_size: int) -> None:
413+
def test_zch_hash_train_rescales_one(self, hash_size: int) -> None:
414414
keep_original_indices = True
415415
kjt = KeyedJaggedTensor(
416416
keys=["f"],
@@ -446,23 +446,20 @@ def test_zch_hash_train_rescales_four(self, hash_size: int) -> None:
446446
),
447447
)
448448

449-
# start with world_size = 4
450-
world_size = 4
449+
# start with world_size = 2
450+
world_size = 2
451451
block_sizes = torch.tensor(
452452
[(size + world_size - 1) // world_size for size in [hash_size]],
453453
dtype=torch.int64,
454454
device="cuda",
455455
)
456456

457-
m1_1 = m0.rebuild_with_output_id_range((0, 10))
458-
m2_1 = m0.rebuild_with_output_id_range((10, 20))
459-
m3_1 = m0.rebuild_with_output_id_range((20, 30))
460-
m4_1 = m0.rebuild_with_output_id_range((30, 40))
457+
m1_1 = m0.rebuild_with_output_id_range((0, 20))
458+
m2_1 = m0.rebuild_with_output_id_range((20, 40))
461459

462-
# shard, now world size 2!
463-
# start with world_size = 4
460+
# shard, now world size 1!
464461
if hash_size > 0:
465-
world_size = 2
462+
world_size = 1
466463
block_sizes = torch.tensor(
467464
[(size + world_size - 1) // world_size for size in [hash_size]],
468465
dtype=torch.int64,
@@ -476,7 +473,7 @@ def test_zch_hash_train_rescales_four(self, hash_size: int) -> None:
476473
keep_original_indices=keep_original_indices,
477474
output_permute=True,
478475
)
479-
in1_2, in2_2 = bucketized_kjt.split([len(kjt.keys())] * world_size)
476+
in1_2 = bucketized_kjt.split([len(kjt.keys())] * world_size)[0]
480477
else:
481478
bucketized_kjt, permute = bucketize_kjt_before_all2all(
482479
kjt,
@@ -492,14 +489,8 @@ def test_zch_hash_train_rescales_four(self, hash_size: int) -> None:
492489
values=torch.cat([kjts[0].values(), kjts[1].values()], dim=0),
493490
lengths=torch.cat([kjts[0].lengths(), kjts[1].lengths()], dim=0),
494491
)
495-
in2_2 = KeyedJaggedTensor(
496-
keys=kjts[2].keys(),
497-
values=torch.cat([kjts[2].values(), kjts[3].values()], dim=0),
498-
lengths=torch.cat([kjts[2].lengths(), kjts[3].lengths()], dim=0),
499-
)
500492

501-
m1_2 = m0.rebuild_with_output_id_range((0, 20))
502-
m2_2 = m0.rebuild_with_output_id_range((20, 40))
493+
m1_2 = m0.rebuild_with_output_id_range((0, 40))
503494
m1_zch_identities = torch.cat(
504495
[
505496
m1_1.state_dict()["_hash_zch_identities"],
@@ -516,53 +507,30 @@ def test_zch_hash_train_rescales_four(self, hash_size: int) -> None:
516507
state_dict["_hash_zch_identities"] = m1_zch_identities
517508
state_dict["_hash_zch_metadata"] = m1_zch_metadata
518509
m1_2.load_state_dict(state_dict)
519-
520-
m2_zch_identities = torch.cat(
521-
[
522-
m3_1.state_dict()["_hash_zch_identities"],
523-
m4_1.state_dict()["_hash_zch_identities"],
524-
]
525-
)
526-
m2_zch_metadata = torch.cat(
527-
[
528-
m3_1.state_dict()["_hash_zch_metadata"],
529-
m4_1.state_dict()["_hash_zch_metadata"],
530-
]
531-
)
532-
state_dict = m2_2.state_dict()
533-
state_dict["_hash_zch_identities"] = m2_zch_identities
534-
state_dict["_hash_zch_metadata"] = m2_zch_metadata
535-
m2_2.load_state_dict(state_dict)
536-
537510
_ = m1_2(in1_2.to_dict())
538-
_ = m2_2(in2_2.to_dict())
539511

540512
m0.reset_inference_mode() # just clears out training state
541513
full_zch_identities = torch.cat(
542514
[
543515
m1_2.state_dict()["_hash_zch_identities"],
544-
m2_2.state_dict()["_hash_zch_identities"],
545516
]
546517
)
547518
state_dict = m0.state_dict()
548519
state_dict["_hash_zch_identities"] = full_zch_identities
549520
m0.load_state_dict(state_dict)
550521

551-
# now set all models to eval, and run kjt
552522
m1_2.eval()
553-
m2_2.eval()
554523
assert m0.training is False
555524

556525
inf_input = kjt.to_dict()
557-
inf_output = m0(inf_input)
558526

527+
inf_output = m0(inf_input)
559528
o1_2 = m1_2(in1_2.to_dict())
560-
o2_2 = m2_2(in2_2.to_dict())
561529
self.assertTrue(
562530
torch.allclose(
563531
inf_output["f"].values(),
564532
torch.index_select(
565-
torch.cat([x["f"].values() for x in [o1_2, o2_2]]),
533+
o1_2["f"].values(),
566534
dim=0,
567535
index=cast(torch.Tensor, permute),
568536
),

0 commit comments

Comments
 (0)