Skip to content

Commit 34d815c

Browse files
authored
[refactor] split JobConfig and ConfigManager into two files (#1442)
This PR creates a new folder `torchtitan/config` to host `job_config.py` and `manager.py`, for the reasons below: - Both are complicated enough to worth their own files. - The convention in torchtitan to extend custom `JobConfig` is to create a file under model folder called `job_config.py` (see https://github.com/pytorch/torchtitan/blob/main/torchtitan/experiments/flux/job_config.py). This PR makes the origin `JobConfig` consistent with that convention. - (minor) Creating a more succinct `torchtitan.config` namespace is more readable than importing from `torchtitan.config_manager`.
1 parent 171a883 commit 34d815c

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

54 files changed

+374
-355
lines changed

.github/workflows/integration_test_8gpu.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,4 +46,4 @@ jobs:
4646
USE_CPP=0 python -m pip install --pre torchao --index-url https://download.pytorch.org/whl/nightly/cu126
4747
4848
mkdir artifacts-to-be-uploaded
49-
python ./tests/integration_tests.py artifacts-to-be-uploaded --ngpu 8
49+
python -m tests.integration_tests artifacts-to-be-uploaded --ngpu 8

.github/workflows/integration_test_8gpu_h100.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,4 +47,4 @@ jobs:
4747
USE_CPP=0 python -m pip install --pre torchao --index-url https://download.pytorch.org/whl/nightly/cu126
4848
4949
mkdir artifacts-to-be-uploaded
50-
python ./tests/integration_tests_h100.py artifacts-to-be-uploaded --ngpu 8
50+
python -m tests.integration_tests_h100 artifacts-to-be-uploaded --ngpu 8

.github/workflows/integration_test_8gpu_torchft.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,5 +49,5 @@ jobs:
4949
RUST_BACKTRACE=1 torchft_lighthouse --min_replicas 1 --quorum_tick_ms 100 --join_timeout_ms 10000 > /dev/null 2>&1 &
5050
echo "ft_integration_test"
5151
# Getting error - Cuda failure 217 'peer access is not supported between these two devices'
52-
python ./tests/integration_tests_ft.py artifacts-to-be-uploaded --ngpu 8
52+
python -m tests.integration_tests_ft artifacts-to-be-uploaded --ngpu 8
5353
# pkill -9 torchft_lighthouse

scripts/estimate/estimation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
from torchtitan.components.lr_scheduler import build_lr_schedulers
1919
from torchtitan.components.optimizer import build_optimizers
20-
from torchtitan.config_manager import ConfigManager, JobConfig
20+
from torchtitan.config import ConfigManager, JobConfig
2121
from torchtitan.distributed import ParallelDims, utils as dist_utils
2222
from torchtitan.protocols.model_converter import build_model_converters
2323
from torchtitan.protocols.train_spec import get_train_spec

scripts/generate/test_generate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
)
2727
from torchtitan.components.checkpoint import excluded_parameters_for_model_only
2828
from torchtitan.components.metrics import build_device_memory_monitor
29-
from torchtitan.config_manager import ConfigManager
29+
from torchtitan.config import ConfigManager
3030
from torchtitan.distributed import ParallelDims, utils as dist_utils
3131
from torchtitan.protocols.train_spec import get_train_spec
3232
from torchtitan.tools import utils

tests/integration_tests_ft.py

Lines changed: 6 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
import os
1111
import subprocess
1212
from collections import defaultdict
13-
from dataclasses import dataclass
14-
from typing import Sequence
13+
14+
from tests.integration_tests import OverrideDefinitions
1515

1616
logging.basicConfig(level=logging.INFO)
1717
logger = logging.getLogger(__name__)
@@ -22,22 +22,6 @@
2222
import tomli as tomllib
2323

2424

25-
@dataclass
26-
class OverrideDefinitions:
27-
"""
28-
This class is used to define the override definitions for the integration tests.
29-
"""
30-
31-
override_args: Sequence[Sequence[str]] = tuple(tuple(" "))
32-
test_descr: str = "default"
33-
test_name: str = "default"
34-
ngpu: int = 4
35-
model_flavor: str = "debugmodel"
36-
37-
def __repr__(self):
38-
return self.test_descr
39-
40-
4125
def build_test_list():
4226
"""
4327
key is the config file name and value is a list of OverrideDefinitions
@@ -52,6 +36,7 @@ def build_test_list():
5236
],
5337
"Default TorchFT integration test",
5438
"default_torchft",
39+
ngpu=8,
5540
)
5641
]
5742
return integration_tests_flavors
@@ -65,7 +50,6 @@ def run_test(test_flavor: OverrideDefinitions, full_path: str, output_dir: str):
6550
# run_test supports sequence of tests.
6651
test_name = test_flavor.test_name
6752
dump_folder_arg = f"--job.dump_folder {output_dir}/{test_name}"
68-
model_flavor_arg = f"--model.flavor {test_flavor.model_flavor}"
6953

7054
# Use all 8 GPUs in a single replica
7155
# TODO: Use two replica groups
@@ -79,14 +63,13 @@ def run_test(test_flavor: OverrideDefinitions, full_path: str, output_dir: str):
7963
for replica_id, ranks in enumerate(all_ranks):
8064
cmd = (
8165
f'TORCH_TRACE="{output_dir}/{test_name}/compile_trace" '
82-
+ f"CUDA_VISIBLE_DEVICES={ranks}"
83-
+ f"CONFIG_FILE={full_path} NGPU={len(ranks)} ./run_train.sh "
66+
+ f"CUDA_VISIBLE_DEVICES={ranks} "
67+
+ f"CONFIG_FILE={full_path} NGPU={test_flavor.ngpu} ./run_train.sh "
8468
+ "--fault_tolerance.enable "
85-
+ f"--fault_tolerance.replica_id={replica_id} --fault_tolerance.group_size={len(all_ranks)}"
69+
+ f"--fault_tolerance.replica_id={replica_id} --fault_tolerance.group_size={test_flavor.ngpu}"
8670
)
8771

8872
cmd += " " + dump_folder_arg
89-
cmd += " " + model_flavor_arg
9073
if override_arg:
9174
cmd += " " + " ".join(override_arg)
9275

tests/integration_tests_h100.py

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
import os
1010
import subprocess
1111
from collections import defaultdict
12-
from dataclasses import dataclass
13-
from typing import Sequence
12+
13+
from .integration_tests import OverrideDefinitions
1414

1515
logging.basicConfig(level=logging.INFO)
1616
logger = logging.getLogger(__name__)
@@ -21,21 +21,6 @@
2121
import tomli as tomllib
2222

2323

24-
@dataclass
25-
class OverrideDefinitions:
26-
"""
27-
This class is used to define the override definitions for the integration tests.
28-
"""
29-
30-
override_args: Sequence[Sequence[str]] = tuple(tuple(" "))
31-
test_descr: str = "default"
32-
test_name: str = "default"
33-
ngpu: int = 4
34-
35-
def __repr__(self):
36-
return self.test_descr
37-
38-
3924
def build_test_list():
4025
"""
4126
key is the config file name and value is a list of OverrideDefinitions

tests/unit_tests/test_activation_checkpoint.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,11 @@
1010
import torch.nn as nn
1111
from torch.utils.flop_counter import FlopCounterMode
1212

13-
from torchtitan.config_manager import ActivationCheckpoint as ACConfig
13+
from torchtitan.config.job_config import ActivationCheckpoint as ACConfig
1414
from torchtitan.models.llama3.infra.parallelize import apply_ac
1515

1616

17-
class TestModule(nn.Module):
17+
class ToyModule(nn.Module):
1818
def __init__(self):
1919
super().__init__()
2020
self.layers = nn.ModuleDict({"0": TransformerBlock()})
@@ -56,12 +56,12 @@ def get_bw_flops(model_fn):
5656
return mode.get_total_flops() / (512**3 * 2)
5757

5858
# 1. No AC
59-
model_no_ac = TestModule()
59+
model_no_ac = ToyModule()
6060
flops_no_ac = get_bw_flops(model_no_ac)
6161

6262
# 2. SAC
6363
# Per-op SAC's policy is to save every other mm
64-
model_selective_ac = TestModule()
64+
model_selective_ac = ToyModule()
6565
ac_config_no_force = ACConfig(
6666
mode="selective",
6767
selective_ac_option="op",
@@ -72,7 +72,7 @@ def get_bw_flops(model_fn):
7272

7373
# 3. Per-op SAC with force recompute "moe.router.gate"
7474
# This leads to two mms being recomputed since they share the same shape!
75-
model_with_force_first = TestModule()
75+
model_with_force_first = ToyModule()
7676
ac_config_with_force_first = ACConfig(
7777
mode="selective",
7878
selective_ac_option="op",
@@ -82,7 +82,7 @@ def get_bw_flops(model_fn):
8282
flops_with_force_first = get_bw_flops(model_with_force_first)
8383

8484
# 4. Per-op SAC with force recompute "output"
85-
model_with_force_last = TestModule()
85+
model_with_force_last = ToyModule()
8686
ac_config_with_force_last = ACConfig(
8787
mode="selective",
8888
selective_ac_option="op",
@@ -92,7 +92,7 @@ def get_bw_flops(model_fn):
9292
flops_with_force_last = get_bw_flops(model_with_force_last)
9393

9494
# 5. Full AC
95-
model_with_full_ac = TestModule()
95+
model_with_full_ac = ToyModule()
9696
ac_config_full_ac = ACConfig(
9797
mode="full",
9898
)
@@ -122,12 +122,12 @@ def get_act_mem(model_fn):
122122
return act_mem
123123

124124
# 1. No AC
125-
model_no_ac = TestModule().cuda()
125+
model_no_ac = ToyModule().cuda()
126126
mem_no_ac = get_act_mem(model_no_ac)
127127

128128
# 2. SAC
129129
# Per-op SAC's policy is to save every other mm
130-
model_selective_ac = TestModule().cuda()
130+
model_selective_ac = ToyModule().cuda()
131131
ac_config_no_force = ACConfig(
132132
mode="selective",
133133
selective_ac_option="op",
@@ -138,7 +138,7 @@ def get_act_mem(model_fn):
138138

139139
# 3. Per-op SAC with force recompute "moe.router.gate"
140140
# This leads to two mms being recomputed since they share the same shape!
141-
model_with_force_first = TestModule().cuda()
141+
model_with_force_first = ToyModule().cuda()
142142
ac_config_with_force_first = ACConfig(
143143
mode="selective",
144144
selective_ac_option="op",
@@ -148,7 +148,7 @@ def get_act_mem(model_fn):
148148
mem_with_force_first = get_act_mem(model_with_force_first)
149149

150150
# 4. Per-op SAC with force recompute "output"
151-
model_with_force_last = TestModule().cuda()
151+
model_with_force_last = ToyModule().cuda()
152152
ac_config_with_force_last = ACConfig(
153153
mode="selective",
154154
selective_ac_option="op",
@@ -158,7 +158,7 @@ def get_act_mem(model_fn):
158158
mem_with_force_last = get_act_mem(model_with_force_last)
159159

160160
# 5. Full AC
161-
model_with_full_ac = TestModule().cuda()
161+
model_with_full_ac = ToyModule().cuda()
162162
ac_config_full_ac = ACConfig(
163163
mode="full",
164164
)
@@ -175,9 +175,9 @@ def get_act_mem(model_fn):
175175
# the size of the other two mms.
176176

177177
def test_correctness(self):
178-
model_no_ac = TestModule()
178+
model_no_ac = ToyModule()
179179

180-
model_selective_ac = TestModule()
180+
model_selective_ac = ToyModule()
181181
model_selective_ac.load_state_dict(model_no_ac.state_dict())
182182
apply_ac(
183183
model_selective_ac,
@@ -187,7 +187,7 @@ def test_correctness(self):
187187
per_op_sac_force_recompute_mm_shapes_by_fqns=[],
188188
),
189189
)
190-
model_force_first = TestModule()
190+
model_force_first = ToyModule()
191191
model_force_first.load_state_dict(model_no_ac.state_dict())
192192
apply_ac(
193193
model_force_first,
@@ -198,7 +198,7 @@ def test_correctness(self):
198198
),
199199
)
200200

201-
model_force_last = TestModule()
201+
model_force_last = ToyModule()
202202
model_force_last.load_state_dict(model_no_ac.state_dict())
203203
apply_ac(
204204
model_force_last,

tests/unit_tests/test_checkpoint.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import torch.nn as nn
1717
from torch.utils.data import DataLoader
1818
from torchtitan.components.checkpoint import CheckpointManager
19-
from torchtitan.config_manager import Checkpoint as CheckpointConfig
19+
from torchtitan.config.job_config import Checkpoint as CheckpointConfig
2020

2121

2222
class FakeOptimizersContainer:
@@ -176,6 +176,7 @@ def test_save_load_restores_state(self, mock_load, mock_save, mock_rank):
176176
lr_schedulers=self.lr_schedulers,
177177
states=self.states,
178178
checkpoint_config=self.job_config.checkpoint,
179+
sd_adapter=None,
179180
base_folder=self.job_config.job.dump_folder,
180181
ft_manager=self.ft_manager,
181182
)
@@ -209,6 +210,7 @@ def test_save_and_purge_keeps_last_k_checkpoints(
209210
lr_schedulers=self.lr_schedulers,
210211
states=self.states,
211212
checkpoint_config=self.job_config.checkpoint,
213+
sd_adapter=None,
212214
base_folder=self.job_config.job.dump_folder,
213215
ft_manager=self.ft_manager,
214216
)
@@ -250,6 +252,7 @@ def test_nonzero_rank_does_not_purge_or_save(self, mock_load, mock_save, mock_ra
250252
lr_schedulers=self.lr_schedulers,
251253
states=self.states,
252254
checkpoint_config=self.job_config.checkpoint,
255+
sd_adapter=None,
253256
base_folder=self.job_config.job.dump_folder,
254257
ft_manager=self.ft_manager,
255258
)
@@ -273,6 +276,7 @@ def test_load_returns_false_when_no_checkpoint_folder(self):
273276
lr_schedulers=self.lr_schedulers,
274277
states=self.states,
275278
checkpoint_config=self.job_config.checkpoint,
279+
sd_adapter=None,
276280
base_folder=self.job_config.job.dump_folder,
277281
ft_manager=self.ft_manager,
278282
)
@@ -297,6 +301,7 @@ def test_load_finds_latest_and_calls_dcp_load(self, mock_load, mock_rank):
297301
lr_schedulers=self.lr_schedulers,
298302
states=self.states,
299303
checkpoint_config=self.job_config.checkpoint,
304+
sd_adapter=None,
300305
base_folder=self.job_config.job.dump_folder,
301306
ft_manager=self.ft_manager,
302307
)
@@ -327,6 +332,7 @@ def test_interval_respects_interval(self, mock_load, mock_save, mock_rank):
327332
lr_schedulers=self.lr_schedulers,
328333
states=self.states,
329334
checkpoint_config=self.job_config.checkpoint,
335+
sd_adapter=None,
330336
base_folder=self.job_config.job.dump_folder,
331337
ft_manager=self.ft_manager,
332338
)
@@ -361,6 +367,7 @@ def test_last_save_model_only_and_initial_load_model_only(
361367
lr_schedulers=self.lr_schedulers,
362368
states=self.states,
363369
checkpoint_config=self.job_config.checkpoint,
370+
sd_adapter=None,
364371
base_folder=self.job_config.job.dump_folder,
365372
ft_manager=self.ft_manager,
366373
)
@@ -381,6 +388,7 @@ def test_last_save_model_only_and_initial_load_model_only(
381388
lr_schedulers=self.lr_schedulers,
382389
states=self.states,
383390
checkpoint_config=self.job_config.checkpoint,
391+
sd_adapter=None,
384392
base_folder=self.job_config.job.dump_folder,
385393
ft_manager=self.ft_manager,
386394
)
@@ -423,6 +431,7 @@ def test_async_save_calls_async_wait(self, mock_async_save, mock_new_group):
423431
lr_schedulers=self.lr_schedulers,
424432
states=states,
425433
checkpoint_config=checkpoint_config,
434+
sd_adapter=None,
426435
base_folder=self.job_config.job.dump_folder,
427436
ft_manager=self.ft_manager,
428437
)
@@ -468,6 +477,7 @@ def test_ft_async_save_calls_async_wait(
468477
lr_schedulers=self.lr_schedulers,
469478
states=self.states,
470479
checkpoint_config=checkpoint_config,
480+
sd_adapter=None,
471481
base_folder=self.job_config.job.dump_folder,
472482
ft_manager=self.ft_manager,
473483
)
@@ -504,6 +514,7 @@ def test_enable_first_step_checkpoint(self, mock_save, mock_rank):
504514
lr_schedulers=self.lr_schedulers,
505515
states=self.states,
506516
checkpoint_config=self.job_config.checkpoint,
517+
sd_adapter=None,
507518
base_folder=self.job_config.job.dump_folder,
508519
ft_manager=self.ft_manager,
509520
)
@@ -530,6 +541,7 @@ def test_enable_first_step_checkpoint(self, mock_save, mock_rank):
530541
lr_schedulers=self.lr_schedulers,
531542
states=self.states,
532543
checkpoint_config=self.job_config.checkpoint,
544+
sd_adapter=None,
533545
base_folder=self.job_config.job.dump_folder,
534546
ft_manager=self.ft_manager,
535547
)
@@ -576,6 +588,7 @@ def __init__(self):
576588
lr_schedulers=self.lr_schedulers,
577589
states=self.states,
578590
checkpoint_config=self.job_config.checkpoint,
591+
sd_adapter=None,
579592
base_folder=self.job_config.job.dump_folder,
580593
ft_manager=self.ft_manager,
581594
)
@@ -626,6 +639,7 @@ def fake_load(state_dict: dict, checkpoint_id=None):
626639
lr_schedulers=self.lr_schedulers,
627640
states=self.states,
628641
checkpoint_config=self.job_config.checkpoint,
642+
sd_adapter=None,
629643
base_folder=self.job_config.job.dump_folder,
630644
ft_manager=self.ft_manager,
631645
)

tests/unit_tests/test_dataset_checkpointing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import torch
1010
from datasets import load_dataset
1111
from torchtitan.components.tokenizer import HuggingFaceTokenizer
12-
from torchtitan.config_manager import ConfigManager
12+
from torchtitan.config import ConfigManager
1313
from torchtitan.datasets.hf_datasets import build_hf_dataloader, DatasetConfig, DATASETS
1414

1515

0 commit comments

Comments
 (0)