diff --git a/intermediate_source/TP_tutorial.rst b/intermediate_source/TP_tutorial.rst index 91e64a8748..4108e72b02 100644 --- a/intermediate_source/TP_tutorial.rst +++ b/intermediate_source/TP_tutorial.rst @@ -333,7 +333,7 @@ This 2-D parallelism pattern can be easily expressed via a 2-D DeviceMesh, and w from torch.distributed.device_mesh import init_device_mesh from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel, parallelize_module - from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + from torch.distributed.fsdp import fully_shard # i.e. 2-D mesh is [dp, tp], training on 64 GPUs that performs 8 way DP and 8 way TP mesh_2d = init_device_mesh("cuda", (8, 8)) @@ -347,7 +347,7 @@ This 2-D parallelism pattern can be easily expressed via a 2-D DeviceMesh, and w # apply Tensor Parallel intra-host on tp_mesh model_tp = parallelize_module(model, tp_mesh, tp_plan) # apply FSDP inter-host on dp_mesh - model_2d = FSDP(model_tp, device_mesh=dp_mesh, use_orig_params=True, ...) + model_2d = fully_shard(model_tp, mesh=dp_mesh, ...) This would allow us to easily apply Tensor Parallel within each host (intra-host) and apply FSDP across hosts (inter-hosts), with **0-code changes** to the Llama model.