diff --git a/scripts/generate/test_generate.py b/scripts/generate/test_generate.py index b1d19ad17f..2efec2a494 100644 --- a/scripts/generate/test_generate.py +++ b/scripts/generate/test_generate.py @@ -113,7 +113,7 @@ def test_generate( logger.info(f"Init model on init_device: {init_device}") model = train_spec.model_cls(model_args) - world_mesh = None + parallel_dims = None # Init distributed env if world_size > 1: dist_utils.init_distributed(config.comm) @@ -127,15 +127,14 @@ def test_generate( etp=1, world_size=world_size, ) - world_mesh = parallel_dims.world_mesh # apply_tp (with Sequence Parallel) on unevenly sharded # sequences would require https://github.com/pytorch/torchtitan/pull/686 - apply_tp_minus_sp(model, parallel_dims.world_mesh["tp"]) + apply_tp_minus_sp(model, parallel_dims.get_mesh("tp")) debug_config = DebugConfig(seed=seed, deterministic=deterministic) dist_utils.set_determinism( - world_mesh=world_mesh, + parallel_dims=parallel_dims, device=device, debug_config=debug_config, distinct_seed_mesh_dims=["pp"], diff --git a/tests/unit_tests/test_parallel_dims.py b/tests/unit_tests/test_parallel_dims.py new file mode 100644 index 0000000000..1c3276dc6c --- /dev/null +++ b/tests/unit_tests/test_parallel_dims.py @@ -0,0 +1,561 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest +from unittest.mock import patch + +import torch.distributed as dist +from torch.testing._internal.distributed._tensor.common_dtensor import ( + DTensorTestBase, + with_comms, +) +from torchtitan.distributed import ParallelDims + + +class TestParallelDimsValidation(unittest.TestCase): + """Test ParallelDims validation logic without mesh building.""" + + @patch("torchtitan.distributed.parallel_dims.device_type", "cpu") + def test_basic_initialization(self): + """Test basic initialization with valid parameters.""" + parallel_dims = ParallelDims( + dp_replicate=2, + dp_shard=2, + cp=1, + tp=2, + pp=1, + ep=1, + etp=1, + world_size=8, + ) + self.assertEqual(parallel_dims.dp_replicate, 2) + self.assertEqual(parallel_dims.dp_shard, 2) + self.assertEqual(parallel_dims.cp, 1) + self.assertEqual(parallel_dims.tp, 2) + self.assertEqual(parallel_dims.pp, 1) + self.assertEqual(parallel_dims.ep, 1) + self.assertEqual(parallel_dims.etp, 1) + self.assertEqual(parallel_dims.world_size, 8) + + @patch("torchtitan.distributed.parallel_dims.device_type", "cpu") + def test_auto_calculate_dp_shard(self): + """Test automatic calculation of dp_shard when set to -1.""" + parallel_dims = ParallelDims( + dp_replicate=2, + dp_shard=-1, + cp=1, + tp=2, + pp=1, + ep=1, + etp=1, + world_size=8, + ) + self.assertEqual(parallel_dims.dp_shard, 2) + + @patch("torchtitan.distributed.parallel_dims.device_type", "cpu") + def test_validation_invalid_world_size(self): + """Test validation fails when parallelism degrees don't match world_size.""" + with self.assertRaises(AssertionError): + ParallelDims( + dp_replicate=2, + dp_shard=2, + cp=1, + tp=2, + pp=1, + ep=1, + etp=1, + world_size=10, # Invalid: 2*2*1*2*1 = 8, not 10 + ) + + @patch("torchtitan.distributed.parallel_dims.device_type", "cpu") + def test_validation_invalid_etp(self): + """Test validation fails when etp is not equal to tp or 1.""" + with self.assertRaises(AssertionError): + ParallelDims( + dp_replicate=1, + dp_shard=1, + cp=1, + tp=4, + pp=1, + ep=2, + etp=2, # Invalid: etp must be tp or 1 when ep > 1 + world_size=8, + ) + + @patch("torchtitan.distributed.parallel_dims.device_type", "cpu") + def test_validation_zero_parallelism(self): + """Test validation fails when parallelism degree is 0.""" + with self.assertRaises(AssertionError): + ParallelDims( + dp_replicate=0, # Invalid: must be >= 1 + dp_shard=1, + cp=1, + tp=1, + pp=1, + ep=1, + etp=1, + world_size=1, + ) + + @patch("torchtitan.distributed.parallel_dims.device_type", "cpu") + def test_validation_invalid_dp_shard(self): + """Test validation fails when dp_shard is invalid (not -1 and not >=1).""" + with self.assertRaises(AssertionError): + ParallelDims( + dp_replicate=1, + dp_shard=0, # Invalid: must be -1 or >= 1 + cp=1, + tp=1, + pp=1, + ep=1, + etp=1, + world_size=1, + ) + + @patch("torchtitan.distributed.parallel_dims.device_type", "cpu") + def test_enabled_properties(self): + """Test all enabled properties.""" + # Test with DP enabled + parallel_dims = ParallelDims( + dp_replicate=2, + dp_shard=2, + cp=1, + tp=2, + pp=1, + ep=1, + etp=1, + world_size=8, + ) + self.assertTrue(parallel_dims.dp_enabled) + self.assertTrue(parallel_dims.dp_replicate_enabled) + self.assertTrue(parallel_dims.dp_shard_enabled) + self.assertFalse(parallel_dims.cp_enabled) + self.assertTrue(parallel_dims.tp_enabled) + self.assertFalse(parallel_dims.pp_enabled) + self.assertFalse(parallel_dims.ep_enabled) + self.assertFalse(parallel_dims.etp_enabled) + self.assertTrue(parallel_dims.fsdp_enabled) + + # Test with CP enabled + parallel_dims = ParallelDims( + dp_replicate=1, + dp_shard=1, + cp=2, + tp=1, + pp=1, + ep=1, + etp=1, + world_size=2, + ) + self.assertFalse(parallel_dims.dp_enabled) + self.assertTrue(parallel_dims.cp_enabled) + self.assertTrue(parallel_dims.dp_cp_enabled) + self.assertTrue(parallel_dims.fsdp_enabled) + + # Test with EP and ETP enabled (EP * ETP must not contribute to world_size) + parallel_dims = ParallelDims( + dp_replicate=1, + dp_shard=2, + cp=1, + tp=1, + pp=1, + ep=2, + etp=1, + world_size=2, + ) + self.assertTrue(parallel_dims.ep_enabled) + self.assertFalse(parallel_dims.etp_enabled) + + # Test with PP enabled + parallel_dims = ParallelDims( + dp_replicate=1, + dp_shard=1, + cp=1, + tp=1, + pp=2, + ep=1, + etp=1, + world_size=2, + ) + self.assertTrue(parallel_dims.pp_enabled) + + @patch("torchtitan.distributed.parallel_dims.device_type", "cpu") + def test_fsdp_gradient_divide_factor(self): + """Test fsdp_gradient_divide_factor calculation.""" + parallel_dims = ParallelDims( + dp_replicate=2, + dp_shard=3, + cp=2, + tp=1, + pp=1, + ep=1, + etp=1, + world_size=12, + ) + # Should be dp_replicate * dp_shard * cp = 2 * 3 * 2 = 12 + self.assertEqual(parallel_dims.fsdp_gradient_divide_factor, 12) + + @patch("torchtitan.distributed.parallel_dims.device_type", "cpu") + def test_non_data_parallel_size(self): + """Test non_data_parallel_size calculation.""" + parallel_dims = ParallelDims( + dp_replicate=2, + dp_shard=2, + cp=2, + tp=3, + pp=2, + ep=1, + etp=1, + world_size=48, + ) + # Should be cp * tp * pp = 2 * 3 * 2 = 12 + self.assertEqual(parallel_dims.non_data_parallel_size, 12) + + @patch("torchtitan.distributed.parallel_dims.device_type", "cpu") + def test_seq_len_divisor(self): + """Test seq_len_divisor calculation.""" + parallel_dims = ParallelDims( + dp_replicate=2, + dp_shard=1, + cp=2, + tp=4, + pp=1, + ep=1, + etp=1, + world_size=16, + ) + # Should be tp * (cp * 2) = 4 * 4 = 16 + self.assertEqual(parallel_dims.seq_len_divisor, 16) + + +class TestParallelDimsMeshOperations(unittest.TestCase): + """Test ParallelDims mesh operations with single-rank distributed environment.""" + + def setUp(self): + """Initialize distributed environment for CPU testing.""" + if not dist.is_initialized(): + dist.init_process_group( + backend="gloo", + init_method="tcp://localhost:12356", + world_size=1, + rank=0, + ) + + def tearDown(self): + """Clean up distributed environment.""" + if dist.is_initialized(): + dist.destroy_process_group() + + @patch("torchtitan.distributed.parallel_dims.device_type", "cpu") + def test_get_mesh_invalid_name(self): + """Test getting mesh with invalid name raises error.""" + parallel_dims = ParallelDims( + dp_replicate=1, + dp_shard=1, + cp=1, + tp=1, + pp=1, + ep=1, + etp=1, + world_size=1, + ) + parallel_dims.build_mesh() + + with self.assertRaises(ValueError) as context: + parallel_dims.get_mesh("invalid_mesh") + self.assertIn("Invalid mesh dim", str(context.exception)) + + @patch("torchtitan.distributed.parallel_dims.device_type", "cpu") + def test_get_mesh_lazy_initialization(self): + """Test that get_mesh triggers build_mesh if not built yet.""" + parallel_dims = ParallelDims( + dp_replicate=1, + dp_shard=1, + cp=1, + tp=1, + pp=1, + ep=1, + etp=1, + world_size=1, + ) + # Don't call build_mesh explicitly + self.assertEqual(len(parallel_dims._meshes), 0) + + # get_mesh should trigger build_mesh + result = parallel_dims.get_mesh("tp") + # Result is None because tp has size 1, but build_mesh should have been called + self.assertGreater(len(parallel_dims._meshes), 0) + + @patch("torchtitan.distributed.parallel_dims.device_type", "cpu") + def test_single_rank_mesh_operations(self): + """Comprehensive test for all single-rank mesh operations. + + This test verifies mesh building, mesh retrieval, mesh sizes, and property + access when all parallelism dimensions are set to 1 (single rank). + """ + parallel_dims = ParallelDims( + dp_replicate=1, + dp_shard=1, + cp=1, + tp=1, + pp=1, + ep=1, + etp=1, + world_size=1, + ) + + # Test mesh building + world_mesh = parallel_dims.build_mesh() + self.assertIsNotNone(world_mesh) + self.assertEqual(world_mesh.size(), 1) + + # Verify all expected meshes are created + self.assertIsNotNone(parallel_dims._meshes) + self.assertIn("pp", parallel_dims._meshes) + self.assertIn("batch", parallel_dims._meshes) + self.assertIn("loss", parallel_dims._meshes) + self.assertIn("dp_replicate", parallel_dims._meshes) + self.assertIn("fsdp", parallel_dims._meshes) + self.assertIn("cp", parallel_dims._meshes) + self.assertIn("tp", parallel_dims._meshes) + + # Validate 1D mesh sizes - all should be 1 for single rank + self.assertEqual(parallel_dims._meshes["dp_replicate"].size(), 1) + self.assertEqual(parallel_dims._meshes["fsdp"].size(), 1) + self.assertEqual(parallel_dims._meshes["tp"].size(), 1) + self.assertEqual(parallel_dims._meshes["batch"].size(), 1) + self.assertEqual(parallel_dims._meshes["loss"].size(), 1) + self.assertEqual(parallel_dims._meshes["pp"].size(), 1) + self.assertEqual(parallel_dims._meshes["cp"].size(), 1) + self.assertEqual(parallel_dims._meshes["ep"].size(), 1) + self.assertEqual(parallel_dims._meshes["etp"].size(), 1) + self.assertEqual(parallel_dims._meshes["efsdp"].size(), 1) + + # Validate 2D mesh shapes + self.assertEqual(parallel_dims._meshes["dp_replicate_fsdp"].shape, (1, 1)) + self.assertEqual(parallel_dims._meshes["dp_replicate_efsdp"].shape, (1, 1)) + self.assertEqual(parallel_dims._meshes["ep_etp"].shape, (1, 1)) + + # Test get_mesh returns None when all dimensions have size 1 + self.assertIsNone(parallel_dims.get_mesh("tp")) + self.assertIsNone(parallel_dims.get_mesh("dp_replicate")) + self.assertIsNone(parallel_dims.get_mesh("pp")) + self.assertIsNone(parallel_dims.get_mesh("cp")) + self.assertIsNone(parallel_dims.get_mesh("fsdp")) + + # Test get_mesh with list input + self.assertIsNone(parallel_dims.get_mesh(["dp_replicate", "fsdp"])) + + # Test get_all_meshes returns empty when all dimensions have size 1 + one_d_meshes = parallel_dims.get_all_meshes(one_dimensioal_only=True) + self.assertEqual(len(one_d_meshes), 0) + + all_meshes = parallel_dims.get_all_meshes(one_dimensioal_only=False) + self.assertEqual(len(all_meshes), 0) + + # Test world_mesh property + world_mesh_property = parallel_dims.world_mesh + self.assertIsNotNone(world_mesh_property) + self.assertEqual(world_mesh_property.size(), 1) + + @patch("torchtitan.distributed.parallel_dims.device_type", "cpu") + def test_get_mesh_with_list_input(self): + """Test get_mesh accepts both string and list inputs.""" + parallel_dims = ParallelDims( + dp_replicate=1, + dp_shard=1, + cp=1, + tp=1, + pp=1, + ep=1, + etp=1, + world_size=1, + ) + parallel_dims.build_mesh() + + # Should accept list input + result = parallel_dims.get_mesh(["dp_replicate", "fsdp"]) + # Returns None because both dimensions have size 1 + self.assertIsNone(result) + + @patch("torchtitan.distributed.parallel_dims.device_type", "cpu") + def test_expert_parallelism_validation(self): + """Test expert parallelism configurations.""" + # EP with ETP = 1 (valid) - world_size = dp_replicate * dp_shard * cp * tp * pp + parallel_dims = ParallelDims( + dp_replicate=1, + dp_shard=2, + cp=1, + tp=1, + pp=1, + ep=2, + etp=1, + world_size=2, # 1 * 2 * 1 * 1 * 1 = 2 + ) + self.assertTrue(parallel_dims.ep_enabled) + self.assertFalse(parallel_dims.etp_enabled) + + # Test with larger configuration + parallel_dims = ParallelDims( + dp_replicate=2, + dp_shard=2, + cp=1, + tp=1, + pp=1, + ep=3, + etp=1, + world_size=4, # 2 * 2 * 1 * 1 * 1 = 4 + ) + self.assertTrue(parallel_dims.ep_enabled) + self.assertFalse(parallel_dims.etp_enabled) + self.assertTrue(parallel_dims.dp_replicate_enabled) + self.assertTrue(parallel_dims.dp_shard_enabled) + + +class TestParallelDimsWorld8MeshOperations(DTensorTestBase): + """Test ParallelDims mesh operations with 8-rank distributed environment.""" + + @property + def world_size(self): + return 8 + + @with_comms + def test_world_size_8_mesh_operations(self): + """Comprehensive test for world_size=8 mesh operations. + + This test validates mesh building, mesh retrieval, mesh sizes, and properties + for a world_size=8 configuration with multiple parallelism dimensions enabled. + Configuration: dp_replicate=2, dp_shard=2, cp=1, tp=2, pp=1 (2*2*1*2*1 = 8) + """ + with patch( + "torchtitan.distributed.parallel_dims.device_type", self.device_type + ): + parallel_dims = ParallelDims( + dp_replicate=2, + dp_shard=2, + cp=1, + tp=2, + pp=1, + ep=1, + etp=1, + world_size=8, + ) + + # Test mesh building + world_mesh = parallel_dims.build_mesh() + self.assertIsNotNone(world_mesh) + self.assertEqual(world_mesh.size(), 8) + + # Verify all expected meshes are created + self.assertIsNotNone(parallel_dims._meshes) + self.assertIn("pp", parallel_dims._meshes) + self.assertIn("batch", parallel_dims._meshes) + self.assertIn("loss", parallel_dims._meshes) + self.assertIn("dp_replicate", parallel_dims._meshes) + self.assertIn("fsdp", parallel_dims._meshes) + self.assertIn("cp", parallel_dims._meshes) + self.assertIn("tp", parallel_dims._meshes) + self.assertIn("ep", parallel_dims._meshes) + self.assertIn("etp", parallel_dims._meshes) + self.assertIn("efsdp", parallel_dims._meshes) + + # Validate 1D mesh sizes match parallelism configuration + self.assertEqual(parallel_dims._meshes["pp"].size(), 1) + self.assertEqual( + parallel_dims._meshes["batch"].size(), 4 + ) # dp_replicate * dp_shard = 2 * 2 + self.assertEqual( + parallel_dims._meshes["loss"].size(), 4 + ) # dp_replicate * dp_shard * cp = 2 * 2 * 1 + self.assertEqual(parallel_dims._meshes["dp_replicate"].size(), 2) + self.assertEqual( + parallel_dims._meshes["fsdp"].size(), 2 + ) # dp_shard * cp = 2 * 1 + self.assertEqual(parallel_dims._meshes["cp"].size(), 1) + self.assertEqual(parallel_dims._meshes["tp"].size(), 2) + self.assertEqual(parallel_dims._meshes["ep"].size(), 1) + self.assertEqual(parallel_dims._meshes["etp"].size(), 1) + self.assertEqual( + parallel_dims._meshes["efsdp"].size(), 4 + ) # fsdp * tp / (etp * ep) = 2 * 2 / (1 * 1) = 4 + + # Validate 2D mesh shapes + self.assertEqual( + parallel_dims._meshes["dp_replicate_fsdp"].shape, (2, 2) + ) # (dp_replicate, fsdp) + self.assertEqual( + parallel_dims._meshes["dp_replicate_efsdp"].shape, (2, 4) + ) # (dp_replicate, efsdp) + self.assertEqual(parallel_dims._meshes["ep_etp"].shape, (1, 1)) # (ep, etp) + + # Test get_mesh returns valid meshes for enabled dimensions (size > 1) + self.assertIsNotNone(parallel_dims.get_mesh("tp")) + self.assertIsNotNone(parallel_dims.get_mesh("dp_replicate")) + self.assertIsNotNone(parallel_dims.get_mesh("fsdp")) + self.assertIsNotNone(parallel_dims.get_mesh("batch")) + self.assertIsNotNone(parallel_dims.get_mesh("loss")) + + # Test get_mesh returns None for disabled dimensions (size = 1) + self.assertIsNone(parallel_dims.get_mesh("pp")) + self.assertIsNone(parallel_dims.get_mesh("cp")) + self.assertIsNone(parallel_dims.get_mesh("ep")) + + # Test get_mesh with 2D mesh names + self.assertIsNotNone(parallel_dims.get_mesh(["dp_replicate", "fsdp"])) + hsdp_mesh = parallel_dims.get_mesh(["dp_replicate", "fsdp"]) + self.assertEqual(hsdp_mesh.shape, (2, 2)) + + # Test get_all_meshes returns only meshes with size > 1 + one_d_meshes = parallel_dims.get_all_meshes(one_dimensioal_only=True) + self.assertGreater(len(one_d_meshes), 0) + # Should include: dp_replicate, fsdp, tp, batch, loss, efsdp (all with size > 1) + self.assertIn("dp_replicate", one_d_meshes) + self.assertIn("fsdp", one_d_meshes) + self.assertIn("tp", one_d_meshes) + self.assertIn("batch", one_d_meshes) + self.assertIn("loss", one_d_meshes) + self.assertIn("efsdp", one_d_meshes) + # Should not include: pp, cp, ep, etp (all with size = 1) + self.assertNotIn("pp", one_d_meshes) + self.assertNotIn("cp", one_d_meshes) + self.assertNotIn("ep", one_d_meshes) + self.assertNotIn("etp", one_d_meshes) + + all_meshes = parallel_dims.get_all_meshes(one_dimensioal_only=False) + self.assertGreater(len(all_meshes), len(one_d_meshes)) + # Should also include 2D meshes + self.assertIn("dp_replicate_fsdp", all_meshes) + self.assertIn("dp_replicate_efsdp", all_meshes) + + # Test world_mesh property + world_mesh_property = parallel_dims.world_mesh + self.assertIsNotNone(world_mesh_property) + self.assertEqual(world_mesh_property.size(), 8) + + # Validate enabled properties + self.assertTrue(parallel_dims.dp_enabled) + self.assertTrue(parallel_dims.dp_replicate_enabled) + self.assertTrue(parallel_dims.dp_shard_enabled) + self.assertTrue(parallel_dims.fsdp_enabled) + self.assertTrue(parallel_dims.tp_enabled) + self.assertFalse(parallel_dims.cp_enabled) + self.assertFalse(parallel_dims.pp_enabled) + self.assertFalse(parallel_dims.ep_enabled) + + # Validate calculated properties + self.assertEqual( + parallel_dims.fsdp_gradient_divide_factor, 4 + ) # dp_replicate * dp_shard * cp = 2 * 2 * 1 + self.assertEqual( + parallel_dims.non_data_parallel_size, 2 + ) # cp * tp * pp = 1 * 2 * 1 + self.assertEqual( + parallel_dims.seq_len_divisor, 4 + ) # tp * (cp * 2) = 2 * (1 * 2) = 2 * 2 + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit_tests/test_set_determinism.py b/tests/unit_tests/test_set_determinism.py index c8087731c5..545603e4e8 100644 --- a/tests/unit_tests/test_set_determinism.py +++ b/tests/unit_tests/test_set_determinism.py @@ -13,8 +13,8 @@ from torchtitan.distributed.utils import set_determinism -class FakeDeviceMesh: - """Fake DeviceMesh for testing seed uniqueness. +class FakeParallelDims: + """Fake ParallelDims for testing seed uniqueness. Args: mesh_dim_names: List of dimension names (e.g., ["dp", "pp", "tp"]) @@ -26,25 +26,46 @@ def __init__(self, mesh_dim_names, mesh_sizes, rank_coords): self.mesh_dim_names = mesh_dim_names self.mesh_sizes = dict(zip(mesh_dim_names, mesh_sizes)) self.rank_coords = dict(zip(mesh_dim_names, rank_coords)) + # Calculate world_size as product of all mesh sizes + self.world_size = 1 + for size in mesh_sizes: + self.world_size *= size - def __getitem__(self, key): - """Return a submesh for the given dimension(s).""" + # Create a world_mesh mock + self.world_mesh = MagicMock() + + def get_mesh(self, key): + """Return a submesh for the given dimension.""" if isinstance(key, str): # Single dimension + if key not in self.mesh_dim_names: + return None submesh = MagicMock() submesh.get_local_rank.return_value = self.rank_coords[key] submesh.size.return_value = self.mesh_sizes[key] submesh.get_coordinate.return_value = self.rank_coords[key] + submesh.device_type = "cpu" return submesh elif isinstance(key, list): - # Multiple dimensions + # Multiple dimensions - check if all exist + if not all(dim in self.mesh_dim_names for dim in key): + return None submesh = MagicMock() # For multiple dimensions, get_coordinate should return None # since we're not testing this path submesh.get_coordinate.return_value = None + submesh.device_type = "cpu" return submesh else: - raise ValueError(f"Unsupported key type: {type(key)}") + return None + + def get_all_meshes(self): + """Return a dict of all meshes.""" + return {dim: self.get_mesh(dim) for dim in self.mesh_dim_names} + + def __getitem__(self, key): + """Return a submesh for the given dimension(s) - for backward compatibility.""" + return self.get_mesh(key) def get_coordinate(self): """Return the coordinate tuple for this rank.""" @@ -85,12 +106,12 @@ def test_seed_uniqueness_2d_mesh(self, mock_get_rank, mock_get_world_size): # Create fake mesh for this rank rank_coords = (dp_rank, pp_rank) - fake_mesh = FakeDeviceMesh(mesh_dim_names, mesh_sizes, rank_coords) + fake_mesh = FakeParallelDims(mesh_dim_names, mesh_sizes, rank_coords) # Call set_determinism with distinct seeds only on PP dimension debug_config = DebugConfig(seed=base_seed, deterministic=False) set_determinism( - world_mesh=fake_mesh, + parallel_dims=fake_mesh, device=self.device, debug_config=debug_config, distinct_seed_mesh_dims=["pp"], @@ -154,12 +175,14 @@ def test_seed_uniqueness_3d_mesh(self, mock_get_rank, mock_get_world_size): # Create fake mesh for this rank rank_coords = (dp_shard_rank, dp_replicate_rank, tp_rank) - fake_mesh = FakeDeviceMesh(mesh_dim_names, mesh_sizes, rank_coords) + fake_mesh = FakeParallelDims( + mesh_dim_names, mesh_sizes, rank_coords + ) # Call set_determinism with distinct seeds on dp_shard and dp_replicate only debug_config = DebugConfig(seed=base_seed, deterministic=False) set_determinism( - world_mesh=fake_mesh, + parallel_dims=fake_mesh, device=self.device, debug_config=debug_config, distinct_seed_mesh_dims=["dp_shard", "dp_replicate"], @@ -218,12 +241,14 @@ def test_set_determinism_single_gpu(self, mock_get_rank, mock_get_world_size): base_seed = 42 fake_mesh = MagicMock() - fake_mesh.mesh_dim_names = None - fake_mesh.get_coordinate.return_value = None + fake_mesh.world_size = 1 + fake_mesh.world_mesh = MagicMock() + fake_mesh.get_mesh.return_value = None + fake_mesh.get_all_meshes.return_value = {} debug_config = DebugConfig(seed=base_seed, deterministic=False) set_determinism( - world_mesh=fake_mesh, + parallel_dims=fake_mesh, device=self.device, debug_config=debug_config, distinct_seed_mesh_dims=["pp"], diff --git a/torchtitan/components/optimizer.py b/torchtitan/components/optimizer.py index 7fc5098800..3ebdead308 100644 --- a/torchtitan/components/optimizer.py +++ b/torchtitan/components/optimizer.py @@ -356,8 +356,8 @@ def _update_expert_bias( model_parts: list[nn.Module], parallel_dims: ParallelDims, ): - dp_cp_mesh = ( - parallel_dims.world_mesh["dp_cp"] if parallel_dims.dp_cp_enabled else None + loss_mesh = ( + parallel_dims.get_mesh("loss") if parallel_dims.dp_cp_enabled else None ) # TODO: Currently this sync is blocking (thus exposed) and happens on the # default compute stream. Need to assess if this is OK performance-wise. @@ -379,9 +379,9 @@ def _update_expert_bias( tokens_per_expert_by_layer = torch.vstack(tokens_per_expert_list) - if dp_cp_mesh is not None: + if loss_mesh is not None: # Perform single all-reduce to get global statistics across all processes - pg = dp_cp_mesh.get_group() + pg = loss_mesh.get_group() torch.distributed.all_reduce( tokens_per_expert_by_layer, group=pg, op=torch.distributed.ReduceOp.SUM ) diff --git a/torchtitan/components/validate.py b/torchtitan/components/validate.py index 93fb68a3cc..5dcebfb94b 100644 --- a/torchtitan/components/validate.py +++ b/torchtitan/components/validate.py @@ -113,7 +113,7 @@ def validate( optional_context_parallel_ctx = ( dist_utils.create_context_parallel_ctx( - cp_mesh=parallel_dims.world_mesh["cp"], + cp_mesh=parallel_dims.get_mesh("cp"), cp_buffers=[inputs, labels] + [m.freqs_cis for m in model_parts], cp_seq_dims=[1, 1] + [0 for _ in model_parts], cp_no_restore_buffers={inputs, labels}, @@ -166,9 +166,7 @@ def validate( loss = torch.sum(torch.stack(accumulated_losses)) loss /= num_steps if parallel_dims.dp_cp_enabled: - global_avg_loss = dist_utils.dist_mean( - loss, parallel_dims.world_mesh["dp_cp"] - ) + global_avg_loss = dist_utils.dist_mean(loss, parallel_dims.get_mesh("loss")) else: global_avg_loss = loss.item() diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index 95588d2c3b..400efac47b 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -388,19 +388,7 @@ class Parallelism: """ Expert parallelism degree. 1 means disabled. No effect for non-MoE models. - Currently, it is supported with the following constraints: - - - when etp = tp: - - - cp <= ep <= dp_shard * cp - - ep % cp == 0 - - dp_shard * cp % ep == 0 - - - when etp = 1: - - - cp * tp <= ep <= dp_shard * cp * tp - - ep % (cp * tp) == 0 - - dp_shard * cp * tp % ep == 0 + Currently, etp is either 1 or is the same as tp. Note that this is still an experimental feature. Some constraints will be relaxed soon when we have more flexible DeviceMesh support. diff --git a/torchtitan/distributed/expert_parallel.py b/torchtitan/distributed/expert_parallel.py index e9986b9974..edaf9f8108 100644 --- a/torchtitan/distributed/expert_parallel.py +++ b/torchtitan/distributed/expert_parallel.py @@ -177,7 +177,7 @@ def _token_dispatch(self, mod, inputs, device_mesh): # The grad_placements on inputs is set to Partial so that necessary # reductions are performed during backward. routed_input = DTensor.from_local( - routed_input, device_mesh["tp"], (Replicate(),) + routed_input, device_mesh["etp"], (Replicate(),) ).to_local(grad_placements=(Partial(),)) inputs = (routed_input, num_tokens_per_expert) diff --git a/torchtitan/distributed/parallel_dims.py b/torchtitan/distributed/parallel_dims.py index 44822039a6..0e7420e26a 100644 --- a/torchtitan/distributed/parallel_dims.py +++ b/torchtitan/distributed/parallel_dims.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from dataclasses import dataclass +from dataclasses import dataclass, field from torch.distributed.device_mesh import DeviceMesh, init_device_mesh @@ -26,7 +26,8 @@ class ParallelDims: etp: int world_size: int - _world_mesh: DeviceMesh = None + _meshes: dict[str, DeviceMesh] = field(default_factory=dict) + _world_mesh: DeviceMesh | None = None def __post_init__(self): self._validate() @@ -56,145 +57,207 @@ def _validate(self): if ep > 1: assert etp == tp or etp == 1, "Currently we only support ETP=TP or ETP=1" - if etp == tp: - # EP would borrow all cp and some dp_shard degree - assert ep % cp == 0 and (dp_shard * cp) % ep == 0 - elif etp == 1: - # EP would borrow all cp and tp and some dp_shard degree - assert ep % (cp * tp) == 0 and (dp_shard * cp * tp) % ep == 0 + + def _mesh_exist(self, name: str, degree: int) -> bool: + if name == "efsdp": + return True if self.ep > 1 else False + return degree > 1 def build_mesh(self) -> DeviceMesh: - # TODO: Current implementation of ParallelDims for dp2ep Expert Parallel - # is not very clean, due to the limited support from DeviceMesh - # for creating two staggered meshes. Will improve. - if self.ep > 1: - return self._build_mesh_with_ep() - else: - return self._build_mesh_without_ep() - - def _build_mesh_with_ep(self) -> DeviceMesh: - # With ep, dp_shard and ep are derived submeshes: - # dp_shard = dp_shard_mod_ep * dp_shard_in_ep - if self.etp == self.tp: - # ep = dp_shard_in_ep * cp - dp_shard_mod_ep = self.dp_shard * self.cp // self.ep - dp_shard_in_ep = self.ep // self.cp - else: - assert self.etp == 1 - # ep = dp_shard_in_ep * cp * tp - dp_shard_mod_ep = self.dp_shard * self.cp * self.tp // self.ep - dp_shard_in_ep = self.ep // (self.cp * self.tp) - - dims = [] - names = [] - for d, name in zip( - [ - self.pp, - self.dp_replicate, - dp_shard_mod_ep, - dp_shard_in_ep, - self.cp, - self.tp, - ], - ["pp", "dp_replicate", "dp_shard_mod_ep", "dp_shard_in_ep", "cp", "tp"], + """ + Build the device mesh with the required mesh dimensions. + + The following mesh dimensions will be created: + + pp: Pipeline Parallelism (PP). + batch: Used by data loading to determine the global batch size and which + part of the data each rank should read. This dimension includes both + ``dp_replicate`` and ``dp_shard``. The backend is set to ``fake`` for + this dimension to avoid unnecessary process group creation. + loss: Used by all-reduce when computing the loss. Includes ``dp_replicate``, + ``dp_shard``, and ``cp`` degrees, as all are data parallelisms. + dp_replicate: For DDP or HSDP replicate dimension. + fsdp: For FSDP dimension. This includes ``dp_shard`` and ``cp``. + cp: Context Parallelism (CP). + tp: Tensor Parallelism (TP). + ep: Expert Parallelism (EP). + efsdp: FSDP in the EP region. + etp: TP in the EP region. + + Note: Most dimensions above are created by unflattening the world mesh, except for loss, + which is created by flattening the batch and cp dimensions. + This API performs the following unflatten operations: + + ["pp", "batch", "cp", "tp"] + ["pp", "dp_replicate", "fsdp", "tp"] + ["pp", "dp_replicate", "efsdp", "ep", "etp"] + + Note: DeviceMesh currently recreates the process group for each dimension. + It should share the process group for the same dim group to avoid unnecessary + process group creation. + """ + + def unflatten_mesh( + world_mesh: DeviceMesh, dim_names: tuple[str], dim_degrees: tuple[int] ): - # dp_shard_mod_ep is needed even if it's 1, whose FSDP wrapping - # helps the MoE layers do mixed precision training - if d > 1 or name == "dp_shard_mod_ep": - dims.append(d) - names.append(name) - - logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}") - mesh = init_device_mesh(device_type, dims, mesh_dim_names=names) - - # Create all the submesh here to ensure all required process groups are - # initialized: - # Mesh for data loading (no communication on this mesh) - dp_mesh_dim_names = [] - # Mesh for param sharding - dp_shard_cp_mesh_dim_names = [] - # Mesh for loss all-reduce - dp_cp_mesh_dim_names = [] - # Mesh for ep - ep_mesh_dim_names = [] - - if self.dp_replicate_enabled: - dp_mesh_dim_names.append("dp_replicate") - dp_cp_mesh_dim_names.append("dp_replicate") - # dp_shard_mod_ep is always needed, even if it's 1 - dp_mesh_dim_names.append("dp_shard_mod_ep") - dp_shard_cp_mesh_dim_names.append("dp_shard_mod_ep") - dp_cp_mesh_dim_names.append("dp_shard_mod_ep") - if "dp_shard_in_ep" in names: - dp_mesh_dim_names.append("dp_shard_in_ep") - dp_shard_cp_mesh_dim_names.append("dp_shard_in_ep") - dp_cp_mesh_dim_names.append("dp_shard_in_ep") - ep_mesh_dim_names.append("dp_shard_in_ep") - if self.cp_enabled: - dp_shard_cp_mesh_dim_names.append("cp") - dp_cp_mesh_dim_names.append("cp") - ep_mesh_dim_names.append("cp") - if self.etp == 1 and self.tp_enabled: - ep_mesh_dim_names.append("tp") - - mesh[tuple(dp_mesh_dim_names)]._flatten(mesh_dim_name="dp") - mesh[tuple(dp_shard_cp_mesh_dim_names)]._flatten(mesh_dim_name="dp_shard_cp") - mesh[tuple(dp_cp_mesh_dim_names)]._flatten(mesh_dim_name="dp_cp") - mesh[tuple(ep_mesh_dim_names)]._flatten(mesh_dim_name="ep") - - return mesh - - def _build_mesh_without_ep(self) -> DeviceMesh: - dims = [] - names = [] - for d, name in zip( - [self.pp, self.dp_replicate, self.dp_shard, self.cp, self.tp], - ["pp", "dp_replicate", "dp_shard", "cp", "tp"], - ): - if d > 1: - dims.append(d) - names.append(name) - - logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}") - mesh = init_device_mesh(device_type, dims, mesh_dim_names=names) - - # Create all the submesh here to ensure all required process groups are - # initialized: - # Mesh for data loading (no communication on this mesh) - dp_mesh_dim_names = [] - # Mesh for param sharding - dp_shard_cp_mesh_dim_names = [] - # Mesh for loss all-reduce - dp_cp_mesh_dim_names = [] - - if self.dp_replicate_enabled: - dp_mesh_dim_names.append("dp_replicate") - dp_cp_mesh_dim_names.append("dp_replicate") - if self.dp_shard_enabled: - dp_mesh_dim_names.append("dp_shard") - dp_shard_cp_mesh_dim_names.append("dp_shard") - dp_cp_mesh_dim_names.append("dp_shard") - if self.cp_enabled: - dp_shard_cp_mesh_dim_names.append("cp") - dp_cp_mesh_dim_names.append("cp") - - if dp_mesh_dim_names != []: - mesh[tuple(dp_mesh_dim_names)]._flatten(mesh_dim_name="dp") - if dp_shard_cp_mesh_dim_names != []: - mesh[tuple(dp_shard_cp_mesh_dim_names)]._flatten( - mesh_dim_name="dp_shard_cp" + """Unflatten the world mesh to create the required mesh dimensions. + + Uses fake backend for dimensions with degree 1 or for 'batch' dimension + to avoid unnecessary process group creation. + """ + backend_override = {} + for name, degree in zip(dim_names, dim_degrees, strict=True): + if (not self._mesh_exist(name, degree)) or name == "batch": + backend_override[name] = "fake" + + return world_mesh._unflatten( + 0, dim_degrees, dim_names, backend_override=backend_override ) - if dp_cp_mesh_dim_names != []: - mesh[tuple(dp_cp_mesh_dim_names)]._flatten(mesh_dim_name="dp_cp") - return mesh + logger.info( + f"Building device mesh with parallelism: " + f"pp={self.pp}, dp_replicate={self.dp_replicate}, dp_shard={self.dp_shard}, " + f"cp={self.cp}, tp={self.tp}, ep={self.ep}, etp={self.etp}" + ) + + batch = self.dp_replicate * self.dp_shard + loss = self.dp_replicate * self.dp_shard * self.cp + fsdp = self.dp_shard * self.cp + efsdp = fsdp * self.tp // (self.etp * self.ep) + + self._world_mesh = init_device_mesh( + device_type, (self.world_size,), mesh_dim_names=("world",) + ) + dataloading_mesh = unflatten_mesh( + self._world_mesh, + ("pp", "batch", "cp", "tp"), + (self.pp, batch, self.cp, self.tp), + ) + loss_mesh = dataloading_mesh["batch", "cp"]._flatten("loss_mesh") + dense_mesh = unflatten_mesh( + self._world_mesh, + ("pp", "dp_replicate", "fsdp", "tp"), + (self.pp, self.dp_replicate, fsdp, self.tp), + ) + sparse_mesh = unflatten_mesh( + self._world_mesh, + ("pp", "dp_replicate", "efsdp", "ep", "etp"), + (self.pp, self.dp_replicate, efsdp, self.ep, self.etp), + ) + + # We have created all the required 1D meshes. This part is to create the + # all the 2D meshes. We pre-created 2D meshes and error out if the users + # try to access a 2D mesh that is not pre-created. + hsdp_mesh = dense_mesh["dp_replicate", "fsdp"] + ehsdp_mesh = sparse_mesh["dp_replicate", "efsdp"] + ep_etp_mesh = sparse_mesh["ep", "etp"] + + self._meshes = { + "pp": dataloading_mesh["pp"], + "batch": dataloading_mesh["batch"], + "loss": loss_mesh, + "dp_replicate": dense_mesh["dp_replicate"], + "fsdp": dense_mesh["fsdp"], + "cp": dataloading_mesh["cp"], + "tp": dataloading_mesh["tp"], + "ep": sparse_mesh["ep"], + "efsdp": sparse_mesh["efsdp"], + "etp": sparse_mesh["etp"], + "dp_replicate_fsdp": hsdp_mesh, + "dp_replicate_efsdp": ehsdp_mesh, + "ep_etp": ep_etp_mesh, + } + + # Validate mesh sizes + self._validate_meshes() + + logger.info( + f"Successfully created meshes with active dimensions: " + f"{list(self.get_all_meshes().keys())}" + ) + + return self._world_mesh + + def _validate_meshes(self): + """Validate that created meshes have the expected sizes.""" + expected_sizes = { + "pp": self.pp, + "batch": self.dp_replicate * self.dp_shard, + "loss": self.dp_replicate * self.dp_shard * self.cp, + "dp_replicate": self.dp_replicate, + "fsdp": self.dp_shard * self.cp, + "cp": self.cp, + "tp": self.tp, + "ep": self.ep, + "efsdp": self.dp_shard * self.cp * self.tp // (self.etp * self.ep), + "etp": self.etp, + "dp_replicate_fsdp": (self.dp_replicate, self.dp_shard * self.cp), + "dp_replicate_efsdp": ( + self.dp_replicate, + self.dp_shard * self.cp * self.tp // (self.etp * self.ep), + ), + "ep_etp": (self.ep, self.etp), + } + + for mesh_name, expected_size in expected_sizes.items(): + if isinstance(expected_size, tuple): + actual_size = self._meshes[mesh_name].shape + else: + actual_size = self._meshes[mesh_name].size() + assert actual_size == expected_size, ( + f"Mesh '{mesh_name}' has unexpected size: " + f"expected {expected_size}, got {actual_size}" + ) + + def get_mesh(self, dims: str | list[str]) -> DeviceMesh | None: + """Get a device mesh by dimension names. + + Args: + dims: Names of the mesh dimension. Valid options include: + 'pp', 'batch', 'loss', 'dp_replicate', 'fsdp', + 'cp', 'tp', 'ep', 'etp', 'efsdp' + + Returns: + DeviceMesh for the requested dimension(s). The DeviceMesh exists if + 1) dimension size is larger than 1 (the parallelism is enabled) + 2) efsdp is enabled even if size is 1 if ep is > 1. + The return value if None otherwise. + + Raises: + ValueError: If the requested dimension name(s) is not valid. + """ + if not self._meshes: + self.build_mesh() + + if isinstance(dims, str): + dims = [dims] + + mesh_name = "_".join(dims) + if mesh_name not in self._meshes: + raise ValueError( + f"Invalid mesh dim: '{mesh_name}'. " + f"Valid dimensions are: {list(self._meshes.keys())}" + ) + + if any(not self._mesh_exist(dim, self._meshes[dim].size()) for dim in dims): + return None + + return self._meshes[mesh_name] + + def get_all_meshes(self, one_dimensioal_only: bool = True) -> dict[str, DeviceMesh]: + if not self._meshes: + self.build_mesh() + if one_dimensioal_only: + return { + k: v for k, v in self._meshes.items() if v.ndim == 1 and v.size() > 1 + } + else: + return {k: v for k, v in self._meshes.items() if v.size() > 1} @property def world_mesh(self) -> DeviceMesh: - # doing late init so ParallelDims can still be used as a lightweight - # dataclass without having to initialize the world mesh if self._world_mesh is None: - self._world_mesh = self.build_mesh() + self.build_mesh() return self._world_mesh @property diff --git a/torchtitan/distributed/pipeline_parallel.py b/torchtitan/distributed/pipeline_parallel.py index 06dba40d6f..38f3bad1ba 100644 --- a/torchtitan/distributed/pipeline_parallel.py +++ b/torchtitan/distributed/pipeline_parallel.py @@ -47,7 +47,7 @@ def pipeline_llm( parallelize_fn: ParallelizeFunction, loss_fn: LossFunction, ) -> tuple[_PipelineSchedule, list[nn.Module], bool, bool]: - pp_mesh = parallel_dims.world_mesh["pp"] + pp_mesh = parallel_dims.get_mesh("pp") # Determine the number of virtual stages based on schedule type schedule_class = get_schedule_class( diff --git a/torchtitan/distributed/utils.py b/torchtitan/distributed/utils.py index f424276a3c..3971cf8932 100644 --- a/torchtitan/distributed/utils.py +++ b/torchtitan/distributed/utils.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. import contextlib +import itertools import math import os from collections.abc import Generator, Iterable @@ -26,7 +27,7 @@ def _dist_reduce( x: torch.Tensor, reduceOp: str, - mesh: DeviceMesh, + mesh: DeviceMesh | None, extra_pg: dist.ProcessGroup | None, ) -> float: """Perform distributed reduction on a tensor. @@ -34,7 +35,8 @@ def _dist_reduce( Args: x (torch.Tensor): Input tensor. reduceOp (str): Reduce operation to perform. - mesh (DeviceMesh): Device mesh to use for reduction. + mesh (DeviceMesh | None): Device mesh to use for reduction. + If None, no reduction is performed but simply convert the tensor to a float. extra_pg (dist.ProcessGroup, optional): Extra process group to use for reduction. Defaults to None. If provided, this all_reduce will be called for the extra process group, and then the result will be all_reduced for the mesh. @@ -46,13 +48,16 @@ def _dist_reduce( if extra_pg is not None: x = funcol.all_reduce(x, reduceOp=reduceOp, group=extra_pg) + if mesh is None: + return x.item() + assert x.numel() == 1 # required by `.item()` return funcol.all_reduce(x, reduceOp=reduceOp, group=mesh).item() def dist_max( x: torch.Tensor, - mesh: DeviceMesh, + mesh: DeviceMesh | None = None, extra_pg: dist.ProcessGroup | None = None, ) -> float: return _dist_reduce( @@ -62,7 +67,7 @@ def dist_max( def dist_sum( x: torch.Tensor, - mesh: DeviceMesh, + mesh: DeviceMesh | None = None, extra_pg: dist.ProcessGroup | None = None, ) -> float: return _dist_reduce( @@ -72,7 +77,7 @@ def dist_sum( def dist_mean( x: torch.Tensor, - mesh: DeviceMesh, + mesh: DeviceMesh | None = None, extra_pg: dist.ProcessGroup | None = None, ) -> float: return _dist_reduce( @@ -81,7 +86,7 @@ def dist_mean( def set_determinism( - world_mesh: DeviceMesh | None, + parallel_dims: ParallelDims, device: torch.device, debug_config: DebugConfig, distinct_seed_mesh_dims: list[str], @@ -99,9 +104,8 @@ def set_determinism( Args: world_mesh: Device mesh for distributed training device: Device to use + debug_config: Debug config to use distinct_seed_mesh_dims: List of mesh dimension names to have distinct seeds across. - seed: Base seed value (if None, will be determined automatically) - deterministic: Whether to enable deterministic algorithms """ if debug_config.deterministic: logger.info("Deterministic algorithm enabled (expect perf degradation).") @@ -124,7 +128,7 @@ def set_determinism( FlexAttentionWrapper._compiled_flex_attn = torch.compile(flex_attention) seed = debug_config.seed - if not world_mesh: + if parallel_dims.world_size == 1: if seed is not None: torch.manual_seed(seed) os.environ["PYTHONHASHSEED"] = str(seed % 2**32) @@ -144,20 +148,18 @@ def set_determinism( # For PP + SPMD cases, we want to separate the world into the SPMD mesh and the PP mesh, # and choose a unique seed for each rank on the PP mesh. # We support multiple distinct dimensions by adding each distinct dimension's local rank to the seed. - distinct_dims_in_mesh = [ - dim - for dim in distinct_seed_mesh_dims - if world_mesh.mesh_dim_names and dim in world_mesh.mesh_dim_names + distinct_seed_meshes = [ + parallel_dims.get_mesh(dim) for dim in distinct_seed_mesh_dims ] + distinct_seed_meshes = [mesh for mesh in distinct_seed_meshes if mesh is not None] - if c10d.get_world_size() > 1 and distinct_dims_in_mesh: + if distinct_seed_meshes: # Each dimension contributes: local_rank * (product of all previous dimension sizes) # This guarantees uniqueness like multi-dimensional array indexing seed_offset = 0 cumulative_size = 1 - for dim in distinct_dims_in_mesh: - distinct_mesh = world_mesh[dim] + for distinct_mesh in distinct_seed_meshes: local_rank = distinct_mesh.get_local_rank() # Add contribution from this dimension seed_offset += local_rank * cumulative_size @@ -168,20 +170,17 @@ def set_determinism( seed %= 2**64 logger.debug( - f"Distinct dims {distinct_dims_in_mesh}, Global rank {c10d.get_rank()} using seed: {seed}" + f"Distinct dims {distinct_seed_mesh_dims}, Global rank {c10d.get_rank()} using seed: {seed}" ) # Filter out all distinct dimensions to get duplicate_seed_mesh - duplicate_seed_mesh_dims = [ - name - for name in world_mesh.mesh_dim_names - if name not in distinct_dims_in_mesh - ] - duplicate_seed_mesh = ( - world_mesh[duplicate_seed_mesh_dims] if duplicate_seed_mesh_dims else None + duplicate_seed_meshes = list( + v + for k, v in parallel_dims.get_all_meshes().items() + if k not in distinct_seed_mesh_dims ) else: - duplicate_seed_mesh = world_mesh + duplicate_seed_meshes = [parallel_dims.world_mesh] logger.debug(f"Global Rank {c10d.get_rank()} using seed: {seed}") # The native RNGs and python RNG may not be important, except for the 1-D PP case, but we seed them for consistency. @@ -191,8 +190,10 @@ def set_determinism( # As long as we are not in the 1-D (PP-only) case, we will have a seed to use for all ranks of the SPMD mesh. # IF PP is also used, this seed is unique per PP rank. - if duplicate_seed_mesh and duplicate_seed_mesh.get_coordinate() is not None: - torch.distributed.tensor._random.manual_seed(seed, duplicate_seed_mesh) + # TODO: remove the need of duplicate_seed_meshes once torch.distributed.tensor._random.manual_seed + # doesn't require a mesh input. + if duplicate_seed_meshes: + torch.distributed.tensor._random.manual_seed(seed, duplicate_seed_meshes[0]) def create_context_parallel_ctx( @@ -306,7 +307,10 @@ def _get_distributed_backend(enable_cpu_backend): ) -def set_pg_timeouts(timeout, world_mesh): +def set_pg_timeouts( + timeout: timedelta, + parallel_dims: ParallelDims, +): """ Sets the timeout for all PGs in the provided mesh, and the default (world) group. @@ -325,11 +329,10 @@ def set_pg_timeouts(timeout, world_mesh): torch.distributed.barrier(device_ids=[device_module.current_device()]) device_module.synchronize() - groups = [world_mesh.get_group(mesh_dim) for mesh_dim in range(world_mesh.ndim)] - # None represents the 'default' PG, not part of the mesh - groups.append(None) - for group in groups: + for group in itertools.chain( + [None], [mesh.get_group() for mesh in parallel_dims.get_all_meshes().values()] + ): torch.distributed.distributed_c10d._set_pg_timeout(timeout, group) diff --git a/torchtitan/experiments/compiler_toolkit/common_utils.py b/torchtitan/experiments/compiler_toolkit/common_utils.py index 965e027bdb..b1bab26d1e 100644 --- a/torchtitan/experiments/compiler_toolkit/common_utils.py +++ b/torchtitan/experiments/compiler_toolkit/common_utils.py @@ -24,10 +24,12 @@ def disable_compile(job_config: JobConfig): job_config.compile.enable = original_value -def parallelize_inputs(world_mesh, args, kwargs): +def parallelize_inputs(parallel_dims, args, kwargs): def to_dtensor(tensor): if isinstance(tensor, torch.Tensor): - return DTensor.from_local(tensor, world_mesh["tp"], [Replicate()]) + return DTensor.from_local( + tensor, parallel_dims.get_mesh("tp"), [Replicate()] + ) return tensor dt_args = tree_map(to_dtensor, args) diff --git a/torchtitan/experiments/compiler_toolkit/graph_utils.py b/torchtitan/experiments/compiler_toolkit/graph_utils.py index 06bdd9305b..2e2076a599 100644 --- a/torchtitan/experiments/compiler_toolkit/graph_utils.py +++ b/torchtitan/experiments/compiler_toolkit/graph_utils.py @@ -233,9 +233,7 @@ def __delattr__(self, name: str) -> None: def forward(self, *args, **kwargs): assert "forward" not in self._overrides, "forward cannot be overridden" - dt_args, dt_kwargs = self.parallelize_inputs( - self.parallel_dims.world_mesh, args, kwargs - ) + dt_args, dt_kwargs = self.parallelize_inputs(self.parallel_dims, args, kwargs) if self.joint_graph_module is None: self.joint_graph_module = self.joint_graph_builder( diff --git a/torchtitan/experiments/forge/engine.py b/torchtitan/experiments/forge/engine.py index 2f1887b2d7..56a3af5da1 100644 --- a/torchtitan/experiments/forge/engine.py +++ b/torchtitan/experiments/forge/engine.py @@ -86,10 +86,9 @@ def __init__(self, job_config: ForgeJobConfig): world_size=world_size, ) - world_mesh = parallel_dims.world_mesh if parallel_dims.dp_enabled: - dp_mesh = world_mesh["dp"] - dp_degree, dp_rank = dp_mesh.size(), dp_mesh.get_local_rank() + batch_mesh = parallel_dims.get_mesh("batch") + dp_degree, dp_rank = batch_mesh.size(), batch_mesh.get_local_rank() else: dp_degree, dp_rank = 1, 0 self.dp_degree, self.dp_rank = dp_degree, dp_rank @@ -102,9 +101,10 @@ def __init__(self, job_config: ForgeJobConfig): # Set random seed, and maybe enable deterministic mode # (mainly for debugging, expect perf loss). dist_utils.set_determinism( - world_mesh, + parallel_dims, self.device, job_config.debug, + distinct_seed_mesh_dims=[], ) self.train_spec = get_train_spec(job_config.model.name) diff --git a/torchtitan/experiments/forge/example_train.py b/torchtitan/experiments/forge/example_train.py index d638a6bd26..78980504ba 100644 --- a/torchtitan/experiments/forge/example_train.py +++ b/torchtitan/experiments/forge/example_train.py @@ -168,7 +168,7 @@ def forward_backward_step( optional_context_parallel_ctx = ( dist_utils.create_context_parallel_ctx( - cp_mesh=parallel_dims.world_mesh["cp"], + cp_mesh=parallel_dims.get_mesh("cp"), cp_buffers=[inputs, labels] + [m.freqs_cis for m in model_parts], cp_seq_dims=[1, 1] + [0 for _ in model_parts], cp_no_restore_buffers={inputs, labels}, @@ -243,7 +243,7 @@ def train_step( self.job_config.training.max_norm, foreach=True, pp_mesh=( - parallel_dims.world_mesh["pp"] if parallel_dims.pp_enabled else None + parallel_dims.get_mesh("pp") if parallel_dims.pp_enabled else None ), ep_enabled=parallel_dims.ep_enabled, ) @@ -261,8 +261,8 @@ def train_step( if parallel_dims.dp_cp_enabled: loss = loss.detach() global_avg_loss, global_max_loss = ( - dist_utils.dist_mean(loss, parallel_dims.world_mesh["dp_cp"]), - dist_utils.dist_max(loss, parallel_dims.world_mesh["dp_cp"]), + dist_utils.dist_mean(loss, parallel_dims.get_mesh("loss")), + dist_utils.dist_max(loss, parallel_dims.get_mesh("loss")), ) else: global_avg_loss = global_max_loss = loss.detach().item() @@ -328,7 +328,7 @@ def train(self): timeout=timedelta( seconds=job_config.comm.train_timeout_seconds ), - world_mesh=self.parallel_dims.world_mesh, + parallel_dims=self.parallel_dims, ) if torch.distributed.get_rank() == 0: diff --git a/torchtitan/experiments/gpt_oss/infra/parallelize.py b/torchtitan/experiments/gpt_oss/infra/parallelize.py index 7714d497e4..b67f4f4c0d 100644 --- a/torchtitan/experiments/gpt_oss/infra/parallelize.py +++ b/torchtitan/experiments/gpt_oss/infra/parallelize.py @@ -53,8 +53,6 @@ def parallelize_gptoss( parallel_dims: ParallelDims, job_config: JobConfig, ): - world_mesh = parallel_dims.world_mesh - assert ( job_config.training.seq_len % parallel_dims.seq_len_divisor == 0 ), f""" @@ -86,7 +84,7 @@ def parallelize_gptoss( apply_non_moe_tp( model, - world_mesh["tp"], + parallel_dims.get_mesh("tp"), loss_parallel=not job_config.parallelism.disable_loss_parallel, enable_float8_tensorwise_tp=False, enable_async_tp=False, @@ -95,10 +93,10 @@ def parallelize_gptoss( if parallel_dims.tp_enabled or parallel_dims.ep_enabled: apply_moe_ep_tp( model, - tp_mesh=world_mesh["tp"] if parallel_dims.tp_enabled else None, - ep_mesh=world_mesh["ep"] if parallel_dims.ep_enabled else None, - ep_tp_mesh=( - world_mesh["ep", "tp"] + tp_mesh=parallel_dims.get_mesh("tp") if parallel_dims.tp_enabled else None, + ep_mesh=parallel_dims.get_mesh("ep") if parallel_dims.ep_enabled else None, + ep_etp_mesh=( + parallel_dims.get_mesh("ep_etp") if parallel_dims.tp_enabled and parallel_dims.ep_enabled and parallel_dims.etp_enabled @@ -123,11 +121,10 @@ def parallelize_gptoss( dp_mesh: DeviceMesh | None = None if parallel_dims.fsdp_enabled or parallel_dims.ep_enabled: # apply FSDP or HSDP, potentially with Context Parallel - if parallel_dims.dp_replicate_enabled: - dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp") - else: - dp_mesh_dim_names = ("dp_shard_cp",) - dp_mesh = world_mesh[tuple(dp_mesh_dim_names)] + names = ( + ["dp_replicate", "fsdp"] if parallel_dims.dp_replicate_enabled else ["fsdp"] + ) + dp_mesh = parallel_dims.get_mesh(names) # the mesh dim names of which the MoE params are sharded on via FSDP/HSDP dp_mod_ep_mesh_dim_names = [] @@ -146,7 +143,7 @@ def parallelize_gptoss( reshard_after_forward_policy=job_config.parallelism.fsdp_reshard_after_forward, ep_degree=parallel_dims.ep, dp_mod_ep_mesh=( - world_mesh[tuple(dp_mod_ep_mesh_dim_names)] + parallel_dims.get_mesh(dp_mod_ep_mesh_dim_names) if parallel_dims.ep_enabled else None ), @@ -163,9 +160,9 @@ def parallelize_gptoss( if job_config.training.enable_cpu_offload: logger.info("Applied CPU Offloading to the model") elif parallel_dims.dp_replicate_enabled: - if world_mesh.ndim > 1: + dp_mesh = parallel_dims.get_mesh("dp_replicate") + if dp_mesh is not None and dp_mesh.ndim > 1: raise RuntimeError("DDP has not supported > 1D parallelism") - dp_mesh = world_mesh apply_ddp( model, dp_mesh, @@ -256,7 +253,7 @@ def apply_moe_ep_tp( model: nn.Module, tp_mesh: DeviceMesh | None, ep_mesh: DeviceMesh | None, - ep_tp_mesh: DeviceMesh | None, + ep_etp_mesh: DeviceMesh | None, etp_enabled: bool, ): assert ep_mesh is not None or tp_mesh is not None @@ -301,7 +298,7 @@ def apply_moe_ep_tp( # input / output sharding on the batch / tokens dim experts_plan = ExpertParallel() else: - experts_mesh = ep_tp_mesh + experts_mesh = ep_etp_mesh experts_plan = GptossExpertTensorParallel() parallelize_module( diff --git a/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py b/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py index ac6f9bdc9b..0a341bc34f 100644 --- a/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py +++ b/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py @@ -27,7 +27,6 @@ def parallelize_deepseekv3( parallel_dims: ParallelDims, job_config: JobConfig, ): - world_mesh = parallel_dims.world_mesh # TODO: TP currently cannot handle uneven seq_len because we set # `use_local_output=True` to use plain Tensors for legacy reasons. # Need to revisit this. @@ -61,26 +60,20 @@ def parallelize_deepseekv3( use_flex_attn = getattr(model.model_args, "use_flex_attn", False) apply_non_moe_tp( model, - world_mesh["tp"], + parallel_dims.get_mesh("tp"), loss_parallel=not job_config.parallelism.disable_loss_parallel, enable_float8_tensorwise_tp=False, use_flex_attn=use_flex_attn, ) - maybe_enable_async_tp(job_config, world_mesh["tp"]) + maybe_enable_async_tp(job_config, parallel_dims.get_mesh("tp")) if parallel_dims.tp_enabled or parallel_dims.ep_enabled: apply_moe_ep_tp( model, - tp_mesh=world_mesh["tp"] if parallel_dims.tp_enabled else None, - ep_mesh=world_mesh["ep"] if parallel_dims.ep_enabled else None, - ep_tp_mesh=( - world_mesh["ep", "tp"] - if parallel_dims.tp_enabled - and parallel_dims.ep_enabled - and parallel_dims.etp_enabled - else None - ), - etp_enabled=parallel_dims.etp_enabled, + tp_mesh=parallel_dims.get_mesh("tp"), + ep_mesh=parallel_dims.get_mesh("ep"), + etp_mesh=parallel_dims.get_mesh("etp"), + ep_etp_mesh=parallel_dims.get_mesh(["ep", "etp"]), ) if job_config.activation_checkpoint.mode != "none": @@ -114,38 +107,37 @@ def parallelize_deepseekv3( ): if parallel_dims.dp_replicate_enabled: if parallel_dims.dp_shard_enabled or parallel_dims.cp_enabled: - dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp") + dp_mesh_dim_names = ["dp_replicate", "fsdp"] dp_mode = "hybrid_shard" else: - dp_mesh_dim_names = ("dp_replicate",) + dp_mesh_dim_names = ["dp_replicate"] dp_mode = "replicate" else: - dp_mesh_dim_names = ("dp_shard_cp",) + dp_mesh_dim_names = ["fsdp"] dp_mode = "fully_shard" - dp_mesh = world_mesh[tuple(dp_mesh_dim_names)] + dp_mesh = parallel_dims.get_mesh(dp_mesh_dim_names) # the mesh dim names of which the MoE params are sharded on via FSDP/HSDP - dp_mod_ep_mesh_dim_names = [] - if parallel_dims.ep_enabled: - if parallel_dims.dp_replicate_enabled: - dp_mod_ep_mesh_dim_names.append("dp_replicate") - dp_mod_ep_mesh_dim_names.append("dp_shard_mod_ep") - dp_mod_ep_mesh = world_mesh[tuple(dp_mod_ep_mesh_dim_names)] + if parallel_dims.dp_replicate_enabled: + dp_mesh_dim_names = ["dp_replicate", "efsdp"] + else: + dp_mesh_dim_names = ["efsdp"] + edp_mesh = parallel_dims.get_mesh(dp_mesh_dim_names) for _, transformer_block in model.layers.items(): if transformer_block.moe_enabled and parallel_dims.ep_enabled: experts_shard_dim = 0 - assert dp_mod_ep_mesh is not None + assert edp_mesh is not None assert hasattr(transformer_block, "moe") if ( - dp_mod_ep_mesh.size() * parallel_dims.ep + edp_mesh.size() * parallel_dims.ep > transformer_block.moe.experts.num_experts ): experts_shard_dim = 1 # when EP is enable, the routed experts' gradient reduction is done over - # dp_mod_ep_mesh instead of whole dp_mesh. + # edp_mesh instead of whole dp_mesh. # we add a `fsdp_gradient_divide_factor` to scale gradient over dp_mesh # to be consistent with data. # TODO (ruisizhang123): update the logic following the link below instead @@ -153,7 +145,7 @@ def parallelize_deepseekv3( # https://github.com/pytorch/torchtitan/pull/1803#discussion_r2415190883 transformer_block.moe.experts = data_parallel( transformer_block.moe.experts, - dp_mod_ep_mesh, + edp_mesh, dp_mode, ac_mode=job_config.activation_checkpoint.mode, mp_policy=mp_policy, diff --git a/torchtitan/experiments/simple_fsdp/llama3/parallelize.py b/torchtitan/experiments/simple_fsdp/llama3/parallelize.py index d61e74a5dd..5315b248d0 100644 --- a/torchtitan/experiments/simple_fsdp/llama3/parallelize.py +++ b/torchtitan/experiments/simple_fsdp/llama3/parallelize.py @@ -67,7 +67,7 @@ def parallelize_llama( # all-gather happens in high precision. enable_float8_tensorwise_tp = enable_float8_linear and not float8_is_rowwise - tp_mesh = parallel_dims.world_mesh["tp"] + tp_mesh = parallel_dims.get_mesh("tp") apply_tp( model, tp_mesh, @@ -98,13 +98,13 @@ def parallelize_llama( ): if parallel_dims.dp_replicate_enabled: if parallel_dims.dp_shard_enabled or parallel_dims.cp_enabled: - dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp") + dp_mesh_dim_names = ["dp_replicate", "fsdp"] dp_mode = "hybrid_shard" else: - dp_mesh_dim_names = ("dp_replicate",) + dp_mesh_dim_names = ["dp_replicate"] dp_mode = "replicate" else: - dp_mesh_dim_names = ("dp_shard_cp",) + dp_mesh_dim_names = ["fsdp"] dp_mode = "fully_shard" mp_policy = MixedPrecisionPolicy( @@ -128,7 +128,7 @@ def parallelize_llama( model = data_parallel( model, - parallel_dims.world_mesh[tuple(dp_mesh_dim_names)], + parallel_dims.get_mesh(dp_mesh_dim_names), mode=dp_mode, ac_mode=job_config.activation_checkpoint.mode, mp_policy=mp_policy, diff --git a/torchtitan/experiments/simple_fsdp/tests/test_numerics.py b/torchtitan/experiments/simple_fsdp/tests/test_numerics.py index 76233aeb87..aaf94a5023 100644 --- a/torchtitan/experiments/simple_fsdp/tests/test_numerics.py +++ b/torchtitan/experiments/simple_fsdp/tests/test_numerics.py @@ -20,13 +20,13 @@ def init_test(self): self.loss_fn = cross_entropy_loss data_parallel_shard_degree = -1 if self.mode == "replicate": - self.dp_mesh_dim_names = ("dp_replicate",) + self.dp_mesh_dim_names = ["dp_replicate"] data_parallel_replicate_degree = self.world_size elif self.mode == "fully_shard": - self.dp_mesh_dim_names = ("dp_shard_cp",) + self.dp_mesh_dim_names = ["fsdp"] data_parallel_replicate_degree = 1 elif self.mode == "hybrid_shard": - self.dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp") + self.dp_mesh_dim_names = ["dp_replicate", "fsdp"] data_parallel_replicate_degree = self.world_size // 2 else: raise ValueError(f"Unsupported mode {self.mode}") @@ -41,7 +41,6 @@ def init_test(self): etp=1, world_size=self.world_size, ) - self.device_mesh = self.parallel_dims.world_mesh def get_input(self): inputs = torch.randn(8, 8).cuda() @@ -50,7 +49,7 @@ def get_input(self): return model, inputs, labels def run_fsdp2(self, model, inputs, labels, epoch=20): - fully_shard(model, mesh=self.device_mesh[tuple(self.dp_mesh_dim_names)]) + fully_shard(model, mesh=self.parallel_dims.get_mesh(self.dp_mesh_dim_names)) optim = self.optimizer(model.parameters(), lr=1e-4) losses = [] for _ in range(epoch): @@ -65,7 +64,7 @@ def run_fsdp2(self, model, inputs, labels, epoch=20): def run_simple_fsdp(self, model, inputs, labels, epoch=20): model = data_parallel( model, - device_mesh=self.device_mesh[tuple(self.dp_mesh_dim_names)], + device_mesh=self.parallel_dims.get_mesh(self.dp_mesh_dim_names), mode=self.mode, ) optim = self.optimizer(model.parameters(), lr=1e-4) @@ -82,7 +81,7 @@ def run_simple_fsdp(self, model, inputs, labels, epoch=20): def run_simple_fsdp_compiled_aot_eager(self, model, inputs, labels, epoch=20): model = data_parallel( model, - device_mesh=self.device_mesh[tuple(self.dp_mesh_dim_names)], + device_mesh=self.parallel_dims.get_mesh(self.dp_mesh_dim_names), mode=self.mode, ) # TODO: Add "inductor" backend when it's numerical issues are fixed diff --git a/torchtitan/experiments/vlm/infra/loss.py b/torchtitan/experiments/vlm/infra/loss.py index bba51f2819..7a3a490fb7 100644 --- a/torchtitan/experiments/vlm/infra/loss.py +++ b/torchtitan/experiments/vlm/infra/loss.py @@ -104,7 +104,7 @@ def build_token_imbalance_ce_loss( # NOTE: The device mesh where the input tokens w/ shape BSD can be sliced: # DP split the batch dim B # CP split the sequence dim S - token_mesh = parallel_dims.world_mesh["dp_cp"] + token_mesh = parallel_dims.get_mesh("loss") ft_pg = ft_manager.loss_sync_pg loss_fn = partial(token_imbalance_ce_loss, token_mesh=token_mesh, ft_pg=ft_pg) if job_config.compile.enable and "loss" in job_config.compile.components: diff --git a/torchtitan/experiments/vlm/infra/parallelize.py b/torchtitan/experiments/vlm/infra/parallelize.py index 6a97e4ece1..e57bc9254b 100644 --- a/torchtitan/experiments/vlm/infra/parallelize.py +++ b/torchtitan/experiments/vlm/infra/parallelize.py @@ -38,7 +38,6 @@ def parallelize_vlm( the model must fit on GPU or CPU memory. """ assert isinstance(model.encoder, nn.Module) - world_mesh = parallel_dims.world_mesh # TODO: TP currently cannot handle uneven seq_len because we set # `use_local_output=True` to use plain Tensors for legacy reasons. # Need to revisit this. @@ -75,14 +74,13 @@ def parallelize_vlm( if parallel_dims.fsdp_enabled: # apply FSDP or HSDP, potentially with Context Parallel - if parallel_dims.dp_replicate_enabled: - dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp") - else: - dp_mesh_dim_names = ("dp_shard_cp",) + names = ( + ["dp_replicate", "fsdp"] if parallel_dims.dp_replicate_enabled else ["fsdp"] + ) apply_fsdp( model, - world_mesh[tuple(dp_mesh_dim_names)], + parallel_dims.get_mesh(names), param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce], pp_enabled=parallel_dims.pp_enabled, @@ -101,11 +99,12 @@ def parallelize_vlm( if job_config.training.enable_cpu_offload: logger.info("Applied CPU Offloading to the model") elif parallel_dims.dp_replicate_enabled: - if world_mesh.ndim > 1: + dp_mesh = parallel_dims.get_mesh("dp_replicate") + if dp_mesh is not None and dp_mesh.ndim > 1: raise RuntimeError("DDP has not supported > 1D parallelism") apply_ddp( model, - world_mesh, + dp_mesh, enable_compile=job_config.compile.enable, ) diff --git a/torchtitan/models/deepseek_v3/infra/parallelize.py b/torchtitan/models/deepseek_v3/infra/parallelize.py index 0793820ffd..f9ba017e6e 100644 --- a/torchtitan/models/deepseek_v3/infra/parallelize.py +++ b/torchtitan/models/deepseek_v3/infra/parallelize.py @@ -50,7 +50,6 @@ def parallelize_deepseekv3( parallel_dims: ParallelDims, job_config: JobConfig, ): - world_mesh = parallel_dims.world_mesh # TODO: TP currently cannot handle uneven seq_len because we set # `use_local_output=True` to use plain Tensors for legacy reasons. # Need to revisit this. @@ -81,26 +80,28 @@ def parallelize_deepseekv3( apply_non_moe_tp( model, - world_mesh["tp"], + parallel_dims.get_mesh("tp"), loss_parallel=not job_config.parallelism.disable_loss_parallel, enable_float8_tensorwise_tp=False, use_flex_attn=use_flex_attn, ) - maybe_enable_async_tp(job_config, world_mesh["tp"]) + maybe_enable_async_tp(job_config, parallel_dims.get_mesh("tp")) if parallel_dims.tp_enabled or parallel_dims.ep_enabled: apply_moe_ep_tp( model, - tp_mesh=world_mesh["tp"] if parallel_dims.tp_enabled else None, - ep_mesh=world_mesh["ep"] if parallel_dims.ep_enabled else None, - ep_tp_mesh=( - world_mesh["ep", "tp"] + tp_mesh=parallel_dims.get_mesh("tp") if parallel_dims.tp_enabled else None, + ep_mesh=parallel_dims.get_mesh("ep") if parallel_dims.ep_enabled else None, + etp_mesh=parallel_dims.get_mesh("etp") + if parallel_dims.etp_enabled + else None, + ep_etp_mesh=( + parallel_dims.get_mesh("ep_etp") if parallel_dims.tp_enabled and parallel_dims.ep_enabled and parallel_dims.etp_enabled else None ), - etp_enabled=parallel_dims.etp_enabled, ) model_compile_enabled = ( @@ -123,18 +124,18 @@ def parallelize_deepseekv3( dp_mesh: DeviceMesh | None = None if parallel_dims.fsdp_enabled or parallel_dims.ep_enabled: # apply FSDP or HSDP, potentially with Context Parallel - if parallel_dims.dp_replicate_enabled: - dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp") - else: - dp_mesh_dim_names = ("dp_shard_cp",) - dp_mesh = world_mesh[tuple(dp_mesh_dim_names)] + names = ( + ["dp_replicate", "fsdp"] if parallel_dims.dp_replicate_enabled else ["fsdp"] + ) + dp_mesh = parallel_dims.get_mesh(names) # the mesh dim names of which the MoE params are sharded on via FSDP/HSDP - dp_mod_ep_mesh_dim_names = [] - if parallel_dims.ep_enabled: - if parallel_dims.dp_replicate_enabled: - dp_mod_ep_mesh_dim_names.append("dp_replicate") - dp_mod_ep_mesh_dim_names.append("dp_shard_mod_ep") + names = ( + ["dp_replicate", "efsdp"] + if parallel_dims.dp_replicate_enabled + else ["efsdp"] + ) + edp_mesh = parallel_dims.get_mesh(names) apply_fsdp( model, @@ -145,11 +146,7 @@ def parallelize_deepseekv3( cpu_offload=job_config.training.enable_cpu_offload, reshard_after_forward_policy=job_config.parallelism.fsdp_reshard_after_forward, ep_degree=parallel_dims.ep, - dp_mod_ep_mesh=( - world_mesh[tuple(dp_mod_ep_mesh_dim_names)] - if parallel_dims.ep_enabled - else None - ), + dp_mod_ep_mesh=edp_mesh, gradient_divide_factor=parallel_dims.fsdp_gradient_divide_factor, ) @@ -164,9 +161,9 @@ def parallelize_deepseekv3( if job_config.training.enable_cpu_offload: logger.info("Applied CPU Offloading to the model") elif parallel_dims.dp_replicate_enabled: - if world_mesh.ndim > 1: + dp_mesh = parallel_dims.get_mesh("dp_replicate") + if dp_mesh is not None and dp_mesh.ndim > 1: raise RuntimeError("DDP has not supported > 1D parallelism") - dp_mesh = world_mesh apply_ddp( model, dp_mesh, diff --git a/torchtitan/models/flux/infra/parallelize.py b/torchtitan/models/flux/infra/parallelize.py index fc9c926af0..e6f6d934e9 100644 --- a/torchtitan/models/flux/infra/parallelize.py +++ b/torchtitan/models/flux/infra/parallelize.py @@ -28,14 +28,13 @@ def parallelize_flux( apply_ac(model, job_config.activation_checkpoint) if parallel_dims.fsdp_enabled: - if parallel_dims.dp_replicate_enabled: - dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp") - else: - dp_mesh_dim_names = ("dp_shard_cp",) + names = ( + ["dp_replicate", "fsdp"] if parallel_dims.dp_replicate_enabled else ["fsdp"] + ) apply_fsdp( model, - parallel_dims.world_mesh[tuple(dp_mesh_dim_names)], + parallel_dims.get_mesh(names), param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce], cpu_offload=job_config.training.enable_cpu_offload, @@ -130,17 +129,16 @@ def parallelize_encoders( job_config: JobConfig, ): if parallel_dims.dp_shard_enabled: # apply FSDP or HSDP - if parallel_dims.dp_replicate_enabled: - dp_mesh_dim_names = ("dp_replicate", "dp_shard") - else: - dp_mesh_dim_names = ("dp_shard",) + names = ( + ["dp_replicate", "fsdp"] if parallel_dims.dp_replicate_enabled else ["fsdp"] + ) mp_policy = MixedPrecisionPolicy( param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce], ) fsdp_config = { - "mesh": parallel_dims.world_mesh[tuple(dp_mesh_dim_names)], + "mesh": parallel_dims.get_mesh(names), "mp_policy": mp_policy, } if job_config.training.enable_cpu_offload: diff --git a/torchtitan/models/flux/train.py b/torchtitan/models/flux/train.py index 9bb3cd48bf..1831a6c7d7 100644 --- a/torchtitan/models/flux/train.py +++ b/torchtitan/models/flux/train.py @@ -32,10 +32,10 @@ def __init__(self, job_config: JobConfig): # (mainly for debugging, expect perf loss). # For Flux model, we need distinct seed across FSDP ranks to ensure we randomly dropout prompts info in dataloader dist_utils.set_determinism( - self.parallel_dims.world_mesh, + self.parallel_dims, self.device, job_config.debug, - distinct_seed_mesh_dims=["dp_shard", "dp_replicate"], + distinct_seed_mesh_dims=["fsdp", "dp_replicate"], ) # NOTE: self._dtype is the data type used for encoders (image encoder, T5 text encoder, CLIP text encoder). @@ -133,7 +133,7 @@ def forward_backward_step( optional_context_parallel_ctx = ( dist_utils.create_context_parallel_ctx( - cp_mesh=self.parallel_dims.world_mesh["cp"], + cp_mesh=self.parallel_dims.get_mesh("cp"), cp_buffers=[ latents, latent_pos_enc, diff --git a/torchtitan/models/flux/validate.py b/torchtitan/models/flux/validate.py index 189385e0f2..f0646c9719 100644 --- a/torchtitan/models/flux/validate.py +++ b/torchtitan/models/flux/validate.py @@ -213,7 +213,7 @@ def validate( optional_context_parallel_ctx = ( dist_utils.create_context_parallel_ctx( - cp_mesh=parallel_dims.world_mesh["cp"], + cp_mesh=parallel_dims.get_mesh("cp"), cp_buffers=[ latents, latent_pos_enc, @@ -258,9 +258,7 @@ def validate( loss = torch.sum(torch.stack(accumulated_losses)) loss /= num_steps if parallel_dims.dp_cp_enabled: - global_avg_loss = dist_utils.dist_mean( - loss, parallel_dims.world_mesh["dp_cp"] - ) + global_avg_loss = dist_utils.dist_mean(loss, parallel_dims.get_mesh("loss")) else: global_avg_loss = loss.item() diff --git a/torchtitan/models/llama3/infra/parallelize.py b/torchtitan/models/llama3/infra/parallelize.py index 86ac3a6dfe..e0e0682da7 100644 --- a/torchtitan/models/llama3/infra/parallelize.py +++ b/torchtitan/models/llama3/infra/parallelize.py @@ -56,7 +56,6 @@ def parallelize_llama( NOTE: The passed-in model preferably should be on meta device. Otherwise, the model must fit on GPU or CPU memory. """ - world_mesh = parallel_dims.world_mesh # TODO: TP currently cannot handle uneven seq_len because we set # `use_local_output=True` to use plain Tensors for legacy reasons. # Need to revisit this. @@ -83,13 +82,14 @@ def parallelize_llama( # all-gather happens in high precision. enable_float8_tensorwise_tp = enable_float8_linear and not float8_is_rowwise + tp_mesh = parallel_dims.get_mesh("tp") apply_tp( model, - world_mesh["tp"], + tp_mesh, loss_parallel=not job_config.parallelism.disable_loss_parallel, enable_float8_tensorwise_tp=enable_float8_tensorwise_tp, ) - maybe_enable_async_tp(job_config, world_mesh["tp"]) + maybe_enable_async_tp(job_config, tp_mesh) model_compile_enabled = ( job_config.compile.enable and "model" in job_config.compile.components @@ -110,15 +110,14 @@ def parallelize_llama( apply_compile(model, job_config.compile) if parallel_dims.fsdp_enabled: - # apply FSDP or HSDP, potentially with Context Parallel - if parallel_dims.dp_replicate_enabled: - dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp") - else: - dp_mesh_dim_names = ("dp_shard_cp",) - + # dp_mesh is the mesh for FSDP/HSDP + names = ( + ["dp_replicate", "fsdp"] if parallel_dims.dp_replicate_enabled else ["fsdp"] + ) + dp_mesh = parallel_dims.get_mesh(names) apply_fsdp( model, - world_mesh[tuple(dp_mesh_dim_names)], + dp_mesh, param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce], pp_enabled=parallel_dims.pp_enabled, @@ -137,11 +136,12 @@ def parallelize_llama( if job_config.training.enable_cpu_offload: logger.info("Applied CPU Offloading to the model") elif parallel_dims.dp_replicate_enabled: - if world_mesh.ndim > 1: + dp_replicate_mesh = parallel_dims.get_mesh("dp_replicate") + if parallel_dims.world_size != dp_replicate_mesh.size(): raise RuntimeError("DDP has not supported > 1D parallelism") apply_ddp( model, - world_mesh, + dp_replicate_mesh, enable_compile=model_compile_enabled, ) diff --git a/torchtitan/models/llama4/infra/parallelize.py b/torchtitan/models/llama4/infra/parallelize.py index aff029f736..f3cc4f4f48 100644 --- a/torchtitan/models/llama4/infra/parallelize.py +++ b/torchtitan/models/llama4/infra/parallelize.py @@ -64,7 +64,6 @@ def parallelize_llama( NOTE: The passed-in model preferably should be on meta device. Otherwise, the model must fit on GPU or CPU memory. """ - world_mesh = parallel_dims.world_mesh # TODO: TP currently cannot handle uneven seq_len because we set # `use_local_output=True` to use plain Tensors for legacy reasons. # Need to revisit this. @@ -91,27 +90,22 @@ def parallelize_llama( # all-gather happens in high precision. enable_float8_tensorwise_tp = enable_float8_linear and not float8_is_rowwise + tp_mesh = parallel_dims.get_mesh("tp") apply_non_moe_tp( model, - world_mesh["tp"], + tp_mesh, loss_parallel=not job_config.parallelism.disable_loss_parallel, enable_float8_tensorwise_tp=enable_float8_tensorwise_tp, ) - maybe_enable_async_tp(job_config, world_mesh["tp"]) + maybe_enable_async_tp(job_config, tp_mesh) if parallel_dims.tp_enabled or parallel_dims.ep_enabled: apply_moe_ep_tp( model, - tp_mesh=world_mesh["tp"] if parallel_dims.tp_enabled else None, - ep_mesh=world_mesh["ep"] if parallel_dims.ep_enabled else None, - ep_tp_mesh=( - world_mesh["ep", "tp"] - if parallel_dims.tp_enabled - and parallel_dims.ep_enabled - and parallel_dims.etp_enabled - else None - ), - etp_enabled=parallel_dims.etp_enabled, + tp_mesh=tp_mesh, + ep_mesh=parallel_dims.get_mesh("ep"), + etp_mesh=parallel_dims.get_mesh("etp"), + ep_etp_mesh=parallel_dims.get_mesh(["ep", "etp"]), ) model_compile_enabled = ( @@ -131,21 +125,20 @@ def parallelize_llama( if model_compile_enabled: apply_compile(model, job_config.compile) - dp_mesh: DeviceMesh | None = None if parallel_dims.fsdp_enabled or parallel_dims.ep_enabled: - # apply FSDP or HSDP, potentially with Context Parallel - if parallel_dims.dp_replicate_enabled: - dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp") - else: - dp_mesh_dim_names = ("dp_shard_cp",) - dp_mesh = world_mesh[tuple(dp_mesh_dim_names)] + # dp_mesh is the mesh for FSDP/HSDP + names = ( + ["dp_replicate", "fsdp"] if parallel_dims.dp_replicate_enabled else ["fsdp"] + ) + dp_mesh = parallel_dims.get_mesh(names) # the mesh dim names of which the MoE params are sharded on via FSDP/HSDP - dp_mod_ep_mesh_dim_names = [] + dp_mod_ep_mesh = None if parallel_dims.ep_enabled: if parallel_dims.dp_replicate_enabled: - dp_mod_ep_mesh_dim_names.append("dp_replicate") - dp_mod_ep_mesh_dim_names.append("dp_shard_mod_ep") + dp_mod_ep_mesh = parallel_dims.get_mesh(["dp_replicate", "efsdp"]) + else: + dp_mod_ep_mesh = parallel_dims.get_mesh("efsdp") apply_fsdp( model, @@ -156,11 +149,7 @@ def parallelize_llama( cpu_offload=job_config.training.enable_cpu_offload, reshard_after_forward_policy=job_config.parallelism.fsdp_reshard_after_forward, ep_degree=parallel_dims.ep, - dp_mod_ep_mesh=( - world_mesh[tuple(dp_mod_ep_mesh_dim_names)] - if parallel_dims.ep_enabled - else None - ), + dp_mod_ep_mesh=dp_mod_ep_mesh, gradient_divide_factor=parallel_dims.fsdp_gradient_divide_factor, ) @@ -175,9 +164,9 @@ def parallelize_llama( if job_config.training.enable_cpu_offload: logger.info("Applied CPU Offloading to the model") elif parallel_dims.dp_replicate_enabled: - if world_mesh.ndim > 1: + dp_mesh = parallel_dims.get_mesh("dp_replicate") + if parallel_dims.world_size != dp_mesh.size(): raise RuntimeError("DDP has not supported > 1D parallelism") - dp_mesh = world_mesh apply_ddp( model, dp_mesh, @@ -441,8 +430,8 @@ def apply_moe_ep_tp( model: nn.Module, tp_mesh: DeviceMesh | None, ep_mesh: DeviceMesh | None, - ep_tp_mesh: DeviceMesh | None, - etp_enabled: bool, + etp_mesh: DeviceMesh | None, + ep_etp_mesh: DeviceMesh | None, ): assert ep_mesh is not None or tp_mesh is not None @@ -464,7 +453,7 @@ def apply_moe_ep_tp( # replicate computation for the router "moe.router.gate": NoParallel(), } - if ep_mesh is not None and not etp_enabled: + if ep_mesh is not None and etp_mesh is None: # If TP is borrowed for EP, then split the tokens across TP ranks so that # the reorderer, the all-to-all comms, and routed experts computation # are effectively running Sequence Parallel (split along the folded bs*slen dim) @@ -488,15 +477,17 @@ def apply_moe_ep_tp( experts_mesh, experts_plan = None, None if ep_mesh is None: + assert ep_etp_mesh is None experts_mesh = tp_mesh # input Replicate, output Partial experts_plan = TensorParallel() - elif tp_mesh is None or not etp_enabled: + elif tp_mesh is None or etp_mesh is None: + assert ep_etp_mesh is None experts_mesh = ep_mesh # input / output sharding on the batch / tokens dim experts_plan = ExpertParallel() else: - experts_mesh = ep_tp_mesh + experts_mesh = ep_etp_mesh experts_plan = ExpertTensorParallel() parallelize_module( diff --git a/torchtitan/models/qwen3/infra/parallelize.py b/torchtitan/models/qwen3/infra/parallelize.py index 6b8dc3d5a6..94a881aed4 100644 --- a/torchtitan/models/qwen3/infra/parallelize.py +++ b/torchtitan/models/qwen3/infra/parallelize.py @@ -51,7 +51,6 @@ def parallelize_qwen3( parallel_dims: ParallelDims, job_config: JobConfig, ): - world_mesh = parallel_dims.world_mesh assert ( job_config.training.seq_len % parallel_dims.seq_len_divisor == 0 ), f""" @@ -86,7 +85,7 @@ def parallelize_qwen3( apply_non_moe_tp( model, - world_mesh["tp"], + parallel_dims.get_mesh("tp"), loss_parallel=not job_config.parallelism.disable_loss_parallel, enable_float8_tensorwise_tp=enable_float8_tensorwise_tp, enable_async_tp=job_config.parallelism.enable_async_tensor_parallel, @@ -95,16 +94,10 @@ def parallelize_qwen3( if parallel_dims.tp_enabled or parallel_dims.ep_enabled: apply_moe_ep_tp( model, - tp_mesh=world_mesh["tp"] if parallel_dims.tp_enabled else None, - ep_mesh=world_mesh["ep"] if parallel_dims.ep_enabled else None, - ep_tp_mesh=( - world_mesh["ep", "tp"] - if parallel_dims.tp_enabled - and parallel_dims.ep_enabled - and parallel_dims.etp_enabled - else None - ), - etp_enabled=parallel_dims.etp_enabled, + tp_mesh=parallel_dims.get_mesh("tp"), + ep_mesh=parallel_dims.get_mesh("ep"), + etp_mesh=parallel_dims.get_mesh("etp"), + ep_etp_mesh=parallel_dims.get_mesh(["ep", "etp"]), ) if job_config.activation_checkpoint.mode != "none": @@ -123,18 +116,18 @@ def parallelize_qwen3( if parallel_dims.fsdp_enabled: # apply FSDP or HSDP, potentially with Context Parallel - if parallel_dims.dp_replicate_enabled: - dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp") - else: - dp_mesh_dim_names = ("dp_shard_cp",) - dp_mesh = world_mesh[tuple(dp_mesh_dim_names)] + names = ( + ["dp_replicate", "fsdp"] if parallel_dims.dp_replicate_enabled else ["fsdp"] + ) + dp_mesh = parallel_dims.get_mesh(names) # the mesh dim names of which the MoE params are sharded on via FSDP/HSDP - dp_mod_ep_mesh_dim_names = [] - if parallel_dims.ep_enabled: - if parallel_dims.dp_replicate_enabled: - dp_mod_ep_mesh_dim_names.append("dp_replicate") - dp_mod_ep_mesh_dim_names.append("dp_shard_mod_ep") + names = ( + ["dp_replicate", "efsdp"] + if parallel_dims.dp_replicate_enabled + else ["efsdp"] + ) + edp_mesh = parallel_dims.get_mesh(names) apply_fsdp( model, @@ -145,11 +138,7 @@ def parallelize_qwen3( cpu_offload=job_config.training.enable_cpu_offload, reshard_after_forward_policy=job_config.parallelism.fsdp_reshard_after_forward, ep_degree=parallel_dims.ep, - dp_mod_ep_mesh=( - world_mesh[tuple(dp_mod_ep_mesh_dim_names)] - if parallel_dims.ep_enabled - else None - ), + dp_mod_ep_mesh=edp_mesh, gradient_divide_factor=parallel_dims.fsdp_gradient_divide_factor, ) @@ -164,11 +153,12 @@ def parallelize_qwen3( if job_config.training.enable_cpu_offload: logger.info("Applied CPU Offloading to the model") elif parallel_dims.dp_replicate_enabled: - if world_mesh.ndim > 1: + dp_mesh = parallel_dims.get_mesh("dp_replicate") + if dp_mesh is not None and dp_mesh.ndim > 1: raise RuntimeError("DDP has not supported > 1D parallelism") apply_ddp( model, - world_mesh, + dp_mesh, enable_compile=model_compile_enabled, ) diff --git a/torchtitan/train.py b/torchtitan/train.py index 4d3ed12e8e..4f9c5c970c 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -11,7 +11,6 @@ from typing import Any, Generator, Iterable, Optional import torch - from torch.distributed.elastic.multiprocessing.errors import record import torchtitan.protocols.train_spec as train_spec_module @@ -99,15 +98,14 @@ def __init__(self, job_config: JobConfig): parallelism_config, world_size ) - world_mesh = parallel_dims.world_mesh if parallel_dims.dp_enabled: - dp_mesh = world_mesh["dp"] - dp_degree, dp_rank = dp_mesh.size(), dp_mesh.get_local_rank() + batch_mesh = parallel_dims.get_mesh("batch") + batch_degree, batch_rank = batch_mesh.size(), batch_mesh.get_local_rank() else: - dp_degree, dp_rank = 1, 0 + batch_degree, batch_rank = 1, 0 self.ft_manager = FTManager(job_config.fault_tolerance) - dp_degree, dp_rank = self.ft_manager.get_dp_info(dp_degree, dp_rank) + batch_degree, batch_rank = self.ft_manager.get_dp_info(batch_degree, batch_rank) # take control of garbage collection to avoid stragglers self.gc_handler = utils.GarbageCollection( @@ -117,7 +115,7 @@ def __init__(self, job_config: JobConfig): # Set random seed, and maybe enable deterministic mode # (mainly for debugging, expect perf loss). dist_utils.set_determinism( - world_mesh, + parallel_dims, self.device, job_config.debug, distinct_seed_mesh_dims=["pp"], @@ -132,8 +130,8 @@ def __init__(self, job_config: JobConfig): ) self.dataloader = self.train_spec.build_dataloader_fn( - dp_world_size=dp_degree, - dp_rank=dp_rank, + dp_world_size=batch_degree, + dp_rank=batch_rank, tokenizer=self.tokenizer, job_config=job_config, ) @@ -199,19 +197,20 @@ def __init__(self, job_config: JobConfig): if global_batch_size < 0: # This global batch size results in 1 gradient accumulation # step. - global_batch_size = job_config.training.local_batch_size * dp_degree + global_batch_size = job_config.training.local_batch_size * batch_degree assert global_batch_size > 0 assert ( - global_batch_size % (job_config.training.local_batch_size * dp_degree) == 0 + global_batch_size % (job_config.training.local_batch_size * batch_degree) + == 0 ), ( f"global batch size must be multiple of local batch size times " f"data-parallel degree ({global_batch_size} " - f"% ({job_config.training.local_batch_size} * {dp_degree}) != 0)" + f"% ({job_config.training.local_batch_size} * {batch_degree}) != 0)" ) # calculate gradient accumulation steps self.gradient_accumulation_steps = global_batch_size // ( - job_config.training.local_batch_size * dp_degree + job_config.training.local_batch_size * batch_degree ) assert self.gradient_accumulation_steps > 0 self.loss_fn = rescale_accumulated_loss( @@ -344,8 +343,8 @@ def __init__(self, job_config: JobConfig): self.validator = self.train_spec.build_validator_fn( job_config=job_config, - dp_world_size=dp_degree, - dp_rank=dp_rank, + dp_world_size=batch_degree, + dp_rank=batch_rank, tokenizer=self.tokenizer, parallel_dims=parallel_dims, loss_fn=self.loss_fn, @@ -434,7 +433,7 @@ def forward_backward_step( # ensure CP handles the separate freqs_cis buffer for each pp stage optional_context_parallel_ctx = ( dist_utils.create_context_parallel_ctx( - cp_mesh=parallel_dims.world_mesh["cp"], + cp_mesh=parallel_dims.get_mesh("cp"), cp_buffers=[inputs, labels] + [m.freqs_cis for m in model_parts], cp_seq_dims=[1, 1] + [0 for _ in model_parts], cp_no_restore_buffers={inputs, labels}, @@ -513,9 +512,7 @@ def train_step( [p for m in self.model_parts for p in m.parameters()], self.job_config.training.max_norm, foreach=True, - pp_mesh=( - parallel_dims.world_mesh["pp"] if parallel_dims.pp_enabled else None - ), + pp_mesh=parallel_dims.get_mesh("pp"), ep_enabled=parallel_dims.ep_enabled, ) self.checkpointer.maybe_wait_for_staging() @@ -532,14 +529,15 @@ def train_step( if parallel_dims.dp_cp_enabled: loss = loss.detach() ft_pg = self.ft_manager.loss_sync_pg + loss_mesh = parallel_dims.get_mesh("loss") global_avg_loss, global_max_loss, global_ntokens_seen = ( - dist_utils.dist_mean(loss, parallel_dims.world_mesh["dp_cp"], ft_pg), - dist_utils.dist_max(loss, parallel_dims.world_mesh["dp_cp"], ft_pg), + dist_utils.dist_mean(loss, loss_mesh, ft_pg), + dist_utils.dist_max(loss, loss_mesh, ft_pg), dist_utils.dist_sum( torch.tensor( self.ntokens_seen, dtype=torch.int64, device=self.device ), - parallel_dims.world_mesh["dp_cp"], + loss_mesh, ft_pg, ), ) @@ -636,7 +634,7 @@ def train(self): timeout=timedelta( seconds=job_config.comm.train_timeout_seconds ), - world_mesh=self.parallel_dims.world_mesh, + parallel_dims=self.parallel_dims, ) if torch.distributed.get_rank() == 0: