Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions scripts/estimate/estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from torch.distributed._tools.fsdp2_mem_tracker import FSDPMemTracker
from torch.testing._internal.distributed.fake_pg import FakeStore

from torchtitan.components.ft import init_ft_manager
from torchtitan.components.lr_scheduler import build_lr_schedulers
from torchtitan.components.optimizer import build_optimizers
from torchtitan.config_manager import ConfigManager, JobConfig
Expand Down Expand Up @@ -116,8 +115,7 @@ def estimate_memory(job_config: JobConfig):
model.train()

# build optimizer after applying parallelisms to the model
ft_manager = init_ft_manager(job_config)
optimizers = build_optimizers([model], job_config, parallel_dims, ft_manager)
optimizers = build_optimizers([model], job_config, parallel_dims)
lr_schedulers = build_lr_schedulers(optimizers.optimizers, job_config)
# Post optimizer step model converters hook.
# e.g. calculate float8 dynamic amax/scale for all-parameter for FSDP2
Expand Down
136 changes: 50 additions & 86 deletions torchtitan/components/ft.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,15 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import copy
import importlib
from contextlib import nullcontext
from typing import ContextManager, Optional, TYPE_CHECKING, Union

import torch
import torch.distributed as dist
import torch.distributed._functional_collectives as funcol
from torch.distributed._composable.fsdp.fully_shard import FSDPModule
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.distributed_c10d import ReduceOp
from torch.distributed.tensor import DTensor
from torchtitan.config_manager import JobConfig
from torchtitan.config_manager import FaultTolerance as FTConfig

if importlib.util.find_spec("torchft") is not None:
import torchft as ft
Expand All @@ -32,14 +28,32 @@
class FTManager:
def __init__(
self,
manager: Optional["ft.Manager"],
group_size: int = 1,
replica_id: int = 0,
ft_config: FTConfig,
) -> None:
self._manager = manager
self.group_size = group_size
self.replica_id = replica_id
if has_torchft and manager is not None:
if not ft_config.enable:
self._manager = None
return

if not has_torchft:
raise ImportError("torchft is not installed. Please install it.")

pg = ft.ProcessGroupNCCL()

# If the training method is specific, then the quorum should be synchronous
self.use_async_quorum = ft_config.semi_sync_method is None

self._manager = ft.Manager(
pg=pg,
min_replica_size=ft_config.min_replica_size,
load_state_dict=None,
state_dict=None,
use_async_quorum=self.use_async_quorum,
replica_id=f"torchtitan_ft_{ft_config.replica_id}",
)
self.group_size = ft_config.group_size
self.replica_id = ft_config.replica_id

if self.use_async_quorum:
self.replicate_pg = ft.process_group.ManagedProcessGroup(self._manager)
self.replicate_pg.register("dp_replicate")

Expand All @@ -53,96 +67,46 @@ def manager(self) -> "ft.Manager":
return self._manager

def get_dp_info(self, dp_degree: int, dp_rank: int) -> tuple[int, int]:
return dp_degree * self.group_size, dp_degree * self.replica_id + dp_rank

def set_all_reduce_hook(self, model_parts: list[torch.nn.Module]) -> None:
def all_reduce_hook(output):
dist.all_reduce(output, group=self.replicate_pg, op=ReduceOp.AVG)

def apply_set_all_reduce_hook(m):
if isinstance(m, FSDPModule):
m.set_all_reduce_hook(all_reduce_hook)

for part in model_parts:
part.apply(apply_set_all_reduce_hook)


def init_ft_manager(job: JobConfig) -> FTManager:
"""Initialize the FT manager if TorchFT is enabled.

Args:
job (JobConfig): The job configuration.

Returns:
FTManager: A wrapper around TorchFT.Manager
"""
if not job.fault_tolerance.enable:
return FTManager(None)

if not has_torchft:
raise ImportError("torchft is not installed. Please install it.")

if job.fault_tolerance.min_replica_size < 1:
raise ValueError("At least one FT replica is required.")

pg = ft.ProcessGroupNCCL()
if self.enabled:
return dp_degree * self.group_size, dp_degree * self.replica_id + dp_rank
else:
return dp_degree, dp_rank

# If the training method is specific, then the quorum should be synchronous
use_async_quorum = job.fault_tolerance.semi_sync_method is None
def maybe_set_all_reduce_hook(self, model_parts: list[torch.nn.Module]) -> None:
if self.enabled and self.use_async_quorum:

