Skip to content

Commit f1a96b3

Browse files
[Compiler Toolkit] Apply autobucketing_reordering_pass (#1951)
Apply autobucketing_reordering_pass
1 parent f1b3c9f commit f1a96b3

File tree

1 file changed

+32
-13
lines changed

1 file changed

+32
-13
lines changed

torchtitan/experiments/compiler_toolkit/llama3/parallelize.py

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import torch
99
from torch._functorch.aot_autograd import aot_compile_joint_with_descriptors
1010
from torch._guards import tracing
11+
from torch._inductor.fx_passes.overlap_scheduling import schedule_overlap_bucketing
1112

1213
from torch.distributed.tensor import DTensor
1314
from torch.fx.passes.regional_inductor import regional_inductor
@@ -31,6 +32,36 @@
3132
from torchtitan.tools.logging import logger
3233

3334

35+
# TODO: support passing configs into schedule_overlap_bucketing
36+
def autobucketing_reordering_pass(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
37+
schedule_overlap_bucketing(
38+
gm, collective_bucketing=True, schedule_overlap_bucketing=False
39+
)
40+
gm.recompile()
41+
return gm
42+
43+
44+
def compiler(name: str, gm: torch.fx.GraphModule, example_inputs):
45+
logger.info(f"{name} before compiler:")
46+
logger.info(gm.print_readable(print_output=False))
47+
48+
gm = autobucketing_reordering_pass(gm)
49+
50+
gm = regional_inductor(gm, example_inputs)
51+
52+
logger.info(f"{name} after compiler:")
53+
logger.info(gm.print_readable(print_output=False))
54+
return gm
55+
56+
57+
def fw_compiler(gm: torch.fx.GraphModule, example_inputs) -> None:
58+
return compiler("fwd_gm", gm, example_inputs)
59+
60+
61+
def bw_compiler(gm: torch.fx.GraphModule, example_inputs) -> None:
62+
return compiler("bwd_gm", gm, example_inputs)
63+
64+
3465
def joint_graph_builder(model, *args, **kwargs):
3566
assert isinstance(model, SimpleFSDPTransformer)
3667
assert isinstance(args, tuple)
@@ -51,21 +82,9 @@ def joint_graph_builder(model, *args, **kwargs):
5182
}:
5283
assert "compile_with_inductor" in node.meta.get("custom", {})
5384

54-
def compiler(gm: torch.fx.GraphModule, example_inputs):
55-
logger.info("Before compiler:")
56-
logger.info(gm.print_readable(print_output=False))
57-
58-
# gm = schedule_overlap_bucketing(gm)
59-
60-
gm = regional_inductor(gm, example_inputs)
61-
62-
logger.info("After compiler:")
63-
logger.info(gm.print_readable(print_output=False))
64-
return gm
65-
6685
with tracing(tracing_context):
6786
fn = aot_compile_joint_with_descriptors(
68-
joint_with_descriptors, fw_compiler=compiler, bw_compiler=compiler
87+
joint_with_descriptors, fw_compiler=fw_compiler, bw_compiler=bw_compiler
6988
)
7089

7190
def wrapper_fn(args, kwargs):

0 commit comments

Comments
 (0)