|
| 1 | +# Some portions of this implementation are inspired and/or adapted |
| 2 | +# from Meta's open-source project TorchTitan, |
| 3 | +# licensed under the BSD 3-Clause License. |
| 4 | + |
| 5 | +import math |
| 6 | +from abc import ABC, abstractmethod |
| 7 | +from typing import Optional |
| 8 | + |
| 9 | +import torch |
| 10 | +from torch.distributed.device_mesh import DeviceMesh |
| 11 | +from torch.distributed.pipelining.schedules import PipelineScheduleSingle, get_schedule_class |
| 12 | + |
| 13 | +from modalities.running_env.fsdp.device_mesh import ParallelismDegrees |
| 14 | + |
| 15 | + |
| 16 | +class FQNsPerStageGenerator(ABC): |
| 17 | + @abstractmethod |
| 18 | + def generate_fqns_per_stage( |
| 19 | + self, num_stages: int, num_layers: int, input_layer_equivalence: int = 1, output_layer_equivalence: int = 1 |
| 20 | + ) -> list[list[str]]: |
| 21 | + """ |
| 22 | + Generate a list of fully qualified names (FQNs) for each pipeline stage. |
| 23 | +
|
| 24 | + Args: |
| 25 | + num_stages (int): Number of stages in the pipeline. |
| 26 | + num_layers (int): Total number of layers in the model. |
| 27 | + input_layer_equivalence (int): Determines to how many transformer layers |
| 28 | + the input layer corresponds. Default is 1. |
| 29 | + output_layer_equivalence (int): Determines to how many transformer layers |
| 30 | + the output layer corresponds. Default is 1. |
| 31 | +
|
| 32 | + Returns: |
| 33 | + list[list[str]]: A list containing an FQN list for each stage. |
| 34 | + """ |
| 35 | + raise NotImplementedError("This method should be implemented by subclasses.") |
| 36 | + |
| 37 | + |
| 38 | +class PipelineFactory: |
| 39 | + """Pipeline factory class to create pipelined models.""" |
| 40 | + |
| 41 | + @staticmethod |
| 42 | + def create_pipeline_model( |
| 43 | + num_layers: int, |
| 44 | + fqns_per_stage_generator: FQNsPerStageGenerator, |
| 45 | + device_mesh: DeviceMesh, |
| 46 | + pp_schedule_name: str, |
| 47 | + num_layers_per_stage: int, |
| 48 | + input_layer_equivalence: Optional[int] = 1, |
| 49 | + output_layer_equivalence: Optional[int] = 1, |
| 50 | + ) -> torch.nn.Module: |
| 51 | + device_mesh[ParallelismDegrees.PP.value] |
| 52 | + pp_dims = device_mesh.size(ParallelismDegrees.PP.value) |
| 53 | + schedule_class = get_schedule_class(pp_schedule_name) |
| 54 | + is_single_stage_schedule = issubclass(schedule_class, PipelineScheduleSingle) |
| 55 | + if not is_single_stage_schedule: |
| 56 | + raise ValueError( |
| 57 | + f"Unsupported pipeline schedule: {pp_schedule_name}. We only support single-stage schedules." |
| 58 | + ) |
| 59 | + |
| 60 | + # calculate the number of stages |
| 61 | + num_virtual_stages = math.ceil( |
| 62 | + (num_layers + input_layer_equivalence + output_layer_equivalence) / num_layers_per_stage |
| 63 | + ) |
| 64 | + if num_virtual_stages % pp_dims != 0: |
| 65 | + raise ValueError( |
| 66 | + f"Number of virtual stages {num_virtual_stages} is not divisible by parallel dimensions {pp_dims}. " |
| 67 | + f"For reference: {num_layers=} {input_layer_equivalence=} " |
| 68 | + f"{output_layer_equivalence=} {num_layers_per_stage=}" |
| 69 | + ) |
| 70 | + |
| 71 | + stages_per_rank = num_virtual_stages // pp_dims |
| 72 | + if stages_per_rank != 1: |
| 73 | + raise ValueError( |
| 74 | + f"Stages per rank {stages_per_rank} must be 1 for single-stage schedules. " |
| 75 | + f"Please adjust {num_layers_per_stage=} to ensure each PP rank has exactly one stage." |
| 76 | + ) |
| 77 | + |
| 78 | + fqns_per_stage_generator.generate_fqns_per_stage( |
| 79 | + num_stages=num_virtual_stages, |
| 80 | + num_layers=num_layers, |
| 81 | + input_layer_equivalence=input_layer_equivalence, |
| 82 | + output_layer_equivalence=output_layer_equivalence, |
| 83 | + ) |
| 84 | + |
| 85 | + @staticmethod |
| 86 | + def create_gpt2_model_splitter(): |
| 87 | + """Create a GPT-2 model splitter for pipeline parallelism.""" |
| 88 | + pass |
0 commit comments