Skip to content

Commit a7bb57c

Browse files
committed
add manual bucketing pass
1 parent 8659543 commit a7bb57c

File tree

6 files changed

+90
-27
lines changed

6 files changed

+90
-27
lines changed

torchtitan/experiments/simple_fsdp/README.md

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -51,13 +51,15 @@ SimpleFSDP relies on compiler backend to perform optimizations (i.e., bucketing
5151

5252
2. auto optimization: perform auto-bucketing & reordering without user inputs. **Note: it is not guaranteed that users will get the most optimized training performance**
5353
- "aot_eager_autobucketing": perform autobucketing at aten fx-level, and perform code execution with aot_eager backend.
54-
55-
56-
users can specify the pass (e.g., "aot_eager_autobucketing") via additional configs:
57-
58-
```bash
59-
--compile.model_backend_override "aot_eager_autobucketing"
60-
```
54+
```bash
55+
--compile.backend "aot_eager" --compile.model_backend_override "aot_eager_autobucketing"
56+
```
57+
58+
3. manual optimization: perform manual bucketing & reordering with user FQN inputs.
59+
- "aot_eager_manualbucketing": perform manual bucketing at aten fx-level, and perform code execution with aot_eager backend.
60+
```bash
61+
--compile.backend "aot_eager" --compile.model_backend_override "aot_eager_manualbucketing"
62+
```
6163

6264
### Citation
6365

torchtitan/experiments/simple_fsdp/backend.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,20 @@
88

99
import torch
1010

11+
from .job_config import Compile as CompileConfig
1112

12-
def get_compile_backend(backend_name: str) -> Union[str, callable]:
13+
14+
def get_compile_backend(
15+
compile_config: CompileConfig, bucket_module_name: list[list[str] | str]
16+
) -> Union[str, callable]:
1317
# return the compile backends used in SimpleFSDP training
1418
# Step1: check if backend_name is inside available torch.compile backends
1519
# Step2: check if the backend_name has been registered as a customized backend
20+
backend_name = (
21+
getattr(compile_config, "model_backend_override", None)
22+
or compile_config.backend
23+
)
24+
1625
available_torch_backend = torch._dynamo.list_backends(exclude_tags=())
1726
if backend_name in available_torch_backend:
1827
return backend_name
@@ -43,6 +52,33 @@ def aten_autobucketing_reordering_pass(
4352
bw_compiler=aten_autobucketing_reordering_pass,
4453
keep_inference_input_mutations=True,
4554
)
55+
elif backend_name == "aot_eager_blockbucketing":
56+
# Perform manual optimization in aten fx-level and execute code in aot_eager backend
57+
# The manualbucketing logic is here:
58+
from functools import partial
59+
60+
from torch._dynamo.backends.common import aot_autograd as aot_autograd_backend
61+
from torch._inductor.fx_passes.overlap_manual_scheduling import (
62+
manual_overlap_bucketing,
63+
)
64+
65+
torch._inductor.config.allow_buffer_reuse = False
66+
manual_overlap_bucketing = partial(
67+
manual_overlap_bucketing,
68+
module_bucket_plans=bucket_module_name,
69+
)
70+
71+
def aten_manualbucketing_reordering_pass(
72+
gm: torch.fx.GraphModule, example_inputs: Any
73+
) -> torch.fx.GraphModule:
74+
manual_overlap_bucketing(gm)
75+
return gm
76+
77+
backend = aot_autograd_backend(
78+
fw_compiler=aten_manualbucketing_reordering_pass,
79+
bw_compiler=aten_manualbucketing_reordering_pass,
80+
keep_inference_input_mutations=True,
81+
)
4682
else:
4783
raise AssertionError(f"Unsupported customized backend: {backend_name}")
4884

torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,8 @@ def parallelize_deepseekv3(
178178
if job_config.compile.enable:
179179
torch._inductor.config.reorder_for_peak_memory = False
180180
torch._dynamo.config.capture_scalar_outputs = True
181-
model = torch.compile(model, backend=job_config.compile.backend, fullgraph=True)
181+
model = torch.compile(
182+
model, backend=get_compile_backend(job_config.compile), fullgraph=True
183+
)
182184

183185
return model

torchtitan/experiments/simple_fsdp/job_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
@dataclass
1111
class Compile:
1212
model_backend_override: str | None = None
13-
"""Override backend to compile in simplefsdp. Additional backend includes aot_eager_autobucketing"""
13+
"""Override backend to compile in simplefsdp. Additional backend includes aot_eager_autobucketing """
1414

1515

1616
@dataclass

torchtitan/experiments/simple_fsdp/llama3/parallelize.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,29 @@
3333
}
3434

3535

36+
def get_bucket_module_name(model) -> list[list[str] | str]:
37+
module_list = [
38+
model.tok_embeddings,
39+
[model.norm, model.output],
40+
]
41+
for layer_id, transformer_block in model.layers.items():
42+
module_list.append(transformer_block)
43+
44+
def convert_to_fqn_list(modules, mapping):
45+
"""Convert a (possibly nested) list of modules to FQN strings."""
46+
result = []
47+
for item in modules:
48+
if isinstance(item, list):
49+
result.append(convert_to_fqn_list(item, mapping))
50+
else:
51+
result.append(mapping.get(item, None))
52+
return result
53+
54+
module_to_name = {m: n for n, m in model.named_modules()}
55+
module_fqns = convert_to_fqn_list(module_list, module_to_name)
56+
return module_fqns
57+
58+
3659
def parallelize_llama(
3760
model: nn.Module,
3861
parallel_dims: ParallelDims,
@@ -140,13 +163,12 @@ def parallelize_llama(
140163

141164
if job_config.compile.enable and "model" in job_config.compile.components:
142165
torch._inductor.config.reorder_for_peak_memory = False
143-
backend = (
144-
getattr(job_config.compile, "model_backend_override", None)
145-
or job_config.compile.backend
146-
)
166+
147167
model = torch.compile(
148168
model,
149-
backend=get_compile_backend(backend),
169+
backend=get_compile_backend(
170+
job_config.compile, get_bucket_module_name(model)
171+
),
150172
fullgraph=True,
151173
)
152174

torchtitan/experiments/simple_fsdp/tests/integration_tests.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -29,18 +29,19 @@ def build_simple_fsdp_test_list() -> list[OverrideDefinitions]:
2929
"1D",
3030
"1d",
3131
),
32-
OverrideDefinitions(
33-
[
34-
[
35-
"--model.name simple_fsdp.llama3",
36-
"--compile.enable",
37-
"--job.custom_config_module=torchtitan.experiments.simple_fsdp.job_config",
38-
"--compile.model_backend_override aot_eager_autobucketing",
39-
],
40-
],
41-
"1D+aot_eager_autobucketing",
42-
"1d_aot_eager_autobucketing",
43-
),
32+
# TODO(ruisizhang123): add back after autobucketing pass is mature
33+
# OverrideDefinitions(
34+
# [
35+
# [
36+
# "--model.name simple_fsdp.llama3",
37+
# "--compile.enable",
38+
# "--job.custom_config_module=torchtitan.experiments.simple_fsdp.job_config",
39+
# "--compile.model_backend_override aot_eager_autobucketing",
40+
# ],
41+
# ],
42+
# "1D+aot_eager_autobucketing",
43+
# "1d_aot_eager_autobucketing",
44+
# ),
4445
OverrideDefinitions(
4546
[
4647
[

0 commit comments

Comments
 (0)