1818 EmbeddingTowerSharder ,
1919)
2020from torchrec .distributed .embedding_types import EmbeddingComputeKernel
21- from torchrec .distributed .embeddingbag import EmbeddingBagCollectionSharder
21+ from torchrec .distributed .embeddingbag import (
22+ EmbeddingBagCollection ,
23+ EmbeddingBagCollectionSharder ,
24+ )
2225from torchrec .distributed .mc_embeddingbag import (
2326 ManagedCollisionEmbeddingBagCollectionSharder ,
2427)
4548 [[17 , 80 ], [17 , 80 ], [17 , 80 ], [17 , 80 ], [17 , 80 ], [17 , 80 ], [17 , 80 ], [11 , 80 ]],
4649]
4750
51+ EXPECTED_RW_SHARD_SIZES_WITH_BUCKETS = [
52+ [[20 , 20 ], [20 , 20 ], [10 , 20 ], [10 , 20 ], [10 , 20 ], [10 , 20 ], [10 , 20 ], [10 , 20 ]],
53+ [[22 , 40 ], [22 , 40 ], [11 , 40 ], [11 , 40 ], [11 , 40 ], [11 , 40 ], [11 , 40 ], [11 , 40 ]],
54+ [[24 , 60 ], [24 , 60 ], [12 , 60 ], [12 , 60 ], [12 , 60 ], [12 , 60 ], [12 , 60 ], [12 , 60 ]],
55+ [[26 , 80 ], [26 , 80 ], [13 , 80 ], [13 , 80 ], [13 , 80 ], [13 , 80 ], [13 , 80 ], [13 , 80 ]],
56+ ]
57+
4858EXPECTED_RW_SHARD_OFFSETS = [
4959 [[0 , 0 ], [13 , 0 ], [26 , 0 ], [39 , 0 ], [52 , 0 ], [65 , 0 ], [78 , 0 ], [91 , 0 ]],
5060 [[0 , 0 ], [14 , 0 ], [28 , 0 ], [42 , 0 ], [56 , 0 ], [70 , 0 ], [84 , 0 ], [98 , 0 ]],
5161 [[0 , 0 ], [15 , 0 ], [30 , 0 ], [45 , 0 ], [60 , 0 ], [75 , 0 ], [90 , 0 ], [105 , 0 ]],
5262 [[0 , 0 ], [17 , 0 ], [34 , 0 ], [51 , 0 ], [68 , 0 ], [85 , 0 ], [102 , 0 ], [119 , 0 ]],
5363]
5464
65+ EXPECTED_RW_SHARD_OFFSETS_WITH_BUCKETS = [
66+ [[0 , 0 ], [20 , 0 ], [40 , 0 ], [50 , 0 ], [60 , 0 ], [70 , 0 ], [80 , 0 ], [90 , 0 ]],
67+ [[0 , 0 ], [22 , 0 ], [44 , 0 ], [55 , 0 ], [66 , 0 ], [77 , 0 ], [88 , 0 ], [99 , 0 ]],
68+ [[0 , 0 ], [24 , 0 ], [48 , 0 ], [60 , 0 ], [72 , 0 ], [84 , 0 ], [96 , 0 ], [108 , 0 ]],
69+ [[0 , 0 ], [26 , 0 ], [52 , 0 ], [65 , 0 ], [78 , 0 ], [91 , 0 ], [104 , 0 ], [117 , 0 ]],
70+ ]
71+
5572
5673def get_expected_cache_aux_size (rows : int ) -> int :
5774 # 0.2 is the hardcoded cache load factor assumed in this test
@@ -101,6 +118,48 @@ def get_expected_cache_aux_size(rows: int) -> int:
101118 ],
102119]
103120
121+ EXPECTED_VIRTUAL_TABLE_RW_SHARD_STORAGE_WITH_BUCKETS = [
122+ [
123+ Storage (hbm = 165888 , ddr = 0 ),
124+ Storage (hbm = 165888 , ddr = 0 ),
125+ Storage (hbm = 165888 , ddr = 0 ),
126+ Storage (hbm = 165888 , ddr = 0 ),
127+ Storage (hbm = 165888 , ddr = 0 ),
128+ Storage (hbm = 165888 , ddr = 0 ),
129+ Storage (hbm = 165888 , ddr = 0 ),
130+ Storage (hbm = 165888 , ddr = 0 ),
131+ ],
132+ [
133+ Storage (hbm = 1001472 , ddr = 0 ),
134+ Storage (hbm = 1001472 , ddr = 0 ),
135+ Storage (hbm = 1001472 , ddr = 0 ),
136+ Storage (hbm = 1001472 , ddr = 0 ),
137+ Storage (hbm = 1001472 , ddr = 0 ),
138+ Storage (hbm = 1001472 , ddr = 0 ),
139+ Storage (hbm = 1001472 , ddr = 0 ),
140+ Storage (hbm = 1001472 , ddr = 0 ),
141+ ],
142+ [
143+ Storage (hbm = 1003520 , ddr = 0 ),
144+ Storage (hbm = 1003520 , ddr = 0 ),
145+ Storage (hbm = 1003520 , ddr = 0 ),
146+ Storage (hbm = 1003520 , ddr = 0 ),
147+ Storage (hbm = 1003520 , ddr = 0 ),
148+ Storage (hbm = 1003520 , ddr = 0 ),
149+ Storage (hbm = 1003520 , ddr = 0 ),
150+ Storage (hbm = 1003520 , ddr = 0 ),
151+ ],
152+ [
153+ Storage (hbm = 2648064 , ddr = 0 ),
154+ Storage (hbm = 2648064 , ddr = 0 ),
155+ Storage (hbm = 2648064 , ddr = 0 ),
156+ Storage (hbm = 2648064 , ddr = 0 ),
157+ Storage (hbm = 2648064 , ddr = 0 ),
158+ Storage (hbm = 2648064 , ddr = 0 ),
159+ Storage (hbm = 2648064 , ddr = 0 ),
160+ Storage (hbm = 2648064 , ddr = 0 ),
161+ ],
162+ ]
104163
105164EXPECTED_UVM_CACHING_RW_SHARD_STORAGE = [
106165 [
@@ -145,6 +204,48 @@ def get_expected_cache_aux_size(rows: int) -> int:
145204 ],
146205]
147206
207+ EXPECTED_UVM_CACHING_RW_SHARD_STORAGE_WITH_BUCKETS = [
208+ [
209+ Storage (hbm = 166352 , ddr = 1600 ),
210+ Storage (hbm = 166352 , ddr = 1600 ),
211+ Storage (hbm = 166120 , ddr = 800 ),
212+ Storage (hbm = 166120 , ddr = 800 ),
213+ Storage (hbm = 166120 , ddr = 800 ),
214+ Storage (hbm = 166120 , ddr = 800 ),
215+ Storage (hbm = 166120 , ddr = 800 ),
216+ Storage (hbm = 166120 , ddr = 800 ),
217+ ],
218+ [
219+ Storage (hbm = 1002335 , ddr = 3520 ),
220+ Storage (hbm = 1002335 , ddr = 3520 ),
221+ Storage (hbm = 1001904 , ddr = 1760 ),
222+ Storage (hbm = 1001904 , ddr = 1760 ),
223+ Storage (hbm = 1001904 , ddr = 1760 ),
224+ Storage (hbm = 1001904 , ddr = 1760 ),
225+ Storage (hbm = 1001904 , ddr = 1760 ),
226+ Storage (hbm = 1001904 , ddr = 1760 ),
227+ ],
228+ [
229+ Storage (hbm = 1004845 , ddr = 5760 ),
230+ Storage (hbm = 1004845 , ddr = 5760 ),
231+ Storage (hbm = 1004183 , ddr = 2880 ),
232+ Storage (hbm = 1004183 , ddr = 2880 ),
233+ Storage (hbm = 1004183 , ddr = 2880 ),
234+ Storage (hbm = 1004183 , ddr = 2880 ),
235+ Storage (hbm = 1004183 , ddr = 2880 ),
236+ Storage (hbm = 1004183 , ddr = 2880 ),
237+ ],
238+ [
239+ Storage (hbm = 2649916 , ddr = 8320 ),
240+ Storage (hbm = 2649916 , ddr = 8320 ),
241+ Storage (hbm = 2648990 , ddr = 4160 ),
242+ Storage (hbm = 2648990 , ddr = 4160 ),
243+ Storage (hbm = 2648990 , ddr = 4160 ),
244+ Storage (hbm = 2648990 , ddr = 4160 ),
245+ Storage (hbm = 2648990 , ddr = 4160 ),
246+ Storage (hbm = 2648990 , ddr = 4160 ),
247+ ],
248+ ]
148249
149250EXPECTED_TWRW_SHARD_SIZES = [
150251 [[25 , 20 ], [25 , 20 ], [25 , 20 ], [25 , 20 ]],
@@ -248,6 +349,16 @@ def compute_kernels(
248349 return [EmbeddingComputeKernel .FUSED .value ]
249350
250351
352+ class VirtualTableRWSharder (EmbeddingBagCollectionSharder ):
353+ def sharding_types (self , compute_device_type : str ) -> List [str ]:
354+ return [ShardingType .ROW_WISE .value ]
355+
356+ def compute_kernels (
357+ self , sharding_type : str , compute_device_type : str
358+ ) -> List [str ]:
359+ return [EmbeddingComputeKernel .DRAM_VIRTUAL_TABLE .value ]
360+
361+
251362class UVMCachingRWSharder (EmbeddingBagCollectionSharder ):
252363 def sharding_types (self , compute_device_type : str ) -> List [str ]:
253364 return [ShardingType .ROW_WISE .value ]
@@ -357,6 +468,27 @@ def setUp(self) -> None:
357468 min_partition = 40 , pooling_factors = [2 , 1 , 3 , 7 ]
358469 ),
359470 }
471+ self ._virtual_table_constraints = {
472+ "table_0" : ParameterConstraints (
473+ min_partition = 20 ,
474+ compute_kernels = [EmbeddingComputeKernel .DRAM_VIRTUAL_TABLE .value ],
475+ ),
476+ "table_1" : ParameterConstraints (
477+ min_partition = 20 ,
478+ pooling_factors = [1 , 3 , 5 ],
479+ compute_kernels = [EmbeddingComputeKernel .DRAM_VIRTUAL_TABLE .value ],
480+ ),
481+ "table_2" : ParameterConstraints (
482+ min_partition = 20 ,
483+ pooling_factors = [8 , 2 ],
484+ compute_kernels = [EmbeddingComputeKernel .DRAM_VIRTUAL_TABLE .value ],
485+ ),
486+ "table_3" : ParameterConstraints (
487+ min_partition = 40 ,
488+ pooling_factors = [2 , 1 , 3 , 7 ],
489+ compute_kernels = [EmbeddingComputeKernel .DRAM_VIRTUAL_TABLE .value ],
490+ ),
491+ }
360492 self .num_tables = 4
361493 tables = [
362494 EmbeddingBagConfig (
@@ -367,6 +499,17 @@ def setUp(self) -> None:
367499 )
368500 for i in range (self .num_tables )
369501 ]
502+ tables_with_buckets = [
503+ EmbeddingBagConfig (
504+ num_embeddings = 100 + i * 10 ,
505+ embedding_dim = 20 + i * 20 ,
506+ name = "table_" + str (i ),
507+ feature_names = ["feature_" + str (i )],
508+ total_num_buckets = 10 ,
509+ use_virtual_table = True ,
510+ )
511+ for i in range (self .num_tables )
512+ ]
370513 weighted_tables = [
371514 EmbeddingBagConfig (
372515 num_embeddings = (i + 1 ) * 10 ,
@@ -377,6 +520,9 @@ def setUp(self) -> None:
377520 for i in range (4 )
378521 ]
379522 self .model = TestSparseNN (tables = tables , weighted_tables = [])
523+ self .model_with_buckets = EmbeddingBagCollection (
524+ tables = tables_with_buckets ,
525+ )
380526 self .enumerator = EmbeddingEnumerator (
381527 topology = Topology (
382528 world_size = self .world_size ,
@@ -386,6 +532,15 @@ def setUp(self) -> None:
386532 batch_size = self .batch_size ,
387533 constraints = self .constraints ,
388534 )
535+ self .virtual_table_enumerator = EmbeddingEnumerator (
536+ topology = Topology (
537+ world_size = self .world_size ,
538+ compute_device = self .compute_device ,
539+ local_world_size = self .local_world_size ,
540+ ),
541+ batch_size = self .batch_size ,
542+ constraints = self ._virtual_table_constraints ,
543+ )
389544 self .tower_model = TestTowerSparseNN (
390545 tables = tables , weighted_tables = weighted_tables
391546 )
@@ -514,6 +669,26 @@ def test_rw_sharding(self) -> None:
514669 EXPECTED_RW_SHARD_STORAGE [i ],
515670 )
516671
672+ def test_virtual_table_rw_sharding_with_buckets (self ) -> None :
673+ sharding_options = self .virtual_table_enumerator .enumerate (
674+ self .model_with_buckets ,
675+ [cast (ModuleSharder [torch .nn .Module ], VirtualTableRWSharder ())],
676+ )
677+ for i , sharding_option in enumerate (sharding_options ):
678+ self .assertEqual (sharding_option .sharding_type , ShardingType .ROW_WISE .value )
679+ self .assertEqual (
680+ [shard .size for shard in sharding_option .shards ],
681+ EXPECTED_RW_SHARD_SIZES_WITH_BUCKETS [i ],
682+ )
683+ self .assertEqual (
684+ [shard .offset for shard in sharding_option .shards ],
685+ EXPECTED_RW_SHARD_OFFSETS_WITH_BUCKETS [i ],
686+ )
687+ self .assertEqual (
688+ [shard .storage for shard in sharding_option .shards ],
689+ EXPECTED_VIRTUAL_TABLE_RW_SHARD_STORAGE_WITH_BUCKETS [i ],
690+ )
691+
517692 def test_uvm_caching_rw_sharding (self ) -> None :
518693 sharding_options = self .enumerator .enumerate (
519694 self .model ,
@@ -535,6 +710,26 @@ def test_uvm_caching_rw_sharding(self) -> None:
535710 EXPECTED_UVM_CACHING_RW_SHARD_STORAGE [i ],
536711 )
537712
713+ def test_uvm_caching_rw_sharding_with_buckets (self ) -> None :
714+ sharding_options = self .enumerator .enumerate (
715+ self .model_with_buckets ,
716+ [cast (ModuleSharder [torch .nn .Module ], UVMCachingRWSharder ())],
717+ )
718+ for i , sharding_option in enumerate (sharding_options ):
719+ self .assertEqual (sharding_option .sharding_type , ShardingType .ROW_WISE .value )
720+ self .assertEqual (
721+ [shard .size for shard in sharding_option .shards ],
722+ EXPECTED_RW_SHARD_SIZES_WITH_BUCKETS [i ],
723+ )
724+ self .assertEqual (
725+ [shard .offset for shard in sharding_option .shards ],
726+ EXPECTED_RW_SHARD_OFFSETS_WITH_BUCKETS [i ],
727+ )
728+ self .assertEqual (
729+ [shard .storage for shard in sharding_option .shards ],
730+ EXPECTED_UVM_CACHING_RW_SHARD_STORAGE_WITH_BUCKETS [i ],
731+ )
732+
538733 def test_twrw_sharding (self ) -> None :
539734 sharding_options = self .enumerator .enumerate (
540735 self .model , [cast (ModuleSharder [torch .nn .Module ], TWRWSharder ())]
0 commit comments