Skip to content

Commit 75e7a0c

Browse files
iamzainhudafacebook-github-bot
authored andcommitted
fix DTensor placements for table wise sharding (#3214)
Summary: TSIA Reviewed By: wconstab, XilunWu Differential Revision: D78594015
1 parent 332b8b4 commit 75e7a0c

File tree

1 file changed

+2
-6
lines changed

1 file changed

+2
-6
lines changed

torchrec/distributed/sharding/tw_sharding.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -135,12 +135,8 @@ def _shard(
135135
dtensor_metadata = None
136136
if self._env.output_dtensor:
137137
dtensor_metadata = DTensorMetadata(
138-
mesh=(
139-
self._env.device_mesh["replicate"] # pyre-ignore[16]
140-
if self._is_2D_parallel
141-
else self._env.device_mesh
142-
),
143-
placements=(Replicate(),),
138+
mesh=self._env.device_mesh,
139+
placements=(Replicate(),) * (self._env.device_mesh.ndim), # pyre-ignore[16]
144140
size=(
145141
info.embedding_config.num_embeddings,
146142
info.embedding_config.embedding_dim,

0 commit comments

Comments
 (0)