Skip to content

Commit d577a46

Browse files
committed
Add torchft to CI
1 parent 27e3ad8 commit d577a46

File tree

4 files changed

+206
-0
lines changed

4 files changed

+206
-0
lines changed
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
name: TorchFT 8 GPU Integration Test
2+
3+
on:
4+
push:
5+
branches: [ main ]
6+
paths:
7+
- 'torchtitan/components/ft.py'
8+
pull_request:
9+
paths:
10+
- 'torchtitan/components/ft.py'
11+
schedule:
12+
# Runs every 6 hours
13+
- cron: '0 */6 * * *'
14+
concurrency:
15+
group: unit-test${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_number || github.ref }}
16+
cancel-in-progress: true
17+
18+
defaults:
19+
run:
20+
shell: bash -l -eo pipefail {0}
21+
22+
jobs:
23+
build-test:
24+
uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
25+
with:
26+
runner: linux.g5.48xlarge.nvidia.gpu
27+
gpu-arch-type: cuda
28+
gpu-arch-version: "12.6"
29+
# This image is faster to clone than the default, but it lacks CC needed by triton
30+
# (1m25s vs 2m37s).
31+
docker-image: torchtitan-ubuntu-20.04-clang12
32+
repository: pytorch/torchtitan
33+
upload-artifact: outputs
34+
script: |
35+
set -eux
36+
37+
# The generic Linux job chooses to use base env, not the one setup by the image
38+
CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]")
39+
conda activate "${CONDA_ENV}"
40+
41+
pip config --user set global.progress_bar off
42+
43+
python -m pip install torchft-nightly
44+
45+
mkdir artifacts-to-be-uploaded
46+
echo "torchft_lighthouse"
47+
RUST_BACKTRACE=1 torchft_lighthouse --min_replicas 1 --quorum_tick_ms 100 --join_timeout_ms 10000 > /dev/null 2>&1 &
48+
echo "ft_integration_test"
49+
python ./tests/integration_tests_ft.py artifacts-to-be-uploaded --ngpu 8
50+
killall -9 torchft_lighthouse