return FTManager(
ft.Manager(
pg=pg,
min_replica_size=job.fault_tolerance.min_replica_size,
load_state_dict=None,
state_dict=None,
use_async_quorum=use_async_quorum,
replica_id=f"torchtitan_ft_{job.fault_tolerance.replica_id}",
),
group_size=job.fault_tolerance.group_size,
replica_id=job.fault_tolerance.replica_id,
)


def ft_dist_reduce(
x: torch.Tensor, reduceOp: str, mesh: DeviceMesh
) -> tuple[torch.Tensor, str, DeviceMesh]:
if has_torchft and isinstance(mesh, ft.device_mesh._FlattenDeviceMesh):
x = funcol.all_reduce(
x, reduceOp=reduceOp, group=mesh.managed_mesh.replicate_pg
)
return x, reduceOp, mesh.managed_mesh.mesh
return x, reduceOp, mesh
def all_reduce_hook(output):
dist.all_reduce(output, group=self.replicate_pg, op=ReduceOp.AVG)

def apply_set_all_reduce_hook(m):
if isinstance(m, FSDPModule):
m.set_all_reduce_hook(all_reduce_hook)

def ft_clip_grad_norm_util(total_norm: DTensor) -> torch.Tensor:
if has_torchft:
mesh = total_norm._spec.mesh
if isinstance(mesh, ft.device_mesh.ManagedDeviceMesh):
# The gradients along the replicated dim has already been reduced.
# So we don't need another reducution beforing removing the
# replicate dimension
local_tensor = total_norm.to_local()
placements = list(copy.copy(total_norm._spec.placements))
placements.pop(mesh.replicate_dim)
return DTensor.from_local(local_tensor, mesh.mesh, placements)
for model_part in model_parts:
model_part.apply(apply_set_all_reduce_hook)

return total_norm
@property
def loss_sync_pg(
self,
) -> Optional["ft.process_group.ManagedProcessGroup"]:
if self.enabled and self.use_async_quorum:
return self.replicate_pg
else:
# skip loss sync when using semi-sync training
return None


def maybe_semi_sync_training(
config: JobConfig,
ft_config: FTConfig,
ft_manager: FTManager,
model_parts: list[torch.nn.Module],
optimizer: torch.optim.Optimizer,
) -> ContextManager[Union["local_sgd.DiLoCo", "local_sgd.LocalSGD", None]]:
"""
If TorchFT is enabled and the config is set, use semi_sync_method
"""
ft_config = config.fault_tolerance
semi_sync_method = ft_config.semi_sync_method
torchft_enabled = ft_config.enable
if torchft_enabled and semi_sync_method is not None:
if ft_config.enable and semi_sync_method is not None:
from torchft import local_sgd

