@@ -60,6 +60,25 @@ def initialize_and_test_parameters(
6060 local_size : Optional [int ] = None ,
6161) -> None :
6262 with MultiProcessContext (rank , world_size , backend , local_size ) as ctx :
63+ # Set seed again in each process to ensure consistency
64+ torch .manual_seed (42 )
65+ if torch .cuda .is_available ():
66+ torch .cuda .manual_seed (42 )
67+
68+ key = (
69+ f"embeddings.{ table_name } .weight"
70+ if isinstance (embedding_tables , EmbeddingCollection )
71+ else f"embedding_bags.{ table_name } .weight"
72+ )
73+
74+ # Create the same fixed tensor in each process
75+ fixed_tensor = torch .randn (10 , 64 , generator = torch .Generator ().manual_seed (42 ))
76+
77+ # Load the fixed tensor into the embedding_tables to ensure consistency
78+ embedding_tables .load_state_dict ({key : fixed_tensor })
79+
80+ # Store the original tensor on CPU for comparison BEFORE creating the model
81+ original_tensor = embedding_tables .state_dict ()[key ].clone ().cpu ()
6382
6483 module_sharding_plan = construct_module_sharding_plan (
6584 embedding_tables ,
@@ -81,12 +100,6 @@ def initialize_and_test_parameters(
81100 device = ctx .device ,
82101 )
83102
84- key = (
85- f"embeddings.{ table_name } .weight"
86- if isinstance (embedding_tables , EmbeddingCollection )
87- else f"embedding_bags.{ table_name } .weight"
88- )
89-
90103 if isinstance (model .state_dict ()[key ], DTensor ):
91104 if ctx .rank == 0 :
92105 gathered_tensor = torch .empty (model .state_dict ()[key ].size ())
@@ -96,28 +109,26 @@ def initialize_and_test_parameters(
96109 gathered_tensor = model .state_dict ()[key ].full_tensor ()
97110 if ctx .rank == 0 :
98111 torch .testing .assert_close (
99- gathered_tensor ,
100- embedding_tables .state_dict ()[key ],
112+ gathered_tensor .cpu (), original_tensor , rtol = 1e-5 , atol = 1e-6
101113 )
102114 elif isinstance (model .state_dict ()[key ], ShardedTensor ):
103115 if ctx .rank == 0 :
104- gathered_tensor = torch .empty_like (
105- embedding_tables .state_dict ()[key ], device = ctx .device
106- )
116+ gathered_tensor = torch .empty_like (original_tensor , device = ctx .device )
107117 else :
108118 gathered_tensor = None
109119
110120 model .state_dict ()[key ].gather (dst = 0 , out = gathered_tensor )
111121
112122 if ctx .rank == 0 :
113123 torch .testing .assert_close (
114- none_throws (gathered_tensor ).to ("cpu" ),
115- embedding_tables .state_dict ()[key ],
124+ none_throws (gathered_tensor ).cpu (),
125+ original_tensor ,
126+ rtol = 1e-5 ,
127+ atol = 1e-6 ,
116128 )
117129 elif isinstance (model .state_dict ()[key ], torch .Tensor ):
118130 torch .testing .assert_close (
119- embedding_tables .state_dict ()[key ].cpu (),
120- model .state_dict ()[key ].cpu (),
131+ model .state_dict ()[key ].cpu (), original_tensor , rtol = 1e-5 , atol = 1e-6
121132 )
122133 else :
123134 raise AssertionError (
@@ -161,6 +172,9 @@ def test_initialize_parameters_ec(self, sharding_type: str) -> None:
161172 backend = "nccl"
162173 table_name = "free_parameters"
163174
175+ # Set seed for deterministic tensor generation
176+ torch .manual_seed (42 )
177+
164178 # Initialize embedding table on non-meta device, in this case cuda:0
165179 embedding_tables = EmbeddingCollection (
166180 tables = [
@@ -173,8 +187,10 @@ def test_initialize_parameters_ec(self, sharding_type: str) -> None:
173187 ],
174188 )
175189
190+ # Use a fixed tensor with explicit seeding for consistent testing
191+ fixed_tensor = torch .randn (10 , 64 , generator = torch .Generator ().manual_seed (42 ))
176192 embedding_tables .load_state_dict (
177- {f"embeddings.{ table_name } .weight" : torch . randn ( 10 , 64 ) }
193+ {f"embeddings.{ table_name } .weight" : fixed_tensor }
178194 )
179195
180196 self ._run_multi_process_test (
@@ -210,6 +226,9 @@ def test_initialize_parameters_ebc(self, sharding_type: str) -> None:
210226 backend = "nccl"
211227 table_name = "free_parameters"
212228
229+ # Set seed for deterministic tensor generation
230+ torch .manual_seed (42 )
231+
213232 # Initialize embedding bag on non-meta device, in this case cuda:0
214233 embedding_tables = EmbeddingBagCollection (
215234 tables = [
@@ -222,8 +241,10 @@ def test_initialize_parameters_ebc(self, sharding_type: str) -> None:
222241 ],
223242 )
224243
244+ # Use a fixed tensor with explicit seeding for consistent testing
245+ fixed_tensor = torch .randn (10 , 64 , generator = torch .Generator ().manual_seed (42 ))
225246 embedding_tables .load_state_dict (
226- {f"embedding_bags.{ table_name } .weight" : torch . randn ( 10 , 64 ) }
247+ {f"embedding_bags.{ table_name } .weight" : fixed_tensor }
227248 )
228249
229250 self ._run_multi_process_test (
0 commit comments