tests/integration_tests_ft.py

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import argparse
8+
import concurrent.futures
9+
import logging
10+
import os
11+
import subprocess
12+
from collections import defaultdict
13+
from dataclasses import dataclass
14+
from typing import Sequence
15+
16+
logging.basicConfig(level=logging.INFO)
17+
logger = logging.getLogger(__name__)
18+
19+
try:
20+
import tomllib
21+
except ModuleNotFoundError:
22+
import tomli as tomllib
23+
24+
25+
@dataclass
26+
class OverrideDefinitions:
27+
"""
28+
This class is used to define the override definitions for the integration tests.
29+
"""
30+
31+
override_args: Sequence[Sequence[str]] = tuple(tuple(" "))
32+
test_descr: str = "default"
33+
test_name: str = "default"
34+
ngpu: int = 4
35+
model_flavor: str = "debugmodel"
36+
37+
def __repr__(self):
38+
return self.test_descr
39+
40+
41+
def build_test_list():
42+
"""
43+
key is the config file name and value is a list of OverrideDefinitions
44+
that is used to generate variations of integration tests based on the
45+
same root config file.
46+
"""
47+
integration_tests_flavors = defaultdict(list)
48+
integration_tests_flavors["debug_model.toml"] = [
49+
OverrideDefinitions(
50+
[
51+
["--training.steps 10", "--checkpoint.enable_checkpoint"],
52+
],
53+
"Default TorchFT integration test",
54+
"default_torchft",
55+
)
56+
]
57+
return integration_tests_flavors
58+
59+
60+
def _run_cmd(cmd):
61+
return subprocess.run([cmd], text=True, shell=True)
62+
63+
64+
def run_test(test_flavor: OverrideDefinitions, full_path: str, output_dir: str):
65+
# run_test supports sequence of tests.
66+
test_name = test_flavor.test_name
67+
dump_folder_arg = f"--job.dump_folder {output_dir}/{test_name}"
68+
model_flavor_arg = f"--model.flavor {test_flavor.model_flavor}"
69+
70+
all_ranks = [",".join(map(str, range(0, 4))), ",".join(map(str, range(4, 8)))]
71+
72+
for idx, override_arg in enumerate(test_flavor.override_args):
73+
cmds = []
74+
for replica_id, ranks in enumerate(all_ranks):
75+
cmd = (
76+
f'TORCH_TRACE="{output_dir}/{test_name}/compile_trace" '
77+
+ f"NGPU=4 CUDA_VISIBLE_DEVICES={ranks} "
78+
+ f"CONFIG_FILE={full_path} NGPU=4 ./run_train.sh "
79+
+ "--fault_tolerance.enable "
80+
+ f"--fault_tolerance.replica_id={replica_id} --fault_tolerance.group_size=2"
81+
)
82+
83+
cmd += " " + dump_folder_arg
84+
cmd += " " + model_flavor_arg
85+
if override_arg:
86+
cmd += " " + " ".join(override_arg)
87+
88+
logger.info(
89+
"=====TorchFT Integration test, flavor : "
90+
f"{test_flavor.test_descr}, command : {cmd}====="
91+
)
92+
cmds.append(cmd)
93+
94+
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
95+
futures = [executor.submit(_run_cmd, cmd) for cmd in cmds]
96+
results = [future.result() for future in futures]
97+
98+
for result in results:
99+
logger.info(result.stdout)
100+
101+
if result.returncode == 0:
102+
continue
103+
104+
raise Exception(
105+
f"Integration test failed, flavor : {test_flavor.test_descr}, command : {cmd}"
106+
)
107+
108+
109+
def run_tests(args):
110+
integration_tests_flavors = build_test_list()
111+
112+
if args.ngpu < 8:
113+
logger.info("Skipping TorchFT integration tests as we need 8 GPUs.")
114+
return
115+
116+
for config_file in os.listdir(args.config_dir):
117+
if not config_file.endswith(".toml"):
118+
continue
119+
120+
full_path = os.path.join(args.config_dir, config_file)
121+
with open(full_path, "rb") as f:
122+
config = tomllib.load(f)
123+
is_integration_test = config["job"].get("use_for_integration_test", False)
124+
if not is_integration_test:
125+
continue
126+
127+
for test_flavor in integration_tests_flavors[config_file]:
128+
if not (args.test == "all" or test_flavor.test_name == args.test):
129+
continue
130+
131+
run_test(test_flavor, full_path, args.output_dir)
132+
133+
134+
def main():
135+
parser = argparse.ArgumentParser()
136+
parser.add_argument("output_dir")
137+
parser.add_argument(
138+
"--config_dir", default="./torchtitan/models/llama3/train_configs"
139+
)
140+
parser.add_argument(
141+
"--test",
142+
default="all",
143+
help="test to run, acceptable values: `test_name` in `build_test_list` (default: all)",
144+
)
145+
parser.add_argument("--ngpu", default=8, type=int)
146+
args = parser.parse_args()
147+
148+
if not os.path.exists(args.output_dir):
149+
os.makedirs(args.output_dir)
150+
run_tests(args)
151+
152+
153+
if __name__ == "__main__":
154+
main()

torchtitan/components/checkpoint.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -430,6 +430,7 @@ def save(self, curr_step: int, last_step: bool = False) -> None:
430430
None
431431
"""
432432

433+
# TODO: we are always saving the checkpoint when ft is on? even if enable_checkpoint is off?
433434
if self.ft_manager:
434435
self._ft_save(curr_step)
435436

torchtitan/components/ft.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
# TODO: test that changes in this file trigger CI
78
import copy
89
import importlib
910
from contextlib import nullcontext

0 commit comments

Comments
 (0)