Skip to content

Commit db47630

Browse files
[Compiler Toolkit] Avoid DTensorize BlockMask for FlexAttention (#1952)
``` # TODO: When using flex_attention, BlockMask would show up in kwargs, # and it's unclear how to convert it to DTensor. If I use to_dtensor, # it would fail with Dynamo Error: P2011360347 # dt_kwargs = tree_map(to_dtensor, kwargs) ```
1 parent 8228c08 commit db47630

File tree

3 files changed

+32
-30
lines changed

3 files changed

+32
-30
lines changed

torchtitan/experiments/compiler_toolkit/common_utils.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@
66

77
from contextlib import contextmanager
88

9+
import torch
10+
from torch.distributed.tensor import DTensor, Replicate
11+
from torch.utils._pytree import tree_map
12+
913
from torchtitan.config import JobConfig
1014

1115

@@ -18,3 +22,21 @@ def disable_compile(job_config: JobConfig):
1822
yield
1923
finally:
2024
job_config.compile.enable = original_value
25+
26+
27+
def parallelize_inputs(world_mesh, args, kwargs):
28+
def to_dtensor(tensor):
29+
if isinstance(tensor, torch.Tensor):
30+
return DTensor.from_local(tensor, world_mesh["tp"], [Replicate()])
31+
return tensor
32+
33+
dt_args = tree_map(to_dtensor, args)
34+
35+
# TODO: When using flex_attention, BlockMask would show up in kwargs,
36+
# and it's unclear how to convert it to DTensor. If I use to_dtensor,
37+
# it would fail with Dynamo Error: P2011360347
38+
# dt_kwargs = tree_map(to_dtensor, kwargs)
39+
40+
dt_kwargs = kwargs
41+
42+
return dt_args, dt_kwargs

torchtitan/experiments/compiler_toolkit/deepseek_v3/parallelize.py

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,16 @@
1111
from torch._functorch.aot_autograd import aot_compile_joint_with_descriptors
1212
from torch._guards import tracing
1313

14-
from torch.distributed.tensor import DTensor, Replicate
14+
from torch.distributed.tensor import DTensor
1515

1616
from torch.fx.traceback import annotate_fn
17-
from torch.utils._pytree import tree_map
1817
from torchtitan.config import JobConfig
1918
from torchtitan.distributed import ParallelDims
2019
from torchtitan.distributed.expert_parallel import ExpertParallel
21-
from torchtitan.experiments.compiler_toolkit.common_utils import disable_compile
20+
from torchtitan.experiments.compiler_toolkit.common_utils import (
21+
disable_compile,
22+
parallelize_inputs,
23+
)
2224

2325
from torchtitan.experiments.compiler_toolkit.graph_utils import (
2426
CompiledModule,
@@ -75,18 +77,6 @@ def wrapper_fn(args, kwargs):
7577
return wrapper_fn
7678

7779

78-
def parallelize_inputs(world_mesh, args, kwargs):
79-
def to_dtensor(tensor):
80-
if isinstance(tensor, torch.Tensor):
81-
return DTensor.from_local(tensor, world_mesh["tp"], [Replicate()])
82-
return tensor
83-
84-
dt_args = tree_map(to_dtensor, args)
85-
dt_kwargs = tree_map(to_dtensor, kwargs)
86-
87-
return dt_args, dt_kwargs
88-
89-
9080
def annotate_model() -> None:
9181
# annotate the MoE with dispatch, compute and combine
9282
ExpertParallel._token_dispatch = annotate_fn({"EP": "dispatch"})(

torchtitan/experiments/compiler_toolkit/llama3/parallelize.py

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,15 @@
99
from torch._functorch.aot_autograd import aot_compile_joint_with_descriptors
1010
from torch._guards import tracing
1111

12-
from torch.distributed.tensor import DTensor, Replicate
12+
from torch.distributed.tensor import DTensor
1313
from torch.fx.passes.regional_inductor import regional_inductor
14-
from torch.utils._pytree import tree_map
1514

1615
from torchtitan.config import JobConfig
1716
from torchtitan.distributed import ParallelDims
18-
from torchtitan.experiments.compiler_toolkit.common_utils import disable_compile
17+
from torchtitan.experiments.compiler_toolkit.common_utils import (
18+
disable_compile,
19+
parallelize_inputs,
20+
)
1921

2022
from torchtitan.experiments.compiler_toolkit.graph_utils import (
2123
CompiledModule,
@@ -78,18 +80,6 @@ def wrapper_fn(args, kwargs):
7880
return wrapper_fn
7981

8082

81-
def parallelize_inputs(world_mesh, args, kwargs):
82-
def to_dtensor(tensor):
83-
if isinstance(tensor, torch.Tensor):
84-
return DTensor.from_local(tensor, world_mesh["tp"], [Replicate()])
85-
return tensor
86-
87-
dt_args = tree_map(to_dtensor, args)
88-
dt_kwargs = tree_map(to_dtensor, kwargs)
89-
90-
return dt_args, dt_kwargs
91-
92-
9383
def annotate_model() -> None:
9484
from torch.fx.traceback import annotate_fn
9585
from torchtitan.models.attention import FlexAttentionWrapper

0 commit comments

Comments
 (0)