We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 8c1555d commit 2174f26Copy full SHA for 2174f26
distributed/tensor_parallelism/fsdp_tp_example.py
@@ -49,7 +49,7 @@
49
from llama2_model import Transformer, ModelArgs
50
51
from torch.distributed.device_mesh import init_device_mesh
52
-from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
+from torch.distributed.fsdp import fully_shard
53
from torch.distributed._tensor import Shard, Replicate
54
from torch.distributed.tensor.parallel import (
55
parallelize_module,
@@ -146,7 +146,7 @@
146
)
147
148
# Init FSDP using the dp device mesh
149
-sharded_model = FSDP(model, device_mesh=dp_mesh, use_orig_params=True)
+sharded_model = fully_shard(model, mesh=dp_mesh)
150
151
rank_log(_rank, logger, f"Model after parallelization {sharded_model=}\n")
152
0 commit comments