From 6c1ce4adf8f26e6a08a267831b40da468dc0dc76 Mon Sep 17 00:00:00 2001 From: Tianyu Liu Date: Mon, 14 Jul 2025 21:46:03 -0700 Subject: [PATCH] refactor FTManager --- scripts/estimate/estimation.py | 4 +- torchtitan/components/ft.py | 136 ++++++++------------- torchtitan/components/optimizer.py | 6 +- torchtitan/distributed/utils.py | 2 +- torchtitan/experiments/llama4/optimizer.py | 2 +- torchtitan/protocols/train_spec.py | 2 +- torchtitan/train.py | 30 ++--- 7 files changed, 66 insertions(+), 116 deletions(-) diff --git a/scripts/estimate/estimation.py b/scripts/estimate/estimation.py index aade405b4..cec91fdcd 100644 --- a/scripts/estimate/estimation.py +++ b/scripts/estimate/estimation.py @@ -15,7 +15,6 @@ from torch.distributed._tools.fsdp2_mem_tracker import FSDPMemTracker from torch.testing._internal.distributed.fake_pg import FakeStore -from torchtitan.components.ft import init_ft_manager from torchtitan.components.lr_scheduler import build_lr_schedulers from torchtitan.components.optimizer import build_optimizers from torchtitan.config_manager import ConfigManager, JobConfig @@ -116,8 +115,7 @@ def estimate_memory(job_config: JobConfig): model.train() # build optimizer after applying parallelisms to the model - ft_manager = init_ft_manager(job_config) - optimizers = build_optimizers([model], job_config, parallel_dims, ft_manager) + optimizers = build_optimizers([model], job_config, parallel_dims) lr_schedulers = build_lr_schedulers(optimizers.optimizers, job_config) # Post optimizer step model converters hook. # e.g. calculate float8 dynamic amax/scale for all-parameter for FSDP2 diff --git a/torchtitan/components/ft.py b/torchtitan/components/ft.py index 946fc4638..60e4d2f80 100644 --- a/torchtitan/components/ft.py +++ b/torchtitan/components/ft.py @@ -4,19 +4,15 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import copy import importlib from contextlib import nullcontext from typing import ContextManager, Optional, TYPE_CHECKING, Union import torch import torch.distributed as dist -import torch.distributed._functional_collectives as funcol from torch.distributed._composable.fsdp.fully_shard import FSDPModule -from torch.distributed.device_mesh import DeviceMesh from torch.distributed.distributed_c10d import ReduceOp -from torch.distributed.tensor import DTensor -from torchtitan.config_manager import JobConfig +from torchtitan.config_manager import FaultTolerance as FTConfig if importlib.util.find_spec("torchft") is not None: import torchft as ft @@ -32,14 +28,32 @@ class FTManager: def __init__( self, - manager: Optional["ft.Manager"], - group_size: int = 1, - replica_id: int = 0, + ft_config: FTConfig, ) -> None: - self._manager = manager - self.group_size = group_size - self.replica_id = replica_id - if has_torchft and manager is not None: + if not ft_config.enable: + self._manager = None + return + + if not has_torchft: + raise ImportError("torchft is not installed. Please install it.") + + pg = ft.ProcessGroupNCCL() + + # If the training method is specific, then the quorum should be synchronous + self.use_async_quorum = ft_config.semi_sync_method is None + + self._manager = ft.Manager( + pg=pg, + min_replica_size=ft_config.min_replica_size, + load_state_dict=None, + state_dict=None, + use_async_quorum=self.use_async_quorum, + replica_id=f"torchtitan_ft_{ft_config.replica_id}", + ) + self.group_size = ft_config.group_size + self.replica_id = ft_config.replica_id + + if self.use_async_quorum: self.replicate_pg = ft.process_group.ManagedProcessGroup(self._manager) self.replicate_pg.register("dp_replicate") @@ -53,85 +67,37 @@ def manager(self) -> "ft.Manager": return self._manager def get_dp_info(self, dp_degree: int, dp_rank: int) -> tuple[int, int]: - return dp_degree * self.group_size, dp_degree * self.replica_id + dp_rank - - def set_all_reduce_hook(self, model_parts: list[torch.nn.Module]) -> None: - def all_reduce_hook(output): - dist.all_reduce(output, group=self.replicate_pg, op=ReduceOp.AVG) - - def apply_set_all_reduce_hook(m): - if isinstance(m, FSDPModule): - m.set_all_reduce_hook(all_reduce_hook) - - for part in model_parts: - part.apply(apply_set_all_reduce_hook) - - -def init_ft_manager(job: JobConfig) -> FTManager: - """Initialize the FT manager if TorchFT is enabled. - - Args: - job (JobConfig): The job configuration. - - Returns: - FTManager: A wrapper around TorchFT.Manager - """ - if not job.fault_tolerance.enable: - return FTManager(None) - - if not has_torchft: - raise ImportError("torchft is not installed. Please install it.") - - if job.fault_tolerance.min_replica_size < 1: - raise ValueError("At least one FT replica is required.") - - pg = ft.ProcessGroupNCCL() + if self.enabled: + return dp_degree * self.group_size, dp_degree * self.replica_id + dp_rank + else: + return dp_degree, dp_rank - # If the training method is specific, then the quorum should be synchronous - use_async_quorum = job.fault_tolerance.semi_sync_method is None + def maybe_set_all_reduce_hook(self, model_parts: list[torch.nn.Module]) -> None: + if self.enabled and self.use_async_quorum: - return FTManager( - ft.Manager( - pg=pg, - min_replica_size=job.fault_tolerance.min_replica_size, - load_state_dict=None, - state_dict=None, - use_async_quorum=use_async_quorum, - replica_id=f"torchtitan_ft_{job.fault_tolerance.replica_id}", - ), - group_size=job.fault_tolerance.group_size, - replica_id=job.fault_tolerance.replica_id, - ) - - -def ft_dist_reduce( - x: torch.Tensor, reduceOp: str, mesh: DeviceMesh -) -> tuple[torch.Tensor, str, DeviceMesh]: - if has_torchft and isinstance(mesh, ft.device_mesh._FlattenDeviceMesh): - x = funcol.all_reduce( - x, reduceOp=reduceOp, group=mesh.managed_mesh.replicate_pg - ) - return x, reduceOp, mesh.managed_mesh.mesh - return x, reduceOp, mesh + def all_reduce_hook(output): + dist.all_reduce(output, group=self.replicate_pg, op=ReduceOp.AVG) + def apply_set_all_reduce_hook(m): + if isinstance(m, FSDPModule): + m.set_all_reduce_hook(all_reduce_hook) -def ft_clip_grad_norm_util(total_norm: DTensor) -> torch.Tensor: - if has_torchft: - mesh = total_norm._spec.mesh - if isinstance(mesh, ft.device_mesh.ManagedDeviceMesh): - # The gradients along the replicated dim has already been reduced. - # So we don't need another reducution beforing removing the - # replicate dimension - local_tensor = total_norm.to_local() - placements = list(copy.copy(total_norm._spec.placements)) - placements.pop(mesh.replicate_dim) - return DTensor.from_local(local_tensor, mesh.mesh, placements) + for model_part in model_parts: + model_part.apply(apply_set_all_reduce_hook) - return total_norm + @property + def loss_sync_pg( + self, + ) -> Optional["ft.process_group.ManagedProcessGroup"]: + if self.enabled and self.use_async_quorum: + return self.replicate_pg + else: + # skip loss sync when using semi-sync training + return None def maybe_semi_sync_training( - config: JobConfig, + ft_config: FTConfig, ft_manager: FTManager, model_parts: list[torch.nn.Module], optimizer: torch.optim.Optimizer, @@ -139,10 +105,8 @@ def maybe_semi_sync_training( """ If TorchFT is enabled and the config is set, use semi_sync_method """ - ft_config = config.fault_tolerance semi_sync_method = ft_config.semi_sync_method - torchft_enabled = ft_config.enable - if torchft_enabled and semi_sync_method is not None: + if ft_config.enable and semi_sync_method is not None: from torchft import local_sgd assert ( diff --git a/torchtitan/components/optimizer.py b/torchtitan/components/optimizer.py index d2ff514cf..ee87888d7 100644 --- a/torchtitan/components/optimizer.py +++ b/torchtitan/components/optimizer.py @@ -243,7 +243,7 @@ def build_optimizers( model_parts: list[nn.Module], job_config: JobConfig, parallel_dims: ParallelDims, - ft_manager: FTManager, + ft_manager: FTManager | None = None, ) -> OptimizersContainer: """Create a OptimizersContainer for the given model parts and job config. @@ -273,7 +273,7 @@ def build_optimizers( raise NotImplementedError( "Optimizers in backward is not supported with Pipeline Parallel." ) - if ft_manager.enabled: + if ft_manager and ft_manager.enabled: raise NotImplementedError( "TorchFT is not supported with optimizers in backward." ) @@ -313,7 +313,7 @@ def build_optimizers( model_parts, optimizer_cls, optimizer_kwargs ) - if ft_manager.enabled: + if ft_manager and ft_manager.enabled: return FTOptimizersContainer( model_parts, optimizer_cls, diff --git a/torchtitan/distributed/utils.py b/torchtitan/distributed/utils.py index 58c5df0ca..e25794a24 100644 --- a/torchtitan/distributed/utils.py +++ b/torchtitan/distributed/utils.py @@ -29,7 +29,7 @@ def _dist_reduce( x: torch.Tensor, reduceOp: str, mesh: DeviceMesh, - extra_pg: dist.ProcessGroup | None = None, + extra_pg: dist.ProcessGroup | None, ) -> float: """Perform distributed reduction on a tensor. diff --git a/torchtitan/experiments/llama4/optimizer.py b/torchtitan/experiments/llama4/optimizer.py index 11870f5fe..3b20f6b1d 100644 --- a/torchtitan/experiments/llama4/optimizer.py +++ b/torchtitan/experiments/llama4/optimizer.py @@ -48,7 +48,7 @@ def build_llama4_optimizers( model_parts: list[nn.Module], job_config: JobConfig, parallel_dims: ParallelDims, - ft_manager: FTManager, + ft_manager: FTManager | None = None, ) -> OptimizersContainer: optimizers = build_optimizers( model_parts=model_parts, diff --git a/torchtitan/protocols/train_spec.py b/torchtitan/protocols/train_spec.py index 3ee870771..0e376b2f6 100644 --- a/torchtitan/protocols/train_spec.py +++ b/torchtitan/protocols/train_spec.py @@ -78,7 +78,7 @@ def init_weights(self, buffer_device: torch.device | None = None) -> None: TokenizerBuilder: TypeAlias = Callable[..., BaseTokenizer] MetricsProcessorBuilder: TypeAlias = Callable[..., MetricsProcessor] OptimizersBuilder: TypeAlias = Callable[ - [list[nn.Module], JobConfig, ParallelDims, FTManager], + [list[nn.Module], JobConfig, ParallelDims, FTManager | None], OptimizersContainer, ] LRSchedulersBuilder: TypeAlias = Callable[ diff --git a/torchtitan/train.py b/torchtitan/train.py index ea7a8e2ef..b48292808 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -13,10 +13,10 @@ import torch from torch.distributed.elastic.multiprocessing.errors import record -import torchtitan.components.ft as ft import torchtitan.protocols.train_spec as train_spec_module from torchtitan.components.checkpoint import CheckpointManager from torchtitan.components.dataloader import DataloaderStopIteration +from torchtitan.components.ft import FTManager, maybe_semi_sync_training from torchtitan.components.loss import rescale_accumulated_loss from torchtitan.components.metrics import ( build_metrics_processor, @@ -50,7 +50,7 @@ class Trainer(torch.distributed.checkpoint.stateful.Stateful): # non-swappable training components checkpointer: CheckpointManager - ft_manager: ft.FTManager + ft_manager: FTManager # runtime utilities device: torch.device @@ -104,11 +104,8 @@ def __init__(self, job_config: JobConfig): else: dp_degree, dp_rank = 1, 0 - self.ft_manager = ft.init_ft_manager(job_config) - # If TorchFT is enabled, the dp_rank and dp_degree, which are used for - # dataloader must be changed. - if self.ft_manager.enabled: - dp_degree, dp_rank = self.ft_manager.get_dp_info(dp_degree, dp_rank) + self.ft_manager = FTManager(job_config.fault_tolerance) + dp_degree, dp_rank = self.ft_manager.get_dp_info(dp_degree, dp_rank) # take control of garbage collection to avoid stragglers self.gc_handler = utils.GarbageCollection( @@ -259,11 +256,7 @@ def __init__(self, job_config: JobConfig): self.model_parts = [model] - if ( - self.ft_manager.enabled - and job_config.fault_tolerance.semi_sync_method is None - ): - self.ft_manager.set_all_reduce_hook(self.model_parts) + self.ft_manager.maybe_set_all_reduce_hook(self.model_parts) # initialize device memory monitor and get peak flops for MFU calculation device_memory_monitor = self.metrics_processor.device_memory_monitor @@ -474,14 +467,9 @@ def train_step( if not self.metrics_processor.should_log(self.step): return - if parallel_dims.dp_cp_enabled or self.ft_manager.enabled: + if parallel_dims.dp_cp_enabled: loss = loss.detach() - # Skip ft manager communication when using semi sync training - use_ft_pg = ( - self.ft_manager.enabled - and self.job_config.fault_tolerance.semi_sync_method is None - ) - ft_pg = self.ft_manager.replicate_pg if use_ft_pg else None + ft_pg = self.ft_manager.loss_sync_pg global_avg_loss, global_max_loss = ( dist_utils.dist_mean(loss, parallel_dims.world_mesh["dp_cp"], ft_pg), dist_utils.dist_max(loss, parallel_dims.world_mesh["dp_cp"], ft_pg), @@ -508,8 +496,8 @@ def train(self): maybe_enable_memory_snapshot( job_config, global_step=self.step ) as memory_profiler, - ft.maybe_semi_sync_training( - job_config, + maybe_semi_sync_training( + job_config.fault_tolerance, ft_manager=self.ft_manager, model_parts=self.model_parts, optimizer=self.optimizers,