diff --git a/.github/workflows/pr-cpu.yaml b/.github/workflows/pr-cpu.yaml index 283c7ed1..05b47cc1 100644 --- a/.github/workflows/pr-cpu.yaml +++ b/.github/workflows/pr-cpu.yaml @@ -23,7 +23,7 @@ jobs: matrix: include: - name: "cpu-2.7.0" - container: mosaicml/pytorch:2.7.0_cpu-python3.12-ubuntu22.04 + container: mosaicml/dle:nightly-latest # Update after the next release of llm-foundry (mosaicml/llm-foundry:2.7.0_cpu-python3.12-ubuntu22.04) markers: "not gpu and not only_release" pip_deps: "[cpu]" pytest_command: "coverage run -m pytest" diff --git a/.github/workflows/pr-gpu.yaml b/.github/workflows/pr-gpu.yaml index 0cfc1784..efc5d4c4 100644 --- a/.github/workflows/pr-gpu.yaml +++ b/.github/workflows/pr-gpu.yaml @@ -23,7 +23,7 @@ jobs: matrix: include: - name: "gpu-2.7.0-1" - container: mosaicml/llm-foundry:2.7.0_cu128-latest + container: mosaicml/dle:nightly-latest # Update after the next release of llm-foundry (mosaicml/llm-foundry:2.7.0_cu128-latest) markers: "gpu" pip_deps: "[gpu]" pytest_command: "coverage run -m pytest" @@ -52,7 +52,7 @@ jobs: matrix: include: - name: "gpu-2.7.0-2" - container: mosaicml/llm-foundry:2.7.0_cu128-latest + container: mosaicml/dle:nightly-latest # Update after the next release of llm-foundry (mosaicml/llm-foundry:2.7.0_cu128-latest) markers: "gpu" pip_deps: "[gpu]" pytest_command: "coverage run -m pytest" diff --git a/compose_rl/algorithms/online/generation_utils/vllm_utils.py b/compose_rl/algorithms/online/generation_utils/vllm_utils.py index 62c86d41..6f0e91c6 100644 --- a/compose_rl/algorithms/online/generation_utils/vllm_utils.py +++ b/compose_rl/algorithms/online/generation_utils/vllm_utils.py @@ -28,6 +28,7 @@ import torch import torch.distributed import torch.nn as nn +from composer.distributed.shared_utils import get_summon_params_fn from composer.utils import dist from ray.exceptions import GetTimeoutError from ray.util.placement_group import placement_group @@ -42,7 +43,9 @@ default_pg_timeout, rendezvous, ) +from torch.distributed.fsdp import FSDPModule from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.tensor import DTensor from compose_rl.algorithms.online.generation_utils.vllm_actor import LLMRayActor from compose_rl.algorithms.online.model_methods import ( @@ -277,33 +280,6 @@ def create_vllm_engines( return vllm_engines -def build_param_fullnames(top_module: nn.Module) -> dict: - """Builds a mapping of parameter objects to their fully-qualified names. - - Traverses the entire model from the top level and map each parameter - object to its fully-qualified name (e.g., - "lm_backbone.layer1.mlp.down_proj.weight"). - - Args: - top_module (nn.Module): The top-level module to traverse. - """ - param2fullname = {} - - def _dfs(current_module: nn.Module, prefix: str = ''): - # Get local parameters (without recursing into children). - for local_name, param in current_module.named_parameters(recurse=False): - full_name = f'{prefix}.{local_name}' if prefix else local_name - param2fullname[param] = full_name - - # Recurse on child modules. - for child_name, child_module in current_module.named_children(): - child_prefix = f'{prefix}.{child_name}' if prefix else child_name - _dfs(child_module, prefix=child_prefix) - - _dfs(top_module) - return param2fullname - - def simplify_param_path(path: str) -> str: """Simplifies the parameter path by removing unnecessary parts. @@ -333,15 +309,15 @@ def simplify_param_path(path: str) -> str: def is_fsdp_leaf(module: nn.Module) -> bool: - """Check if the module is a leaf in the FSDP hierarchy. + """Check if the module is a leaf in the FSDP(1/2) hierarchy. Args: module (nn.Module): The torch module to check """ - if not isinstance(module, FSDP): + if not isinstance(module, (FSDP, FSDPModule)): return False for subm in module.modules(): - if subm is not module and isinstance(subm, FSDP): + if subm is not module and isinstance(subm, (FSDP, FSDPModule)): return False return True @@ -377,6 +353,19 @@ def should_update_torch_module( return False +def get_name_for_param(model: nn.Module, param: torch.Tensor) -> str: + """Get the full name of a parameter in the model. + + Args: + model (nn.Module): The model that contains the parameter + param (torch.Tensor): The parameter to get the name for + """ + for name, p in model.named_parameters(): + if p is param: + return name + raise ValueError(f'Parameter {param} not found in model {model}') + + def broadcast_to_vllm( model: nn.Module, vllm_engines: list, @@ -399,19 +388,21 @@ def broadcast_to_vllm( torch.cuda.empty_cache() if loss_type == OnPolicyEnum.PPO: # Extract the lm_backbone params from the model - count, num_params = 0, len( + num_params = len( list(model.model.lm_backbone.named_parameters()), # type: ignore ) elif loss_type in ALGORITHM_TYPE.CRITIC_FREE: # Directly use the model params - count, num_params = 0, len( + num_params = len( list(model.model.named_parameters()), # type: ignore ) else: raise ValueError( f'Unsupported loss type: {loss_type}. Supported types are: ppo, grpo', ) + count = 0 + # Reset prefix caching if enabled refss = [] cache_reset_refss = [] if enable_prefix_caching and dist.get_global_rank() == 0: @@ -430,8 +421,6 @@ def broadcast_to_vllm( ] seen_fsdp_modules = set() seen_updated_parsed_names = set() - count = 0 - param_2_full_name = build_param_fullnames(model) with torch.no_grad(): # Adding a dummy forwards call. @@ -454,67 +443,96 @@ def broadcast_to_vllm( start_time = time.time() update_time = 0 + # Getting the correct summon_full_params function based on whether + # the model is FSDP1 vs FSDP2. + summon_full_params = get_summon_params_fn(model) + for module_name, module in model.named_modules(): - if isinstance(module, FSDP): - # This is needed otherwise FSDP will materialize parameters of size 0. - # So just for the joint actor critic models we have to actually skip this module. - if module_name == 'model' and loss_type == OnPolicyEnum.PPO: - continue - - # Only update if we haven't updated this module before - if module not in seen_fsdp_modules: - seen_fsdp_modules.add(module) - - # Materializes parameters for this specific FSDP module - with FSDP.summon_full_params( + # Skip non-FSDP modules + if not isinstance(module, (FSDP, FSDPModule)): + continue + + # This is needed otherwise FSDP will materialize parameters of size 0. + # So just for the joint actor critic models we have to actually skip this module. + if module_name == 'model' and loss_type == OnPolicyEnum.PPO: + continue + + # Only update if we haven't updated this module before + if module in seen_fsdp_modules: + continue + seen_fsdp_modules.add(module) + + # Materializes parameters for this specific FSDP module only BUT THIS + # INCLUDES any parameters from submodules that are not FSDP-wrapped themselves. + # We don't want to materialize the entire model to avoid potential OOM. + # View NestedFSDPModel in the Composer repo and the related test in + # for an example of why this logic is needed. + with summon_full_params( + module, + writeback=False, + rank0_only=False, + recurse=False, + ): + # Note: For the following module.named_parameters(), we have to use recurse=True + # since the following case is possible where we still need NonFSDP_Child's params + # FSDP_Module + # |- direct_param (found with recurse=False) + # |- NonFSDP_Child + # | |- child_param (missed with recurse=False) + for _, param in module.named_parameters(recurse=True): + # Only distribute on rank 0 + if not dist.get_global_rank() == 0: + continue + + # Skip DTensor params at this level since they were not summoned + # and we only want to broadcast the summoned parameters. + # Encountering this conditional implies that a FSDP-wrapped submodule + # exists and will later be summoned to materialize this parameter. + if isinstance(param, DTensor): + continue + + full_name = get_name_for_param(model, param) + parsed_name = simplify_param_path(full_name) + + if parsed_name in seen_updated_parsed_names: + continue + + if 'critic_head' in parsed_name: + log.info('Critic head found, skipping sending') + continue + + update = should_update_torch_module( + parsed_name, + full_name, module, - writeback=False, - rank0_only=True, - recurse=False, - ): - for _, param in module.named_parameters(recurse=True): - if dist.get_global_rank() == 0: - full_name = param_2_full_name[param] - parsed_name = simplify_param_path(full_name) - - if 'critic_head' in parsed_name: - log.info('Critic head found, skipping sending') - continue - - update = should_update_torch_module( - parsed_name, - full_name, - module, - loss_type, - valid_non_leaf_module_names, - ) - - # We've already updated this module before, - if parsed_name in seen_updated_parsed_names: - continue - - # Usually if we have to skip a module, it's because we cannot - if update: - start_update_time = time.time() - seen_updated_parsed_names.add(parsed_name) - - count += 1 - shape = param.shape - refs = [ - engine.update_weight.remote( - parsed_name, - dtype=param.dtype, - shape=shape, - empty_cache=(count == num_params), - ) for engine in vllm_engines - ] - refss.extend(refs) - torch.distributed.broadcast( - param.data, - 0, - group=model_update_group, - ) - update_time += time.time() - start_update_time + loss_type, + valid_non_leaf_module_names, + ) + + if not update: + continue + + start_update_time = time.time() + seen_updated_parsed_names.add(parsed_name) + + count += 1 + shape = param.shape + refs = [ + engine.update_weight.remote( + parsed_name, + dtype=param.dtype, + shape=shape, + empty_cache=(count == num_params), + ) for engine in vllm_engines + ] + refss.extend(refs) + + torch.distributed.broadcast( + param.data, + 0, + group=model_update_group, + ) + update_time += time.time() - start_update_time # Issue (#67): Note this code will likely need to be updated for PEFT for efficiency reasons. if dist.get_global_rank() == 0: diff --git a/compose_rl/algorithms/online/hf_utils.py b/compose_rl/algorithms/online/hf_utils.py index 4ef8a407..db114db3 100644 --- a/compose_rl/algorithms/online/hf_utils.py +++ b/compose_rl/algorithms/online/hf_utils.py @@ -7,6 +7,7 @@ import torch import torch.nn as nn +from composer.distributed.shared_utils import get_summon_params_fn from composer.utils import is_model_fsdp from transformers import ( AutoConfig, @@ -93,14 +94,13 @@ def generate( **kwargs: Any, ): if is_model_fsdp(self.lm_backbone): - from torch.distributed.fsdp import FullyShardedDataParallel as FSDP - - # Note: We need to use the FSDP.summon_full_params context manager here because the generate function + # Note: We need to use the summon_full_params context manager here because the generate function # does not seem to gather the weights for the LM head. This solution works because the tied weights of the LM head # are in the root FSDP module, and are summoned by the below context manager. See https://github.com/pytorch/pytorch/issues/100069 # for more info. # Note: We use recurse=False here so that we only summon full params for the LM head, not the entire model. - with FSDP.summon_full_params( + summon_full_params = get_summon_params_fn(self.lm_backbone) + with summon_full_params( self.lm_backbone, writeback=False, recurse=False, diff --git a/compose_rl/algorithms/online/model.py b/compose_rl/algorithms/online/model.py index efe5eeda..1ed00e6a 100644 --- a/compose_rl/algorithms/online/model.py +++ b/compose_rl/algorithms/online/model.py @@ -8,6 +8,7 @@ from typing import Any, MutableMapping, Optional, Union import torch +from composer.distributed.shared_utils import get_summon_params_fn from composer.models import HuggingFaceModel from composer.utils import dist, is_model_fsdp from llmfoundry.models import ComposerHFCausalLM @@ -186,14 +187,13 @@ def generate(self, input_ids: torch.Tensor, *args: Any, **kwargs: Any): # Note: it seems as if we need to summon FSDP parameters here to ensure that we don't break # the standard actor critic forward pass. if is_model_fsdp(self.model): - from torch.distributed.fsdp import FullyShardedDataParallel as FSDP - - # Note: We need to use the FSDP.summon_full_params context manager here because the generate function + # Note: We need to use the summon_full_params context manager here because the generate function # does not seem to gather the weights for the LM head. This solution works because the tied weights of the LM head # are in the root FSDP module, and are summoned by the below context manager. See https://github.com/pytorch/pytorch/issues/100069 # for more info. # Note: We use recurse=False here so that we only summon full params for the LM head, not the entire model. - with FSDP.summon_full_params( + summon_full_params = get_summon_params_fn(self.model) + with summon_full_params( self.model, writeback=False, recurse=False, diff --git a/compose_rl/algorithms/reward_modeling/model.py b/compose_rl/algorithms/reward_modeling/model.py index 2d3637ee..0986a533 100644 --- a/compose_rl/algorithms/reward_modeling/model.py +++ b/compose_rl/algorithms/reward_modeling/model.py @@ -9,6 +9,7 @@ from typing import Any, Mapping, MutableMapping, Optional, Union import torch +from composer.distributed.shared_utils import get_summon_params_fn from composer.utils import is_model_fsdp from llmfoundry.models import ComposerHFCausalLM, ComposerMPTCausalLM @@ -241,7 +242,6 @@ def loss(self, outputs: SequenceClassifierOutput, class ComposerHFCausalClassifierRewardModel(ComposerHFCausalLM, RewardModel): - default_train_metrics: tuple = () default_eval_metrics: tuple = () @@ -292,9 +292,9 @@ def mask_last_embed_except_eos( context_manager = nullcontext if is_model_fsdp(self.model): - from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + summon_full_params = get_summon_params_fn(self.model) context_manager = partial( - FSDP.summon_full_params, + summon_full_params, self.model, writeback=True, recurse=False, diff --git a/compose_rl/utils/utils.py b/compose_rl/utils/utils.py index b51b7fc2..863745b5 100644 --- a/compose_rl/utils/utils.py +++ b/compose_rl/utils/utils.py @@ -11,7 +11,9 @@ import spacy_alignments as tokenizations import torch import torch.nn.functional as F +from composer.utils import dist from kubernetes import client, config +from torch.distributed.fsdp import FSDPModule from torch.utils.data import DataLoader from transformers import PretrainedConfig @@ -998,14 +1000,18 @@ def flip_pad_token_usage_for_generate(model: torch.nn.Module): assert len(model.transformer.blocks) > 0 # type: ignore block = model.transformer.blocks[0] # type: ignore # Logic takes care of the activation checkpointing case w/ FSDP - if hasattr( - block._fsdp_wrapped_module, # type: ignore - '_checkpoint_wrapped_module', - ): - needs_flipping = not block._fsdp_wrapped_module._checkpoint_wrapped_module.use_pad_tok_in_ffn # type: ignore + if hasattr(block, '_fsdp_wrapped_module'): + fsdp_wrapped_module = block._fsdp_wrapped_module # type: ignore + elif isinstance(block, FSDPModule): + fsdp_wrapped_module = block + else: + return needs_flipping + + if hasattr(fsdp_wrapped_module, '_checkpoint_wrapped_module'): + needs_flipping = not fsdp_wrapped_module._checkpoint_wrapped_module.use_pad_tok_in_ffn # type: ignore else: # Otherwise we avoid the activation checkpointing and toggle the flag here - needs_flipping = not block._fsdp_wrapped_module.use_pad_tok_in_ffn # type: ignore + needs_flipping = not fsdp_wrapped_module.use_pad_tok_in_ffn # type: ignore if needs_flipping: flip_pad_token_usage_in_ffn(model) @@ -1024,11 +1030,17 @@ def flip_pad_token_usage_in_ffn(model: torch.nn.Module): for block in model.transformer.blocks: # type: ignore # Logic takes care of the activation checkpointing case w/ FSDP - if hasattr(block._fsdp_wrapped_module, '_checkpoint_wrapped_module'): - block._fsdp_wrapped_module._checkpoint_wrapped_module.use_pad_tok_in_ffn = not block._fsdp_wrapped_module._checkpoint_wrapped_module.use_pad_tok_in_ffn + if hasattr(block, '_fsdp_wrapped_module'): + fsdp_wrapped_module = block._fsdp_wrapped_module + elif isinstance(block, FSDPModule): + fsdp_wrapped_module = block + else: + continue + if hasattr(fsdp_wrapped_module, '_checkpoint_wrapped_module'): + fsdp_wrapped_module._checkpoint_wrapped_module.use_pad_tok_in_ffn = not fsdp_wrapped_module._checkpoint_wrapped_module.use_pad_tok_in_ffn # type: ignore else: # Otherwise we avoid the activation checkpointing and toggle the flag here - block._fsdp_wrapped_module.use_pad_tok_in_ffn = not block._fsdp_wrapped_module.use_pad_tok_in_ffn + fsdp_wrapped_module.use_pad_tok_in_ffn = not fsdp_wrapped_module.use_pad_tok_in_ffn # type: ignore def get_remote_name(pod_name: str):