Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions torchtitan/experiments/simple_fsdp/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,14 @@ def get_compile_backend(backend_name: str) -> Union[str, callable]:
# Perform auto optimization in aten fx-level and execute code in aot_eager backend
# The autobucketing logic is here: https://github.com/pytorch/pytorch/pull/163960
from torch._dynamo.backends.common import aot_autograd as aot_autograd_backend

from torch._inductor.config import aten_distributed_optimizations as dist_opts
from torch._inductor.fx_passes.overlap_scheduling import (
schedule_overlap_bucketing,
)

torch._inductor.config.test_configs.aten_fx_overlap_preserving_bucketing = True
torch._inductor.config.test_configs.aten_fx_overlap_insert_overlap_deps = False
dist_opts.collective_bucketing = True
dist_opts.insert_overlap_deps = False
torch._inductor.config.allow_buffer_reuse = False

def aten_autobucketing_reordering_pass(
Expand Down
16 changes: 2 additions & 14 deletions torchtitan/experiments/simple_fsdp/simple_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
Replicate,
Shard,
)
from torch.distributed.device_mesh import _mesh_resources, DeviceMesh
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor._dtensor_spec import DTensorSpec
from torch.distributed.tensor._redistribute import redistribute_local_tensor
from torch.distributed.tensor.placement_types import _StridedShard, Placement
Expand Down Expand Up @@ -95,19 +95,7 @@ def _distribute_dtensor(
"""
inner_spec = tensor._spec
outer_mesh, inner_mesh = device_mesh, inner_spec.mesh
outer_global_mesh = _mesh_resources.get_root_mesh(outer_mesh)
inner_global_mesh = _mesh_resources.get_root_mesh(inner_mesh)
if outer_global_mesh != inner_global_mesh or (
outer_global_mesh is None or inner_global_mesh is None
):
raise AssertionError(
"Cannot distribute tensor across two meshes without the same root mesh: \n"
f"outer global mesh: {outer_global_mesh}\ninner global mesh: {inner_global_mesh}"
)
assert outer_mesh.mesh_dim_names is not None
assert inner_mesh.mesh_dim_names is not None
submesh_names = outer_mesh.mesh_dim_names + inner_mesh.mesh_dim_names
spanned_mesh = outer_global_mesh[submesh_names]
spanned_mesh = DeviceMesh._concatenate((outer_mesh, inner_mesh))
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

One nit: isn't the argument supposed to be a list, and not a tuple?

If so, how come there is no type checking or other linting to catch this?

Note that we've already observed TorchTitan being somewhat incorrect with types, e.g., it passes lists to init_device_mesh instead of tuples.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, I believe we have not enabled type checking for TorchTitan, which we should.


if len(dp_placements) == 1:
assert dp_placements[0].is_replicate() or dp_placements[0].is_shard()
Expand Down