88import torch
99from torch ._functorch .aot_autograd import aot_compile_joint_with_descriptors
1010from torch ._guards import tracing
11+ from torch ._inductor .fx_passes .overlap_scheduling import schedule_overlap_bucketing
1112
1213from torch .distributed .tensor import DTensor
1314from torch .fx .passes .regional_inductor import regional_inductor
3132from torchtitan .tools .logging import logger
3233
3334
35+ # TODO: support passing configs into schedule_overlap_bucketing
36+ def autobucketing_reordering_pass (gm : torch .fx .GraphModule ) -> torch .fx .GraphModule :
37+ schedule_overlap_bucketing (
38+ gm , collective_bucketing = True , schedule_overlap_bucketing = False
39+ )
40+ gm .recompile ()
41+ return gm
42+
43+
44+ def compiler (name : str , gm : torch .fx .GraphModule , example_inputs ):
45+ logger .info (f"{ name } before compiler:" )
46+ logger .info (gm .print_readable (print_output = False ))
47+
48+ gm = autobucketing_reordering_pass (gm )
49+
50+ gm = regional_inductor (gm , example_inputs )
51+
52+ logger .info (f"{ name } after compiler:" )
53+ logger .info (gm .print_readable (print_output = False ))
54+ return gm
55+
56+
57+ def fw_compiler (gm : torch .fx .GraphModule , example_inputs ) -> None :
58+ return compiler ("fwd_gm" , gm , example_inputs )
59+
60+
61+ def bw_compiler (gm : torch .fx .GraphModule , example_inputs ) -> None :
62+ return compiler ("bwd_gm" , gm , example_inputs )
63+
64+
3465def joint_graph_builder (model , * args , ** kwargs ):
3566 assert isinstance (model , SimpleFSDPTransformer )
3667 assert isinstance (args , tuple )
@@ -51,21 +82,9 @@ def joint_graph_builder(model, *args, **kwargs):
5182 }:
5283 assert "compile_with_inductor" in node .meta .get ("custom" , {})
5384
54- def compiler (gm : torch .fx .GraphModule , example_inputs ):
55- logger .info ("Before compiler:" )
56- logger .info (gm .print_readable (print_output = False ))
57-
58- # gm = schedule_overlap_bucketing(gm)
59-
60- gm = regional_inductor (gm , example_inputs )
61-
62- logger .info ("After compiler:" )
63- logger .info (gm .print_readable (print_output = False ))
64- return gm
65-
6685 with tracing (tracing_context ):
6786 fn = aot_compile_joint_with_descriptors (
68- joint_with_descriptors , fw_compiler = compiler , bw_compiler = compiler
87+ joint_with_descriptors , fw_compiler = fw_compiler , bw_compiler = bw_compiler
6988 )
7089
7190 def wrapper_fn (args , kwargs ):
0 commit comments