Skip to content

Commit 1e80bbb

Browse files
committed
feat: implemented stage FQN generation for pipeline parallelism
1 parent 1e4d28e commit 1e80bbb

File tree

2 files changed

+88
-0
lines changed

2 files changed

+88
-0
lines changed

src/modalities/models/parallelism/__init__.py

Whitespace-only changes.
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
# Some portions of this implementation are inspired and/or adapted
2+
# from Meta's open-source project TorchTitan,
3+
# licensed under the BSD 3-Clause License.
4+
5+
import math
6+
from abc import ABC, abstractmethod
7+
from typing import Optional
8+
9+
import torch
10+
from torch.distributed.device_mesh import DeviceMesh
11+
from torch.distributed.pipelining.schedules import PipelineScheduleSingle, get_schedule_class
12+
13+
from modalities.running_env.fsdp.device_mesh import ParallelismDegrees
14+
15+
16+
class FQNsPerStageGenerator(ABC):
17+
@abstractmethod
18+
def generate_fqns_per_stage(
19+
self, num_stages: int, num_layers: int, input_layer_equivalence: int = 1, output_layer_equivalence: int = 1
20+
) -> list[list[str]]:
21+
"""
22+
Generate a list of fully qualified names (FQNs) for each pipeline stage.
23+
24+
Args:
25+
num_stages (int): Number of stages in the pipeline.
26+
num_layers (int): Total number of layers in the model.
27+
input_layer_equivalence (int): Determines to how many transformer layers
28+
the input layer corresponds. Default is 1.
29+
output_layer_equivalence (int): Determines to how many transformer layers
30+
the output layer corresponds. Default is 1.
31+
32+
Returns:
33+
list[list[str]]: A list containing an FQN list for each stage.
34+
"""
35+
raise NotImplementedError("This method should be implemented by subclasses.")
36+
37+
38+
class PipelineFactory:
39+
"""Pipeline factory class to create pipelined models."""
40+
41+
@staticmethod
42+
def create_pipeline_model(
43+
num_layers: int,
44+
fqns_per_stage_generator: FQNsPerStageGenerator,
45+
device_mesh: DeviceMesh,
46+
pp_schedule_name: str,
47+
num_layers_per_stage: int,
48+
input_layer_equivalence: Optional[int] = 1,
49+
output_layer_equivalence: Optional[int] = 1,
50+
) -> torch.nn.Module:
51+
device_mesh[ParallelismDegrees.PP.value]
52+
pp_dims = device_mesh.size(ParallelismDegrees.PP.value)
53+
schedule_class = get_schedule_class(pp_schedule_name)
54+
is_single_stage_schedule = issubclass(schedule_class, PipelineScheduleSingle)
55+
if not is_single_stage_schedule:
56+
raise ValueError(
57+
f"Unsupported pipeline schedule: {pp_schedule_name}. We only support single-stage schedules."
58+
)
59+
60+
# calculate the number of stages
61+
num_virtual_stages = math.ceil(
62+
(num_layers + input_layer_equivalence + output_layer_equivalence) / num_layers_per_stage
63+
)
64+
if num_virtual_stages % pp_dims != 0:
65+
raise ValueError(
66+
f"Number of virtual stages {num_virtual_stages} is not divisible by parallel dimensions {pp_dims}. "
67+
f"For reference: {num_layers=} {input_layer_equivalence=} "
68+
f"{output_layer_equivalence=} {num_layers_per_stage=}"
69+
)
70+
71+
stages_per_rank = num_virtual_stages // pp_dims
72+
if stages_per_rank != 1:
73+
raise ValueError(
74+
f"Stages per rank {stages_per_rank} must be 1 for single-stage schedules. "
75+
f"Please adjust {num_layers_per_stage=} to ensure each PP rank has exactly one stage."
76+
)
77+
78+
fqns_per_stage_generator.generate_fqns_per_stage(
79+
num_stages=num_virtual_stages,
80+
num_layers=num_layers,
81+
input_layer_equivalence=input_layer_equivalence,
82+
output_layer_equivalence=output_layer_equivalence,
83+
)
84+
85+
@staticmethod
86+
def create_gpt2_model_splitter():
87+
"""Create a GPT-2 model splitter for pipeline parallelism."""
88+
pass

0 commit comments

Comments
 (0)