Skip to content

Commit 1c7c38d

Browse files
committed
add the forge folder
1 parent 34d815c commit 1c7c38d

File tree

6 files changed

+723
-0
lines changed

6 files changed

+723
-0
lines changed
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
## `ForgeEngine`
2+
3+
The `forge` folder contains a lightweight training engine that serves as a streamlined subset of the `Trainer` class from [torchtitan/train.py](/torchtitan/train.py). This engine provides only the essential constructor method, making it highly flexible for various downstream applications.
4+
5+
The [`ForgeEngine`](engine.py) takes a [`ForgeJobConfig`](job_config.py) to
6+
- Initialize an SPMD distributed training environment
7+
- Construct and scale models via n-D parallelisms and meta-device initialization
8+
- Provide necessary training components and utilities
9+
10+
**Primary Use Case**: The engine is designed for building trainers in post-training workflows where multiple specialized components (trainer, generator, replay buffer, parameter server, etc.) work together.
11+
12+
Additionally, the folder provides a train spec registration method [`register_train_spec`](train_spec.py) that allows users to extend beyond the core set of models and training components available in torchtitan, enabling greater flexibility and customization for specific training requirements.
13+
14+
The [example_train.py](./example_train.py) demonstrates how to use `ForgeEngine` for pretraining, achieving the same functionality as [torchtitan/train.py](/torchtitan/train.py) (except for quantization or fault tolerance).
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from .engine import ForgeEngine
8+
from .job_config import ForgeJobConfig
9+
from .train_spec import ForgeTrainSpec, register_train_spec
10+
11+
__all__ = ["ForgeEngine", "ForgeJobConfig", "ForgeTrainSpec", "register_train_spec"]
Lines changed: 233 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,233 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import os
8+
from typing import Generator
9+
10+
import torch
11+
from torch.distributed.elastic.multiprocessing.errors import record
12+
13+
import torchtitan.protocols.train_spec as train_spec_module
14+
from torchtitan.components.checkpoint import CheckpointManager
15+
from torchtitan.components.loss import rescale_accumulated_loss
16+
from torchtitan.distributed import ParallelDims, utils as dist_utils
17+
from torchtitan.protocols.train_spec import BaseModelArgs
18+
from torchtitan.tools import utils
19+
20+
from .job_config import ForgeJobConfig
21+
from .train_spec import ForgeTrainSpec, get_train_spec
22+
23+
24+
class ForgeEngine(torch.distributed.checkpoint.stateful.Stateful):
25+
# core configs
26+
job_config: ForgeJobConfig
27+
parallel_dims: ParallelDims
28+
train_spec: ForgeTrainSpec
29+
30+
# swappable training components in ForgeTrainSpec
31+
model_parts: list[torch.nn.Module]
32+
loss_fn: train_spec_module.LossFunction
33+
optimizers: train_spec_module.OptimizersContainer
34+
lr_schedulers: train_spec_module.LRSchedulersContainer
35+
36+
# non-swappable training components
37+
checkpointer: CheckpointManager
38+
39+
# runtime utilities
40+
device: torch.device
41+
gc_handler: utils.GarbageCollection
42+
gradient_accumulation_steps: int
43+
train_context: Generator[None, None, None]
44+
pp_has_first_stage: bool
45+
pp_has_last_stage: bool
46+
47+
# Fields in ForgeEngine which are not in original Trainer
48+
# for dataloading
49+
dp_degree: int
50+
dp_rank: int
51+
# for logging
52+
model_args: BaseModelArgs
53+
num_flops_per_token: float
54+
model_param_count: int
55+
global_batch_size: int
56+
57+
# Enable debug tracing on failure: https://pytorch.org/docs/stable/elastic/errors.html
58+
@record
59+
def __init__(self, job_config: ForgeJobConfig):
60+
torch._C._log_api_usage_once("torchtitan.train")
61+
62+
self.job_config = job_config
63+
64+
device_module, device_type = utils.device_module, utils.device_type
65+
self.device = torch.device(f"{device_type}:{int(os.environ['LOCAL_RANK'])}")
66+
# Device has to be set before creating TorchFT manager.
67+
device_module.set_device(self.device)
68+
69+
# init distributed and build meshes
70+
dist_utils.init_distributed(job_config)
71+
world_size = int(os.environ["WORLD_SIZE"])
72+
parallelism_config = job_config.parallelism
73+
self.parallel_dims = parallel_dims = ParallelDims(
74+
dp_shard=parallelism_config.data_parallel_shard_degree,
75+
dp_replicate=parallelism_config.data_parallel_replicate_degree,
76+
cp=parallelism_config.context_parallel_degree,
77+
tp=parallelism_config.tensor_parallel_degree,
78+
pp=parallelism_config.pipeline_parallel_degree,
79+
ep=parallelism_config.expert_parallel_degree,
80+
world_size=world_size,
81+
)
82+
83+
world_mesh = parallel_dims.world_mesh
84+
if parallel_dims.dp_enabled:
85+
dp_mesh = world_mesh["dp"]
86+
dp_degree, dp_rank = dp_mesh.size(), dp_mesh.get_local_rank()
87+
else:
88+
dp_degree, dp_rank = 1, 0
89+
self.dp_degree, self.dp_rank = dp_degree, dp_rank
90+
91+
# take control of garbage collection to avoid stragglers
92+
self.gc_handler = utils.GarbageCollection(
93+
gc_freq=job_config.training.gc_freq, debug=job_config.training.gc_debug
94+
)
95+
96+
# Set random seed, and maybe enable deterministic mode
97+
# (mainly for debugging, expect perf loss).
98+
dist_utils.set_determinism(
99+
world_mesh,
100+
self.device,
101+
job_config.training.seed,
102+
job_config.training.deterministic,
103+
)
104+
self.train_spec = get_train_spec(job_config.model.name)
105+
106+
# build model (using meta init)
107+
self.model_args = model_args = self.train_spec.model_args[
108+
job_config.model.flavor
109+
]
110+
# set the model args from training job configs
111+
model_args.update_from_config(job_config)
112+
113+
with torch.device("meta"):
114+
model = self.train_spec.model_cls(model_args)
115+
116+
# calculate model size and flops per token
117+
(
118+
self.model_param_count,
119+
self.num_flops_per_token,
120+
) = model_args.get_nparams_and_flops(model, job_config.training.seq_len)
121+
122+
# move sharded model to CPU/GPU and initialize weights via DTensor
123+
if job_config.training.enable_cpu_offload:
124+
init_device = "cpu"
125+
buffer_device = device_type
126+
else:
127+
init_device = device_type
128+
buffer_device = None
129+
130+
self.loss_fn = self.train_spec.build_loss_fn(job_config)
131+
132+
# verify batch sizes
133+
global_batch_size = job_config.training.global_batch_size
134+
if global_batch_size < 0:
135+
# This global batch size results in 1 gradient accumulation
136+
# step.
137+
global_batch_size = job_config.training.local_batch_size * dp_degree
138+
assert global_batch_size > 0
139+
assert (
140+
global_batch_size % (job_config.training.local_batch_size * dp_degree) == 0
141+
), (
142+
f"global batch size must be multiple of local batch size times "
143+
f"data-parallel degree ({global_batch_size} "
144+
f"% ({job_config.training.local_batch_size} * {dp_degree}) != 0)"
145+
)
146+
self.global_batch_size = global_batch_size
147+
148+
# calculate gradient accumulation steps
149+
self.gradient_accumulation_steps = global_batch_size // (
150+
job_config.training.local_batch_size * dp_degree
151+
)
152+
assert self.gradient_accumulation_steps > 0
153+
self.loss_fn = rescale_accumulated_loss(
154+
self.loss_fn, self.gradient_accumulation_steps
155+
)
156+
157+
# apply parallelisms and initialization
158+
if parallel_dims.pp_enabled:
159+
if not self.train_spec.pipelining_fn:
160+
raise RuntimeError(
161+
f"Pipeline Parallel is enabled but {self.train_spec.name} "
162+
f"does not support pipelining"
163+
)
164+
165+
# apply both PT-D Pipeline Parallel and SPMD-style PT-D techniques
166+
(
167+
self.pp_schedule,
168+
self.model_parts,
169+
self.pp_has_first_stage,
170+
self.pp_has_last_stage,
171+
) = self.train_spec.pipelining_fn(
172+
model,
173+
parallel_dims,
174+
job_config,
175+
self.device,
176+
model_args,
177+
self.train_spec.parallelize_fn,
178+
self.loss_fn,
179+
)
180+
# when PP is enabled, `model` obj is no longer used after this point,
181+
# model_parts is used instead
182+
del model
183+
184+
for m in self.model_parts:
185+
m.to_empty(device=init_device)
186+
with torch.no_grad():
187+
m.init_weights(buffer_device=buffer_device)
188+
m.train()
189+
else:
190+
# apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel
191+
model = self.train_spec.parallelize_fn(model, parallel_dims, job_config)
192+
193+
model.to_empty(device=init_device)
194+
with torch.no_grad():
195+
model.init_weights(buffer_device=buffer_device)
196+
model.train()
197+
198+
self.model_parts = [model]
199+
200+
# build optimizer after applying parallelisms to the model
201+
self.optimizers = self.train_spec.build_optimizers_fn(
202+
self.model_parts, job_config, parallel_dims
203+
)
204+
self.lr_schedulers = self.train_spec.build_lr_schedulers_fn(
205+
self.optimizers, job_config
206+
)
207+
208+
self.checkpointer = CheckpointManager(
209+
dataloader=None,
210+
model_parts=self.model_parts,
211+
optimizers=self.optimizers,
212+
lr_schedulers=self.lr_schedulers,
213+
states={"train_state": self},
214+
checkpoint_config=job_config.checkpoint,
215+
sd_adapter=self.train_spec.state_dict_adapter,
216+
)
217+
218+
loss_parallel_enabled = (
219+
parallel_dims.tp_enabled and not parallelism_config.disable_loss_parallel
220+
)
221+
self.train_context = dist_utils.get_train_context(
222+
loss_parallel_enabled,
223+
parallelism_config.enable_compiled_autograd,
224+
)
225+
self.maybe_enable_amp = dist_utils.maybe_enable_amp(
226+
parallel_dims,
227+
job_config.training.mixed_precision_param,
228+
device_type,
229+
)
230+
231+
def close(self) -> None:
232+
if self.checkpointer:
233+
self.checkpointer.close()

0 commit comments

Comments
 (0)