@@ -213,7 +213,7 @@ def test_scriptability_lru(self) -> None:
213
213
torch .jit .script (mcc_ec )
214
214
215
215
@unittest .skipIf (
216
- torch .cuda .device_count () < 1 ,
216
+ torch .cuda .device_count () < 2 ,
217
217
"Not enough GPUs, this test requires at least one GPUs" ,
218
218
)
219
219
# pyre-ignore [56]
@@ -292,7 +292,7 @@ def test_zch_hash_train_to_inf_block_bucketize(
292
292
)
293
293
294
294
@unittest .skipIf (
295
- torch .cuda .device_count () < 1 ,
295
+ torch .cuda .device_count () < 2 ,
296
296
"Not enough GPUs, this test requires at least one GPUs" ,
297
297
)
298
298
# pyre-ignore [56]
@@ -404,13 +404,13 @@ def test_zch_hash_train_rescales_two(self, hash_size: int) -> None:
404
404
)
405
405
406
406
@unittest .skipIf (
407
- torch .cuda .device_count () < 1 ,
407
+ torch .cuda .device_count () < 2 ,
408
408
"Not enough GPUs, this test requires at least one GPUs" ,
409
409
)
410
410
# pyre-ignore [56]
411
411
@given (hash_size = st .sampled_from ([0 , 80 ]))
412
412
@settings (max_examples = 5 , deadline = None )
413
- def test_zch_hash_train_rescales_four (self , hash_size : int ) -> None :
413
+ def test_zch_hash_train_rescales_one (self , hash_size : int ) -> None :
414
414
keep_original_indices = True
415
415
kjt = KeyedJaggedTensor (
416
416
keys = ["f" ],
@@ -446,23 +446,20 @@ def test_zch_hash_train_rescales_four(self, hash_size: int) -> None:
446
446
),
447
447
)
448
448
449
- # start with world_size = 4
450
- world_size = 4
449
+ # start with world_size = 2
450
+ world_size = 2
451
451
block_sizes = torch .tensor (
452
452
[(size + world_size - 1 ) // world_size for size in [hash_size ]],
453
453
dtype = torch .int64 ,
454
454
device = "cuda" ,
455
455
)
456
456
457
- m1_1 = m0 .rebuild_with_output_id_range ((0 , 10 ))
458
- m2_1 = m0 .rebuild_with_output_id_range ((10 , 20 ))
459
- m3_1 = m0 .rebuild_with_output_id_range ((20 , 30 ))
460
- m4_1 = m0 .rebuild_with_output_id_range ((30 , 40 ))
457
+ m1_1 = m0 .rebuild_with_output_id_range ((0 , 20 ))
458
+ m2_1 = m0 .rebuild_with_output_id_range ((20 , 40 ))
461
459
462
- # shard, now world size 2!
463
- # start with world_size = 4
460
+ # shard, now world size 1!
464
461
if hash_size > 0 :
465
- world_size = 2
462
+ world_size = 1
466
463
block_sizes = torch .tensor (
467
464
[(size + world_size - 1 ) // world_size for size in [hash_size ]],
468
465
dtype = torch .int64 ,
@@ -476,7 +473,7 @@ def test_zch_hash_train_rescales_four(self, hash_size: int) -> None:
476
473
keep_original_indices = keep_original_indices ,
477
474
output_permute = True ,
478
475
)
479
- in1_2 , in2_2 = bucketized_kjt .split ([len (kjt .keys ())] * world_size )
476
+ in1_2 = bucketized_kjt .split ([len (kjt .keys ())] * world_size )[ 0 ]
480
477
else :
481
478
bucketized_kjt , permute = bucketize_kjt_before_all2all (
482
479
kjt ,
@@ -492,14 +489,8 @@ def test_zch_hash_train_rescales_four(self, hash_size: int) -> None:
492
489
values = torch .cat ([kjts [0 ].values (), kjts [1 ].values ()], dim = 0 ),
493
490
lengths = torch .cat ([kjts [0 ].lengths (), kjts [1 ].lengths ()], dim = 0 ),
494
491
)
495
- in2_2 = KeyedJaggedTensor (
496
- keys = kjts [2 ].keys (),
497
- values = torch .cat ([kjts [2 ].values (), kjts [3 ].values ()], dim = 0 ),
498
- lengths = torch .cat ([kjts [2 ].lengths (), kjts [3 ].lengths ()], dim = 0 ),
499
- )
500
492
501
- m1_2 = m0 .rebuild_with_output_id_range ((0 , 20 ))
502
- m2_2 = m0 .rebuild_with_output_id_range ((20 , 40 ))
493
+ m1_2 = m0 .rebuild_with_output_id_range ((0 , 40 ))
503
494
m1_zch_identities = torch .cat (
504
495
[
505
496
m1_1 .state_dict ()["_hash_zch_identities" ],
@@ -516,53 +507,30 @@ def test_zch_hash_train_rescales_four(self, hash_size: int) -> None:
516
507
state_dict ["_hash_zch_identities" ] = m1_zch_identities
517
508
state_dict ["_hash_zch_metadata" ] = m1_zch_metadata
518
509
m1_2 .load_state_dict (state_dict )
519
-
520
- m2_zch_identities = torch .cat (
521
- [
522
- m3_1 .state_dict ()["_hash_zch_identities" ],
523
- m4_1 .state_dict ()["_hash_zch_identities" ],
524
- ]
525
- )
526
- m2_zch_metadata = torch .cat (
527
- [
528
- m3_1 .state_dict ()["_hash_zch_metadata" ],
529
- m4_1 .state_dict ()["_hash_zch_metadata" ],
530
- ]
531
- )
532
- state_dict = m2_2 .state_dict ()
533
- state_dict ["_hash_zch_identities" ] = m2_zch_identities
534
- state_dict ["_hash_zch_metadata" ] = m2_zch_metadata
535
- m2_2 .load_state_dict (state_dict )
536
-
537
510
_ = m1_2 (in1_2 .to_dict ())
538
- _ = m2_2 (in2_2 .to_dict ())
539
511
540
512
m0 .reset_inference_mode () # just clears out training state
541
513
full_zch_identities = torch .cat (
542
514
[
543
515
m1_2 .state_dict ()["_hash_zch_identities" ],
544
- m2_2 .state_dict ()["_hash_zch_identities" ],
545
516
]
546
517
)
547
518
state_dict = m0 .state_dict ()
548
519
state_dict ["_hash_zch_identities" ] = full_zch_identities
549
520
m0 .load_state_dict (state_dict )
550
521
551
- # now set all models to eval, and run kjt
552
522
m1_2 .eval ()
553
- m2_2 .eval ()
554
523
assert m0 .training is False
555
524
556
525
inf_input = kjt .to_dict ()
557
- inf_output = m0 (inf_input )
558
526
527
+ inf_output = m0 (inf_input )
559
528
o1_2 = m1_2 (in1_2 .to_dict ())
560
- o2_2 = m2_2 (in2_2 .to_dict ())
561
529
self .assertTrue (
562
530
torch .allclose (
563
531
inf_output ["f" ].values (),
564
532
torch .index_select (
565
- torch . cat ([ x [ "f" ].values () for x in [ o1_2 , o2_2 ]] ),
533
+ o1_2 [ "f" ].values (),
566
534
dim = 0 ,
567
535
index = cast (torch .Tensor , permute ),
568
536
),
0 commit comments