assert (
Expand Down
6 changes: 3 additions & 3 deletions torchtitan/components/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def build_optimizers(
model_parts: list[nn.Module],
job_config: JobConfig,
parallel_dims: ParallelDims,
ft_manager: FTManager,
ft_manager: FTManager | None = None,
) -> OptimizersContainer:
"""Create a OptimizersContainer for the given model parts and job config.

Expand Down Expand Up @@ -273,7 +273,7 @@ def build_optimizers(
raise NotImplementedError(
"Optimizers in backward is not supported with Pipeline Parallel."
)
if ft_manager.enabled:
if ft_manager and ft_manager.enabled:
raise NotImplementedError(
"TorchFT is not supported with optimizers in backward."
)
Expand Down Expand Up @@ -313,7 +313,7 @@ def build_optimizers(
model_parts, optimizer_cls, optimizer_kwargs
)

if ft_manager.enabled:
if ft_manager and ft_manager.enabled:
return FTOptimizersContainer(
model_parts,
optimizer_cls,
Expand Down
2 changes: 1 addition & 1 deletion torchtitan/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def _dist_reduce(
x: torch.Tensor,
reduceOp: str,
mesh: DeviceMesh,
extra_pg: dist.ProcessGroup | None = None,
extra_pg: dist.ProcessGroup | None,
) -> float:
"""Perform distributed reduction on a tensor.

Expand Down
2 changes: 1 addition & 1 deletion torchtitan/experiments/llama4/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def build_llama4_optimizers(
model_parts: list[nn.Module],
job_config: JobConfig,
parallel_dims: ParallelDims,
ft_manager: FTManager,
ft_manager: FTManager | None = None,
) -> OptimizersContainer:
optimizers = build_optimizers(
model_parts=model_parts,
Expand Down
2 changes: 1 addition & 1 deletion torchtitan/protocols/train_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def init_weights(self, buffer_device: torch.device | None = None) -> None:
TokenizerBuilder: TypeAlias = Callable[..., BaseTokenizer]
MetricsProcessorBuilder: TypeAlias = Callable[..., MetricsProcessor]
OptimizersBuilder: TypeAlias = Callable[
[list[nn.Module], JobConfig, ParallelDims, FTManager],
[list[nn.Module], JobConfig, ParallelDims, FTManager | None],
OptimizersContainer,
]
LRSchedulersBuilder: TypeAlias = Callable[
Expand Down
30 changes: 9 additions & 21 deletions torchtitan/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@
import torch
from torch.distributed.elastic.multiprocessing.errors import record

import torchtitan.components.ft as ft
import torchtitan.protocols.train_spec as train_spec_module
from torchtitan.components.checkpoint import CheckpointManager
from torchtitan.components.dataloader import DataloaderStopIteration
from torchtitan.components.ft import FTManager, maybe_semi_sync_training
from torchtitan.components.loss import rescale_accumulated_loss
from torchtitan.components.metrics import (
build_metrics_processor,
Expand Down Expand Up @@ -50,7 +50,7 @@ class Trainer(torch.distributed.checkpoint.stateful.Stateful):

# non-swappable training components
checkpointer: CheckpointManager
ft_manager: ft.FTManager
ft_manager: FTManager

# runtime utilities
device: torch.device
Expand Down Expand Up @@ -104,11 +104,8 @@ def __init__(self, job_config: JobConfig):
else:
dp_degree, dp_rank = 1, 0

self.ft_manager = ft.init_ft_manager(job_config)
# If TorchFT is enabled, the dp_rank and dp_degree, which are used for
# dataloader must be changed.
if self.ft_manager.enabled:
dp_degree, dp_rank = self.ft_manager.get_dp_info(dp_degree, dp_rank)
self.ft_manager = FTManager(job_config.fault_tolerance)
dp_degree, dp_rank = self.ft_manager.get_dp_info(dp_degree, dp_rank)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i looks like dp_degree and dp_rank are needed for the dataloader. If we move the dataloader initialization after the model initialization, then I think we can also move maybe_set_all_reduce_hook here to consolidate all the torchft code together.

Can look into it in a follow up PR


# take control of garbage collection to avoid stragglers
self.gc_handler = utils.GarbageCollection(
Expand Down Expand Up @@ -259,11 +256,7 @@ def __init__(self, job_config: JobConfig):

self.model_parts = [model]

if (
self.ft_manager.enabled
and job_config.fault_tolerance.semi_sync_method is None
):
self.ft_manager.set_all_reduce_hook(self.model_parts)
self.ft_manager.maybe_set_all_reduce_hook(self.model_parts)

# initialize device memory monitor and get peak flops for MFU calculation
device_memory_monitor = self.metrics_processor.device_memory_monitor
Expand Down Expand Up @@ -474,14 +467,9 @@ def train_step(
if not self.metrics_processor.should_log(self.step):
return

if parallel_dims.dp_cp_enabled or self.ft_manager.enabled:
if parallel_dims.dp_cp_enabled:
loss = loss.detach()
# Skip ft manager communication when using semi sync training
use_ft_pg = (
self.ft_manager.enabled
and self.job_config.fault_tolerance.semi_sync_method is None
)
ft_pg = self.ft_manager.replicate_pg if use_ft_pg else None
ft_pg = self.ft_manager.loss_sync_pg
global_avg_loss, global_max_loss = (
dist_utils.dist_mean(loss, parallel_dims.world_mesh["dp_cp"], ft_pg),
dist_utils.dist_max(loss, parallel_dims.world_mesh["dp_cp"], ft_pg),
Expand All @@ -508,8 +496,8 @@ def train(self):
maybe_enable_memory_snapshot(
job_config, global_step=self.step
) as memory_profiler,
ft.maybe_semi_sync_training(
job_config,
maybe_semi_sync_training(
job_config.fault_tolerance,
ft_manager=self.ft_manager,
model_parts=self.model_parts,
optimizer=self.optimizers,
Expand Down
Loading