Skip to content

Commit dcd37e5

Browse files
committed
chore: Merge branch 'main' into tensor_parallelism
2 parents 8b378ac + f6f663b commit dcd37e5

File tree

10 files changed

+146
-29
lines changed

10 files changed

+146
-29
lines changed

src/modalities/config/config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -338,15 +338,15 @@ class FullACParams(BaseModel):
338338
pass
339339

340340
class SelectiveLayerACParams(BaseModel):
341-
ac_freq: int
341+
ac_freq: Annotated[int, Field(strict=True, ge=1)]
342342

343343
class SelectiveOpACParams(BaseModel):
344344
save_ops_keys: list[str]
345345

346346
ac_variant: ActivationCheckpointingVariants
347347
layers_fqn: str
348-
model: PydanticPytorchModuleType | PydanticFSDP1ModuleType
349-
ac_fun_params: Optional[FullACParams | SelectiveLayerACParams | SelectiveOpACParams] = None
348+
model: PydanticPytorchModuleType
349+
ac_fun_params: FullACParams | SelectiveLayerACParams | SelectiveOpACParams
350350

351351

352352
class RawAppStateConfig(BaseModel):

src/modalities/models/model_factory.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
from modalities.running_env.fsdp.fsdp_auto_wrapper import FSDPTransformerAutoWrapPolicyFactory
4343
from modalities.training.activation_checkpointing.activation_checkpointing import (
4444
ActivationCheckpointing,
45-
apply_activation_checkpointing_inplace,
45+
apply_activation_checkpointing_fsdp1_inplace,
4646
)
4747
from modalities.training.activation_checkpointing.activation_checkpointing_variants import (
4848
ActivationCheckpointingVariants,
@@ -265,19 +265,19 @@ def get_activation_checkpointed_fsdp1_model_(model: FSDP1, activation_checkpoint
265265
"""
266266
if len(activation_checkpointing_modules) > 0:
267267
if isinstance(model, FSDP1):
268-
apply_activation_checkpointing_inplace(
268+
apply_activation_checkpointing_fsdp1_inplace(
269269
model=model,
270270
activation_checkpointing_modules=activation_checkpointing_modules,
271271
)
272272
else:
273273
raise ValueError(
274274
"Activation checkpointing can only be applied to FSDP1-wrapped models! "
275-
f"Current model type: {type(model)}"
275+
f"Current model type: {type(model)}."
276276
)
277277
return model
278278

279279
@staticmethod
280-
def get_activation_checkpointed_model_(
280+
def get_activation_checkpointed_fsdp2_model_(
281281
ac_variant: ActivationCheckpointingVariants,
282282
layers_fqn: str,
283283
model: nn.Module,
@@ -288,7 +288,9 @@ def get_activation_checkpointed_model_(
288288
),
289289
) -> nn.Module:
290290
"""FSDP2 variant for applying activation checkpointing to the given model (in-place operation).
291-
When using FSDP2, we always first apply activation checkpointing to the model and then wrap it with FSDP2.
291+
292+
Important: When using FSDP2, we always first apply activation checkpointing to the model
293+
and then wrap it with FSDP2.
292294
293295
Args:
294296
ac_variant (ActivationCheckpointingVariants): The activation checkpointing variant to use.

src/modalities/registry/components.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ class ComponentEntity:
166166
ComponentEntity(
167167
"model",
168168
"activation_checkpointed",
169-
ModelFactory.get_activation_checkpointed_model_,
169+
ModelFactory.get_activation_checkpointed_fsdp2_model_,
170170
ActivationCheckpointedModelConfig,
171171
),
172172
ComponentEntity("model", "compiled", ModelFactory.get_compiled_model, CompiledModelConfig),

src/modalities/training/activation_checkpointing/activation_checkpointing.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,15 @@ def is_module_to_apply_activation_checkpointing(
2323
return isinstance(submodule, tuple(activation_checkpointing_modules))
2424

2525

26-
def apply_activation_checkpointing_inplace(model: nn.Module, activation_checkpointing_modules: list[str]):
26+
def apply_activation_checkpointing_fsdp1_inplace(model: FSDP1, activation_checkpointing_modules: list[str]):
2727
activation_checkpointing_module_types = [
2828
get_module_class_from_name(model, m) for m in activation_checkpointing_modules
2929
]
30-
if not isinstance(model, (FSDP1)):
31-
raise ValueError("activation checkpointing can only be applied to FSDP1 wrapped models!")
30+
if not isinstance(model, FSDP1):
31+
raise ValueError(
32+
"This activation checkpointing component can only be applied to FSDP1 wrapped models. "
33+
"Use the respective FSDP2 component for FSDP2 models."
34+
)
3235
non_reentrant_wrapper = partial(ptd_checkpoint_wrapper, checkpoint_impl=CheckpointImpl.NO_REENTRANT, debug=False)
3336

3437
apply_activation_checkpointing(
@@ -76,7 +79,7 @@ class ActivationCheckpointing:
7679
# for low precision training, it's useful to always save
7780
# the result of max, since the absolute maximum is
7881
# used to compute the scaling factor for quantization.
79-
"torch.ops.aten.max.default": ops.aten.max.default,
82+
"ops.aten.max.default": ops.aten.max.default,
8083
}
8184

8285
@staticmethod
@@ -147,19 +150,20 @@ def apply_activation_checkpointing_(
147150

148151
@staticmethod
149152
def _apply_full_ac(module: nn.Module) -> nn.Module:
150-
module_saced = ptd_checkpoint_wrapper(module, preserve_rng_state=False)
151-
return module_saced
153+
module_aced = ptd_checkpoint_wrapper(module, preserve_rng_state=False)
154+
return module_aced
152155

153156
@staticmethod
154157
def _apply_selective_op_ac(module: nn.Module, save_ops_keys: list[str]) -> nn.Module:
155-
def _get_custom_policy(meta, save_ops_set: Set): # closure to capture meta
158+
def _get_custom_policy(meta: dict[str, int], save_ops_set: Set): # closure to capture meta
156159
def _custom_policy(ctx, func, *args, **kwargs):
157160
mode = "recompute" if ctx.is_recompute else "forward"
158161
mm_count_key = f"{mode}_mm_count"
159162
if func == torch.ops.aten.mm.default:
160163
meta[mm_count_key] += 1
161164
# Saves output of all compute ops in save_ops_set, except every second mm
162165
# NOTE: we should make this configurable and not hide it in the code
166+
# To make this completely configurable, we would have to store the checkpointing frequency of every OP.
163167
to_save = func in save_ops_set and not (
164168
func == torch.ops.aten.mm.default and meta[mm_count_key] % 2 == 0
165169
)

src/modalities/util.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,18 @@ def get_experiment_id_from_config(config_file_path: Optional[Path], hash_length:
6565
def get_synced_string(
6666
string_to_be_synced: str, from_rank: int = 0, max_string_byte_length: Optional[int] = 1024
6767
) -> str:
68+
"""Broadcast a string from one rank to all other ranks in the distributed setup.
69+
70+
Args:
71+
string_to_be_synced (str): The string to be synced across ranks.
72+
from_rank (int, optional): The rank that generates the string. Defaults to 0.
73+
max_string_byte_length (Optional[int], optional): Maximum byte length of the string to be synced.
74+
Defaults to 1024.
75+
Returns:
76+
str: The synced string, decoded from the byte array.
77+
Raises:
78+
ValueError: If the string exceeds the maximum byte length.
79+
"""
6880
rank = dist.get_rank()
6981
if rank == from_rank:
7082
# Generate a unique folder name
@@ -112,9 +124,9 @@ def get_synced_experiment_id_of_run(
112124
Returns:
113125
str: The experiment ID.
114126
"""
115-
experimenet_id = get_experiment_id_from_config(config_file_path, hash_length)
127+
experiment_id = get_experiment_id_from_config(config_file_path, hash_length)
116128
experiment_id_synced = get_synced_string(
117-
string_to_be_synced=experimenet_id,
129+
string_to_be_synced=experiment_id,
118130
from_rank=0,
119131
max_string_byte_length=max_experiment_id_byte_length,
120132
)

tests/training/config_activation_checkpointing.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ full_activation_checkpointed_model:
77
instance_key: model_raw
88
pass_type: BY_REFERENCE
99
layers_fqn: transformer.h
10+
ac_fun_params: {}
1011

1112
selective_layer_activation_checkpointed_model:
1213
component_key: model
@@ -31,7 +32,7 @@ selective_op_activation_checkpointed_model:
3132
layers_fqn: transformer.h
3233
ac_fun_params:
3334
save_ops_keys:
34-
- torch.ops.aten.mm.default
35+
- ops.aten.mm.default
3536

3637
model_raw:
3738
component_key: model

tests/training/config_activation_checkpointing_fsdp1.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ test_model:
77
instance_key: wrapped_model
88
pass_type: BY_REFERENCE
99
layers_fqn: transformer.h
10+
ac_fun_params: {}
1011

1112
wrapped_model:
1213
component_key: model

tests/training/config_activation_checkpointing_fsdp2.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ activation_checkpointed_model:
4040
instance_key: model_raw
4141
pass_type: BY_REFERENCE
4242
layers_fqn: transformer.h
43+
ac_fun_params: {}
4344

4445
model_raw:
4546
component_key: model

tests/training/test_activation_checkpointing.py

Lines changed: 104 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@
22
from pathlib import Path
33

44
import pytest
5+
import torch
56
import torch.multiprocessing as mp
7+
import torch.nn as nn
8+
import torch.nn.functional as F
69
from pydantic import BaseModel
710
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import CheckpointWrapper
811

@@ -15,6 +18,10 @@
1518
working_dir = Path(os.path.dirname(__file__))
1619

1720

21+
class RawModel(BaseModel):
22+
model_raw: PydanticPytorchModuleType
23+
24+
1825
class ActivationCheckpointingInstantiationModel(BaseModel):
1926
test_model: PydanticPytorchModuleType
2027

@@ -31,6 +38,10 @@ class SelectiveOpActivationCheckpointingInstantiationModel(BaseModel):
3138
selective_op_activation_checkpointed_model: PydanticPytorchModuleType
3239

3340

41+
@pytest.mark.skipif(
42+
torch.cuda.device_count() < 2,
43+
reason="This test requires more than one GPU",
44+
)
3445
@pytest.mark.parametrize(
3546
"rdvz_port, world_size, relative_config_path",
3647
[
@@ -50,7 +61,6 @@ def test_full_activation_checkpointing_FSDP1_legacy(world_size: int, rdvz_port:
5061
def _test_full_activation_checkpointing_FSDP1_legacy_thread(
5162
process_id: int, rdvz_port: int, world_size: int, relative_config_path: str
5263
):
53-
working_dir = Path(os.path.dirname(__file__))
5464
config_file_path = working_dir / relative_config_path
5565

5666
with MultiProcessingCudaEnv(
@@ -77,6 +87,10 @@ def _test_full_activation_checkpointing_FSDP1_legacy_thread(
7787
)
7888

7989

90+
@pytest.mark.skipif(
91+
torch.cuda.device_count() < 2,
92+
reason="This test requires more than one GPU",
93+
)
8094
@pytest.mark.parametrize(
8195
"rdvz_port, world_size, relative_config_path",
8296
[
@@ -96,7 +110,6 @@ def test_full_activation_checkpointing_FSDPX(world_size: int, rdvz_port: int, re
96110
def _test_full_activation_checkpointing_FSDPX_thread(
97111
process_id: int, rdvz_port: int, world_size: int, relative_config_path: str
98112
):
99-
working_dir = Path(os.path.dirname(__file__))
100113
config_file_path = working_dir / relative_config_path
101114

102115
with MultiProcessingCudaEnv(
@@ -130,8 +143,7 @@ def _test_full_activation_checkpointing_FSDPX_thread(
130143
("config_activation_checkpointing.yaml"),
131144
],
132145
)
133-
def test_full_activation_checkpointing(relative_config_path: str):
134-
working_dir = Path(os.path.dirname(__file__))
146+
def test_fsdp2_full_activation_checkpointing(relative_config_path: str):
135147
config_file_path = working_dir / relative_config_path
136148

137149
main = Main(config_file_path, experiment_id="-1")
@@ -152,8 +164,7 @@ def test_full_activation_checkpointing(relative_config_path: str):
152164
("config_activation_checkpointing.yaml"),
153165
],
154166
)
155-
def test_selective_layer_activation_checkpointing(relative_config_path: str):
156-
working_dir = Path(os.path.dirname(__file__))
167+
def test_fsdp2_selective_layer_activation_checkpointing(relative_config_path: str):
157168
config_file_path = working_dir / relative_config_path
158169

159170
main = Main(config_file_path, experiment_id="-1")
@@ -174,8 +185,7 @@ def test_selective_layer_activation_checkpointing(relative_config_path: str):
174185
("config_activation_checkpointing.yaml"),
175186
],
176187
)
177-
def test_selective_op_activation_checkpointing(relative_config_path: str):
178-
working_dir = Path(os.path.dirname(__file__))
188+
def test_fsdp2_selective_op_activation_checkpointing(relative_config_path: str):
179189
config_file_path = working_dir / relative_config_path
180190

181191
main = Main(config_file_path, experiment_id="-1")
@@ -189,3 +199,89 @@ def test_selective_op_activation_checkpointing(relative_config_path: str):
189199
assert isinstance(module, CheckpointWrapper)
190200
else:
191201
assert not isinstance(module, CheckpointWrapper)
202+
203+
204+
# end to end equivalence test in terms of loss
205+
206+
207+
@pytest.mark.parametrize(
208+
"relative_config_path",
209+
[
210+
("config_activation_checkpointing.yaml"),
211+
],
212+
)
213+
def test_fsdp2_activation_checkpointing_end2end(relative_config_path: str):
214+
def forward_and_backward(model: nn.Module, input_ids: torch.Tensor) -> float:
215+
target = input_ids[:, 1:] # batch_size, seq_len - 1
216+
input_ids = input_ids[:, :-1] # batch_size, seq_len - 1
217+
input_dict = {"input_ids": input_ids}
218+
logits = model(input_dict)["logits"] # batch_size, seq_len - 1, vocab_size
219+
220+
loss = F.cross_entropy(
221+
logits.reshape(-1, logits.size(-1)), # batch_size * (seq_len - 1), vocab_size
222+
target.reshape(-1), # batch_size * (seq_len - 1)
223+
reduction="mean",
224+
)
225+
loss_val = loss.item()
226+
loss.backward()
227+
return loss_val
228+
229+
def check_grads_equal(model1, model2, label):
230+
for (n1, p1), (n2, p2) in zip(model1.named_parameters(), model2.named_parameters()):
231+
if p1.grad is not None and p2.grad is not None:
232+
# we cannot check the FQNs as AC renames the parameters.
233+
# inestead we check for weight equivalence
234+
torch.testing.assert_close(p1, p2, rtol=1e-5, atol=1e-7, msg=f"Parameter mismatch in {n1} ({label})")
235+
torch.testing.assert_close(
236+
p1.grad, p2.grad, rtol=1e-5, atol=1e-7, msg=f"Gradient mismatch in {n1} ({label})"
237+
)
238+
239+
batch_size = 2
240+
seq_len = 256
241+
vocab_size = 50304
242+
243+
# build the models with different activation checkpointing variants but equivalent weights
244+
config_file_path = working_dir / relative_config_path
245+
main = Main(config_file_path, experiment_id="-1")
246+
247+
torch.manual_seed(42)
248+
model_raw = main.build_components(components_model_type=RawModel).model_raw.to("cuda")
249+
250+
torch.manual_seed(42)
251+
model_fac = main.build_components(
252+
components_model_type=FullActivationCheckpointingInstantiationModel
253+
).full_activation_checkpointed_model.to("cuda")
254+
255+
torch.manual_seed(42)
256+
model_sel_layer = main.build_components(
257+
components_model_type=SelectiveLayerActivationCheckpointingInstantiationModel
258+
).selective_layer_activation_checkpointed_model.to("cuda")
259+
260+
torch.manual_seed(42)
261+
model_sel_op = main.build_components(
262+
components_model_type=SelectiveOpActivationCheckpointingInstantiationModel
263+
).selective_op_activation_checkpointed_model.to("cuda")
264+
265+
# Ensure all models have a different reference
266+
models = [model_raw, model_fac, model_sel_layer, model_sel_op]
267+
assert len(set(id(m) for m in models)) == len(models)
268+
269+
# Dummy LLM token input
270+
# we use a sequence length of seq_len + 1 as the last token will be only used for loss calculation
271+
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len + 1), device="cuda")
272+
273+
# Run forward+backward
274+
loss_raw = forward_and_backward(model_raw, input_ids)
275+
loss_fac = forward_and_backward(model_fac, input_ids)
276+
loss_sel_layer = forward_and_backward(model_sel_layer, input_ids)
277+
loss_sel_op = forward_and_backward(model_sel_op, input_ids)
278+
279+
# Compare losses
280+
torch.testing.assert_close(torch.tensor(loss_fac), torch.tensor(loss_raw), msg="FAC loss mismatch")
281+
torch.testing.assert_close(torch.tensor(loss_sel_layer), torch.tensor(loss_raw), msg="Sel layer AC loss mismatch")
282+
torch.testing.assert_close(torch.tensor(loss_sel_op), torch.tensor(loss_raw), msg="Sel op AC loss mismatch")
283+
284+
# Compare gradients
285+
check_grads_equal(model_raw, model_fac, "fac")
286+
check_grads_equal(model_raw, model_sel_layer, "sel_layer")
287+
check_grads_equal(model_raw, model_sel_op, "sel_op")

tutorials/instruction_tuning/configs/small_train_instruct_model_fsdp2_config.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ settings:
2323
enforce_last_step_evaluated: false
2424
enforce_last_step_checkpointed: false
2525
step_profile:
26-
gradient_accumulation_steps: 2
27-
local_train_micro_batch_size: 2
26+
gradient_accumulation_steps: 4
27+
local_train_micro_batch_size: 1
2828
sequence_length: 8192 # Qwen2.5 would have 32768
2929
training_target:
3030
# had to hack here: Value error, Not enough tokens in the dataset. Actual: 57434112, Expected: >=57442304

0 commit comments

Comments
 (0)