diff --git a/torchrec/distributed/embedding.py b/torchrec/distributed/embedding.py index 362240d88..553a1c9b4 100644 --- a/torchrec/distributed/embedding.py +++ b/torchrec/distributed/embedding.py @@ -954,7 +954,17 @@ def _initialize_torch_state(self) -> None: # noqa ( [ # assuming virtual table only supports rw sharding for now - 0 if dim == 0 else dim_size + # When backend return whole row, need to respect dim(1) + # otherwise will see shard dim exceeded tensor dim error + ( + 0 + if dim == 0 + else ( + local_shards[0].metadata.shard_sizes[1] + if dim == 1 + else dim_size + ) + ) for dim, dim_size in enumerate( self._name_to_table_size[table_name] )