2
2
from pathlib import Path
3
3
4
4
import pytest
5
+ import torch
5
6
import torch .multiprocessing as mp
7
+ import torch .nn as nn
8
+ import torch .nn .functional as F
6
9
from pydantic import BaseModel
7
10
from torch .distributed .algorithms ._checkpoint .checkpoint_wrapper import CheckpointWrapper
8
11
15
18
working_dir = Path (os .path .dirname (__file__ ))
16
19
17
20
21
+ class RawModel (BaseModel ):
22
+ model_raw : PydanticPytorchModuleType
23
+
24
+
18
25
class ActivationCheckpointingInstantiationModel (BaseModel ):
19
26
test_model : PydanticPytorchModuleType
20
27
@@ -31,6 +38,10 @@ class SelectiveOpActivationCheckpointingInstantiationModel(BaseModel):
31
38
selective_op_activation_checkpointed_model : PydanticPytorchModuleType
32
39
33
40
41
+ @pytest .mark .skipif (
42
+ torch .cuda .device_count () < 2 ,
43
+ reason = "This test requires more than one GPU" ,
44
+ )
34
45
@pytest .mark .parametrize (
35
46
"rdvz_port, world_size, relative_config_path" ,
36
47
[
@@ -50,7 +61,6 @@ def test_full_activation_checkpointing_FSDP1_legacy(world_size: int, rdvz_port:
50
61
def _test_full_activation_checkpointing_FSDP1_legacy_thread (
51
62
process_id : int , rdvz_port : int , world_size : int , relative_config_path : str
52
63
):
53
- working_dir = Path (os .path .dirname (__file__ ))
54
64
config_file_path = working_dir / relative_config_path
55
65
56
66
with MultiProcessingCudaEnv (
@@ -77,6 +87,10 @@ def _test_full_activation_checkpointing_FSDP1_legacy_thread(
77
87
)
78
88
79
89
90
+ @pytest .mark .skipif (
91
+ torch .cuda .device_count () < 2 ,
92
+ reason = "This test requires more than one GPU" ,
93
+ )
80
94
@pytest .mark .parametrize (
81
95
"rdvz_port, world_size, relative_config_path" ,
82
96
[
@@ -96,7 +110,6 @@ def test_full_activation_checkpointing_FSDPX(world_size: int, rdvz_port: int, re
96
110
def _test_full_activation_checkpointing_FSDPX_thread (
97
111
process_id : int , rdvz_port : int , world_size : int , relative_config_path : str
98
112
):
99
- working_dir = Path (os .path .dirname (__file__ ))
100
113
config_file_path = working_dir / relative_config_path
101
114
102
115
with MultiProcessingCudaEnv (
@@ -130,8 +143,7 @@ def _test_full_activation_checkpointing_FSDPX_thread(
130
143
("config_activation_checkpointing.yaml" ),
131
144
],
132
145
)
133
- def test_full_activation_checkpointing (relative_config_path : str ):
134
- working_dir = Path (os .path .dirname (__file__ ))
146
+ def test_fsdp2_full_activation_checkpointing (relative_config_path : str ):
135
147
config_file_path = working_dir / relative_config_path
136
148
137
149
main = Main (config_file_path , experiment_id = "-1" )
@@ -152,8 +164,7 @@ def test_full_activation_checkpointing(relative_config_path: str):
152
164
("config_activation_checkpointing.yaml" ),
153
165
],
154
166
)
155
- def test_selective_layer_activation_checkpointing (relative_config_path : str ):
156
- working_dir = Path (os .path .dirname (__file__ ))
167
+ def test_fsdp2_selective_layer_activation_checkpointing (relative_config_path : str ):
157
168
config_file_path = working_dir / relative_config_path
158
169
159
170
main = Main (config_file_path , experiment_id = "-1" )
@@ -174,8 +185,7 @@ def test_selective_layer_activation_checkpointing(relative_config_path: str):
174
185
("config_activation_checkpointing.yaml" ),
175
186
],
176
187
)
177
- def test_selective_op_activation_checkpointing (relative_config_path : str ):
178
- working_dir = Path (os .path .dirname (__file__ ))
188
+ def test_fsdp2_selective_op_activation_checkpointing (relative_config_path : str ):
179
189
config_file_path = working_dir / relative_config_path
180
190
181
191
main = Main (config_file_path , experiment_id = "-1" )
@@ -189,3 +199,89 @@ def test_selective_op_activation_checkpointing(relative_config_path: str):
189
199
assert isinstance (module , CheckpointWrapper )
190
200
else :
191
201
assert not isinstance (module , CheckpointWrapper )
202
+
203
+
204
+ # end to end equivalence test in terms of loss
205
+
206
+
207
+ @pytest .mark .parametrize (
208
+ "relative_config_path" ,
209
+ [
210
+ ("config_activation_checkpointing.yaml" ),
211
+ ],
212
+ )
213
+ def test_fsdp2_activation_checkpointing_end2end (relative_config_path : str ):
214
+ def forward_and_backward (model : nn .Module , input_ids : torch .Tensor ) -> float :
215
+ target = input_ids [:, 1 :] # batch_size, seq_len - 1
216
+ input_ids = input_ids [:, :- 1 ] # batch_size, seq_len - 1
217
+ input_dict = {"input_ids" : input_ids }
218
+ logits = model (input_dict )["logits" ] # batch_size, seq_len - 1, vocab_size
219
+
220
+ loss = F .cross_entropy (
221
+ logits .reshape (- 1 , logits .size (- 1 )), # batch_size * (seq_len - 1), vocab_size
222
+ target .reshape (- 1 ), # batch_size * (seq_len - 1)
223
+ reduction = "mean" ,
224
+ )
225
+ loss_val = loss .item ()
226
+ loss .backward ()
227
+ return loss_val
228
+
229
+ def check_grads_equal (model1 , model2 , label ):
230
+ for (n1 , p1 ), (n2 , p2 ) in zip (model1 .named_parameters (), model2 .named_parameters ()):
231
+ if p1 .grad is not None and p2 .grad is not None :
232
+ # we cannot check the FQNs as AC renames the parameters.
233
+ # inestead we check for weight equivalence
234
+ torch .testing .assert_close (p1 , p2 , rtol = 1e-5 , atol = 1e-7 , msg = f"Parameter mismatch in { n1 } ({ label } )" )
235
+ torch .testing .assert_close (
236
+ p1 .grad , p2 .grad , rtol = 1e-5 , atol = 1e-7 , msg = f"Gradient mismatch in { n1 } ({ label } )"
237
+ )
238
+
239
+ batch_size = 2
240
+ seq_len = 256
241
+ vocab_size = 50304
242
+
243
+ # build the models with different activation checkpointing variants but equivalent weights
244
+ config_file_path = working_dir / relative_config_path
245
+ main = Main (config_file_path , experiment_id = "-1" )
246
+
247
+ torch .manual_seed (42 )
248
+ model_raw = main .build_components (components_model_type = RawModel ).model_raw .to ("cuda" )
249
+
250
+ torch .manual_seed (42 )
251
+ model_fac = main .build_components (
252
+ components_model_type = FullActivationCheckpointingInstantiationModel
253
+ ).full_activation_checkpointed_model .to ("cuda" )
254
+
255
+ torch .manual_seed (42 )
256
+ model_sel_layer = main .build_components (
257
+ components_model_type = SelectiveLayerActivationCheckpointingInstantiationModel
258
+ ).selective_layer_activation_checkpointed_model .to ("cuda" )
259
+
260
+ torch .manual_seed (42 )
261
+ model_sel_op = main .build_components (
262
+ components_model_type = SelectiveOpActivationCheckpointingInstantiationModel
263
+ ).selective_op_activation_checkpointed_model .to ("cuda" )
264
+
265
+ # Ensure all models have a different reference
266
+ models = [model_raw , model_fac , model_sel_layer , model_sel_op ]
267
+ assert len (set (id (m ) for m in models )) == len (models )
268
+
269
+ # Dummy LLM token input
270
+ # we use a sequence length of seq_len + 1 as the last token will be only used for loss calculation
271
+ input_ids = torch .randint (0 , vocab_size , (batch_size , seq_len + 1 ), device = "cuda" )
272
+
273
+ # Run forward+backward
274
+ loss_raw = forward_and_backward (model_raw , input_ids )
275
+ loss_fac = forward_and_backward (model_fac , input_ids )
276
+ loss_sel_layer = forward_and_backward (model_sel_layer , input_ids )
277
+ loss_sel_op = forward_and_backward (model_sel_op , input_ids )
278
+
279
+ # Compare losses
280
+ torch .testing .assert_close (torch .tensor (loss_fac ), torch .tensor (loss_raw ), msg = "FAC loss mismatch" )
281
+ torch .testing .assert_close (torch .tensor (loss_sel_layer ), torch .tensor (loss_raw ), msg = "Sel layer AC loss mismatch" )
282
+ torch .testing .assert_close (torch .tensor (loss_sel_op ), torch .tensor (loss_raw ), msg = "Sel op AC loss mismatch" )
283
+
284
+ # Compare gradients
285
+ check_grads_equal (model_raw , model_fac , "fac" )
286
+ check_grads_equal (model_raw , model_sel_layer , "sel_layer" )
287
+ check_grads_equal (model_raw , model_sel_op , "sel_op" )
0 commit comments