Commit d2c1ed9
Enable Changing the # of shards for CW resharding (meta-pytorch#3188)
Summary:
Pull Request resolved: meta-pytorch#3188
Currently Dynamic Sharding assumes the # of shards per embedding table stays the same:
- https://www.internalfb.com/code/fbsource/[6d270632037a1e8bca7f63500dd07fd0b213e572]/fbcode/torchrec/distributed/sharding/dynamic_sharding.py?lines=140
E.g.
- `table_0` originally sharded on ranks: [0, 1]
- Reshard API currently supports moving `table_0` shards to ranks [1, 2].
- Where `the shard` on rank 0 will move to rank 1, and the shard on rank 1 will move to rank 2
We want to support changing the # of shards:
- e.g. table_0 originally on ranks: [0, 1] --> reshard to [0]
- Or reshard to [0, 1, 2, 3]
Here's the unit test you can modify to check if your usecase passes:
- https://www.internalfb.com/code/fbsource/[4d0d74b9f3c441e7aa35ce7102200fa0ca8c95cf]/fbcode/torchrec/distributed/tests/test_dynamic_sharding.py?lines=453-459
- Basically change the new sharding plan to be a different # of ranks than the original sharding plan.
Note: the new total number of ranks for each embedding table should be a factor of the dimension 0 of that embedding table
- e.g. emb_table size: [4, 8], this can only be sharded on 1, 2, or 4 ranks. not 3 ranks
Differential Revision: D782917171 parent fd9d78a commit d2c1ed9
File tree
3 files changed
+348
-105
lines changed- torchrec/distributed
- sharding
- test_utils
- tests
3 files changed
+348
-105
lines changed
0 commit comments