You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
# fsdp hooks based on https://github.com/pytorch/pytorch/blob/20e40492b046b9287726d3ec656117e4dc38f0e2/test/distributed/_composable/fsdp/test_fully_shard_extensions.py#L81
@@ -155,14 +172,16 @@ def fsdp_post_all_gather(
155
172
):
156
173
(data,) =all_gather_outputs
157
174
158
-
# For training step 1+, out=unshared param.
175
+
# For training step 1+, out=unsharded param.
159
176
ifoutisnotNone:
160
177
ifisinstance(out, ScaledGroupedMMTensor):
161
178
out_data=out._data
179
+
out.scaling_type=self.scaling_type
162
180
elifisinstance(out, DTensor) andisinstance(
163
181
out._local_tensor, ScaledGroupedMMTensor
164
182
):
165
183
out_data=out._local_tensor._data
184
+
out._local_tensor.scaling_type=self.scaling_type
166
185
else:
167
186
raiseRuntimeError(
168
187
f"expect out to be ScaledGroupedMMTensor or DTensor with local_tensor=ScaledGroupedMM, but got {type(out)}"
@@ -185,6 +204,6 @@ def fsdp_post_all_gather(
185
204
return
186
205
187
206
# For training step 0, out=None, so we need to return a new ScaledGroupedMMTensor.
0 commit comments