We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 332b8b4 commit 75e7a0cCopy full SHA for 75e7a0c
torchrec/distributed/sharding/tw_sharding.py
@@ -135,12 +135,8 @@ def _shard(
135
dtensor_metadata = None
136
if self._env.output_dtensor:
137
dtensor_metadata = DTensorMetadata(
138
- mesh=(
139
- self._env.device_mesh["replicate"] # pyre-ignore[16]
140
- if self._is_2D_parallel
141
- else self._env.device_mesh
142
- ),
143
- placements=(Replicate(),),
+ mesh=self._env.device_mesh,
+ placements=(Replicate(),) * (self._env.device_mesh.ndim), # pyre-ignore[16]
144
size=(
145
info.embedding_config.num_embeddings,
146
info.embedding_config.embedding_dim,
0 commit comments