Skip to content

Commit 2174f26

Browse files
wwwjnsoumith
authored andcommitted
fsdp_tp_example fsdp1-> fsdp2
1 parent 8c1555d commit 2174f26

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

distributed/tensor_parallelism/fsdp_tp_example.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
from llama2_model import Transformer, ModelArgs
5050

5151
from torch.distributed.device_mesh import init_device_mesh
52-
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
52+
from torch.distributed.fsdp import fully_shard
5353
from torch.distributed._tensor import Shard, Replicate
5454
from torch.distributed.tensor.parallel import (
5555
parallelize_module,
@@ -146,7 +146,7 @@
146146
)
147147

148148
# Init FSDP using the dp device mesh
149-
sharded_model = FSDP(model, device_mesh=dp_mesh, use_orig_params=True)
149+
sharded_model = fully_shard(model, mesh=dp_mesh)
150150

151151
rank_log(_rank, logger, f"Model after parallelization {sharded_model=}\n")
152152

0 commit comments

Comments
 (0)