Skip to content

Commit 755434d

Browse files
authored
fsdp1 -> fsdp2 (#3464)
1 parent d5cd489 commit 755434d

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

intermediate_source/TP_tutorial.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,7 @@ This 2-D parallelism pattern can be easily expressed via a 2-D DeviceMesh, and w
333333
334334
from torch.distributed.device_mesh import init_device_mesh
335335
from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel, parallelize_module
336-
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
336+
from torch.distributed.fsdp import fully_shard
337337
338338
# i.e. 2-D mesh is [dp, tp], training on 64 GPUs that performs 8 way DP and 8 way TP
339339
mesh_2d = init_device_mesh("cuda", (8, 8))
@@ -347,7 +347,7 @@ This 2-D parallelism pattern can be easily expressed via a 2-D DeviceMesh, and w
347347
# apply Tensor Parallel intra-host on tp_mesh
348348
model_tp = parallelize_module(model, tp_mesh, tp_plan)
349349
# apply FSDP inter-host on dp_mesh
350-
model_2d = FSDP(model_tp, device_mesh=dp_mesh, use_orig_params=True, ...)
350+
model_2d = fully_shard(model_tp, mesh=dp_mesh, ...)
351351
352352
353353
This would allow us to easily apply Tensor Parallel within each host (intra-host) and apply FSDP across hosts (inter-hosts), with **0-code changes** to the Llama model.

0 commit comments

Comments
 (0)