From 86695d1b7cf4a65cc8eb89265b9230181de48076 Mon Sep 17 00:00:00 2001 From: Rithwik Ediga Lakhamsani Date: Fri, 11 Jul 2025 15:59:40 -0700 Subject: [PATCH 01/15] supporting fsdp2 formatting plz format working on testing added new test that fails using dtensor APIs added some more tests formatted finally works smh formatted plz work im begging undid changes to reward modeling since that's breaking on main why doesn't the formatter work smh testing out specific FSDP2 change added another test added some logging wip a comment --- .../online/generation_utils/vllm_utils.py | 177 ++++++++----- compose_rl/algorithms/online/hf_utils.py | 7 +- compose_rl/algorithms/online/model.py | 9 +- .../algorithms/reward_modeling/model.py | 5 +- compose_rl/utils/utils.py | 191 +++++++++++++- tests/common/models.py | 123 +++++++++ tests/test_offline.py | 4 + tests/test_utils.py | 249 ++++++++++++++++++ 8 files changed, 675 insertions(+), 90 deletions(-) create mode 100644 tests/common/models.py diff --git a/compose_rl/algorithms/online/generation_utils/vllm_utils.py b/compose_rl/algorithms/online/generation_utils/vllm_utils.py index 62c86d41..7fdcbcc4 100644 --- a/compose_rl/algorithms/online/generation_utils/vllm_utils.py +++ b/compose_rl/algorithms/online/generation_utils/vllm_utils.py @@ -42,13 +42,18 @@ default_pg_timeout, rendezvous, ) +from torch.distributed.fsdp import FSDPModule from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +# DTensor debugging imports +from torch.distributed.tensor import DTensor + from compose_rl.algorithms.online.generation_utils.vllm_actor import LLMRayActor from compose_rl.algorithms.online.model_methods import ( ALGORITHM_TYPE, OnPolicyEnum, ) +from compose_rl.utils.utils import summon_full_params log = logging.getLogger(__name__) @@ -113,7 +118,6 @@ def init_process_group( class WorkerWrap: - def init_process_group( self, master_address: str, @@ -312,6 +316,7 @@ def simplify_param_path(path: str) -> str: """ # Parts we want to remove remove_parts = [ + '_wrapped_module', '_fsdp_wrapped_module', '_checkpoint_wrapped_module', 'lm_backbone', @@ -333,15 +338,15 @@ def simplify_param_path(path: str) -> str: def is_fsdp_leaf(module: nn.Module) -> bool: - """Check if the module is a leaf in the FSDP hierarchy. + """Check if the module is a leaf in the FSDP(1/2) hierarchy. Args: module (nn.Module): The torch module to check """ - if not isinstance(module, FSDP): + if not isinstance(module, (FSDP, FSDPModule)): return False for subm in module.modules(): - if subm is not module and isinstance(subm, FSDP): + if subm is not module and isinstance(subm, (FSDP, FSDPModule)): return False return True @@ -377,6 +382,19 @@ def should_update_torch_module( return False +def get_path_to_param(model: nn.Module, param: torch.Tensor) -> str: + """Get the path to a parameter in the model. + + Args: + model (nn.Module): The model to get the path to + param (torch.Tensor): The parameter to get the path to + """ + for name, p in model.named_parameters(): + if p is param: + return name + raise ValueError(f'Parameter {param} not found in model {model}') + + def broadcast_to_vllm( model: nn.Module, vllm_engines: list, @@ -395,23 +413,25 @@ def broadcast_to_vllm( loss_type (str): The loss type which decides whether to use critic-free or not. Defaults to `ppo`. enable_prefix_caching (bool): Whether to enable prefix caching. Defaults to `False`. """ - # avoid OOM + # To avoid OOM torch.cuda.empty_cache() if loss_type == OnPolicyEnum.PPO: # Extract the lm_backbone params from the model - count, num_params = 0, len( + num_params = len( list(model.model.lm_backbone.named_parameters()), # type: ignore ) elif loss_type in ALGORITHM_TYPE.CRITIC_FREE: # Directly use the model params - count, num_params = 0, len( + num_params = len( list(model.model.named_parameters()), # type: ignore ) else: raise ValueError( f'Unsupported loss type: {loss_type}. Supported types are: ppo, grpo', ) + count = 0 + # Reset prefix caching if enabled refss = [] cache_reset_refss = [] if enable_prefix_caching and dist.get_global_rank() == 0: @@ -430,8 +450,6 @@ def broadcast_to_vllm( ] seen_fsdp_modules = set() seen_updated_parsed_names = set() - count = 0 - param_2_full_name = build_param_fullnames(model) with torch.no_grad(): # Adding a dummy forwards call. @@ -455,66 +473,89 @@ def broadcast_to_vllm( update_time = 0 for module_name, module in model.named_modules(): - if isinstance(module, FSDP): - # This is needed otherwise FSDP will materialize parameters of size 0. - # So just for the joint actor critic models we have to actually skip this module. - if module_name == 'model' and loss_type == OnPolicyEnum.PPO: - continue - - # Only update if we haven't updated this module before - if module not in seen_fsdp_modules: - seen_fsdp_modules.add(module) - - # Materializes parameters for this specific FSDP module - with FSDP.summon_full_params( + # Skip non-FSDP modules + if not isinstance(module, (FSDP, FSDPModule)): + continue + + # This is needed otherwise FSDP will materialize parameters of size 0. + # So just for the joint actor critic models we have to actually skip this module. + if module_name == 'model' and loss_type == OnPolicyEnum.PPO: + continue + + # Only update if we haven't updated this module before + if module in seen_fsdp_modules: + continue + seen_fsdp_modules.add(module) + + # Materializes parameters for this specific FSDP module specifically. + # Don't materialize the entire model to avoid potential OOM. + with summon_full_params( + module, + writeback=False, + rank0_only=True, + recurse=False, + ): + # Note: We have to recurse=True since the following case is possible: + # FSDP_Module + # |- direct_param (found with recurse=False) + # |- NonFSDP_Child + # | |- child_param (missed with recurse=False) + for _, param in module.named_parameters(recurse=True): + # Only distribute on rank 0 + if not dist.get_global_rank() == 0: + continue + + # Skip DTensor params at this level since they were not summoned + # and we only want to broadcast the summoned parameters. + # Encountering this implies nested FSDPModules and a later module, + # when summoned, will convert this DTensor to a regular tensor. + # TODO: Investigate why this isn't an issue for FSDP1. + if isinstance(param, DTensor): + continue + + full_name = get_path_to_param(model, param) + parsed_name = simplify_param_path(full_name) + + # We've already updated this module before, + if parsed_name in seen_updated_parsed_names: + continue + + if 'critic_head' in parsed_name: + log.info('Critic head found, skipping sending') + continue + + update = should_update_torch_module( + parsed_name, + full_name, module, - writeback=False, - rank0_only=True, - recurse=False, - ): - for _, param in module.named_parameters(recurse=True): - if dist.get_global_rank() == 0: - full_name = param_2_full_name[param] - parsed_name = simplify_param_path(full_name) - - if 'critic_head' in parsed_name: - log.info('Critic head found, skipping sending') - continue - - update = should_update_torch_module( - parsed_name, - full_name, - module, - loss_type, - valid_non_leaf_module_names, - ) - - # We've already updated this module before, - if parsed_name in seen_updated_parsed_names: - continue - - # Usually if we have to skip a module, it's because we cannot - if update: - start_update_time = time.time() - seen_updated_parsed_names.add(parsed_name) - - count += 1 - shape = param.shape - refs = [ - engine.update_weight.remote( - parsed_name, - dtype=param.dtype, - shape=shape, - empty_cache=(count == num_params), - ) for engine in vllm_engines - ] - refss.extend(refs) - torch.distributed.broadcast( - param.data, - 0, - group=model_update_group, - ) - update_time += time.time() - start_update_time + loss_type, + valid_non_leaf_module_names, + ) + + if not update: + continue + + start_update_time = time.time() + seen_updated_parsed_names.add(parsed_name) + + count += 1 + shape = param.shape + refs = [ + engine.update_weight.remote( + parsed_name, + dtype=param.dtype, + shape=shape, + empty_cache=(count == num_params), + ) for engine in vllm_engines + ] + refss.extend(refs) + + torch.distributed.broadcast( + param.data, + 0, + group=model_update_group, + ) + update_time += time.time() - start_update_time # Issue (#67): Note this code will likely need to be updated for PEFT for efficiency reasons. if dist.get_global_rank() == 0: diff --git a/compose_rl/algorithms/online/hf_utils.py b/compose_rl/algorithms/online/hf_utils.py index 4ef8a407..2c1a46ec 100644 --- a/compose_rl/algorithms/online/hf_utils.py +++ b/compose_rl/algorithms/online/hf_utils.py @@ -23,6 +23,7 @@ ) from compose_rl.algorithms.online.policy_configuration import HFPolicyConfig from compose_rl.utils.consts import _MASTER_WEIGHTS_PRECISION +from compose_rl.utils.utils import summon_full_params Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] @@ -93,14 +94,12 @@ def generate( **kwargs: Any, ): if is_model_fsdp(self.lm_backbone): - from torch.distributed.fsdp import FullyShardedDataParallel as FSDP - - # Note: We need to use the FSDP.summon_full_params context manager here because the generate function + # Note: We need to use the summon_full_params context manager here because the generate function # does not seem to gather the weights for the LM head. This solution works because the tied weights of the LM head # are in the root FSDP module, and are summoned by the below context manager. See https://github.com/pytorch/pytorch/issues/100069 # for more info. # Note: We use recurse=False here so that we only summon full params for the LM head, not the entire model. - with FSDP.summon_full_params( + with summon_full_params( self.lm_backbone, writeback=False, recurse=False, diff --git a/compose_rl/algorithms/online/model.py b/compose_rl/algorithms/online/model.py index efe5eeda..bbdeebd1 100644 --- a/compose_rl/algorithms/online/model.py +++ b/compose_rl/algorithms/online/model.py @@ -28,6 +28,7 @@ clear_mb_load_balancing_loss, get_mb_load_balancing_loss, ) +from compose_rl.utils.utils import summon_full_params Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] @@ -35,7 +36,6 @@ class ComposerMPTPolicyLM(HuggingFaceModel): - def __init__( self, tokenizer: Tokenizer, @@ -135,7 +135,6 @@ def set_batch_stats(self, batch_stats: dict[str, Any]): class ComposerHFPolicyLM(ComposerHFPolicy): - def __init__( self, *, @@ -186,14 +185,12 @@ def generate(self, input_ids: torch.Tensor, *args: Any, **kwargs: Any): # Note: it seems as if we need to summon FSDP parameters here to ensure that we don't break # the standard actor critic forward pass. if is_model_fsdp(self.model): - from torch.distributed.fsdp import FullyShardedDataParallel as FSDP - - # Note: We need to use the FSDP.summon_full_params context manager here because the generate function + # Note: We need to use the summon_full_params context manager here because the generate function # does not seem to gather the weights for the LM head. This solution works because the tied weights of the LM head # are in the root FSDP module, and are summoned by the below context manager. See https://github.com/pytorch/pytorch/issues/100069 # for more info. # Note: We use recurse=False here so that we only summon full params for the LM head, not the entire model. - with FSDP.summon_full_params( + with summon_full_params( self.model, writeback=False, recurse=False, diff --git a/compose_rl/algorithms/reward_modeling/model.py b/compose_rl/algorithms/reward_modeling/model.py index 2d3637ee..ff22edc7 100644 --- a/compose_rl/algorithms/reward_modeling/model.py +++ b/compose_rl/algorithms/reward_modeling/model.py @@ -31,6 +31,7 @@ ComposerHFSequenceClassification from compose_rl.algorithms.reward_modeling.modeling_mpt import \ MPTForSequenceClassification +from compose_rl.utils.utils import summon_full_params log = logging.getLogger(__name__) @@ -241,7 +242,6 @@ def loss(self, outputs: SequenceClassifierOutput, class ComposerHFCausalClassifierRewardModel(ComposerHFCausalLM, RewardModel): - default_train_metrics: tuple = () default_eval_metrics: tuple = () @@ -292,9 +292,8 @@ def mask_last_embed_except_eos( context_manager = nullcontext if is_model_fsdp(self.model): - from torch.distributed.fsdp import FullyShardedDataParallel as FSDP context_manager = partial( - FSDP.summon_full_params, + summon_full_params, self.model, writeback=True, recurse=False, diff --git a/compose_rl/utils/utils.py b/compose_rl/utils/utils.py index b51b7fc2..405535bd 100644 --- a/compose_rl/utils/utils.py +++ b/compose_rl/utils/utils.py @@ -5,13 +5,16 @@ import re import warnings from collections.abc import Generator, Iterable +from contextlib import contextmanager from typing import Any, Optional, Union import spacy import spacy_alignments as tokenizations import torch import torch.nn.functional as F +from composer.utils import dist from kubernetes import client, config +from torch.distributed.fsdp import FSDPModule from torch.utils.data import DataLoader from transformers import PretrainedConfig @@ -998,14 +1001,18 @@ def flip_pad_token_usage_for_generate(model: torch.nn.Module): assert len(model.transformer.blocks) > 0 # type: ignore block = model.transformer.blocks[0] # type: ignore # Logic takes care of the activation checkpointing case w/ FSDP - if hasattr( - block._fsdp_wrapped_module, # type: ignore - '_checkpoint_wrapped_module', - ): - needs_flipping = not block._fsdp_wrapped_module._checkpoint_wrapped_module.use_pad_tok_in_ffn # type: ignore + if hasattr(block, '_fsdp_wrapped_module'): + fsdp_wrapped_module = block._fsdp_wrapped_module # type: ignore + elif isinstance(block, FSDPModule): + fsdp_wrapped_module = block + else: + return needs_flipping + + if hasattr(fsdp_wrapped_module, '_checkpoint_wrapped_module'): + needs_flipping = not fsdp_wrapped_module._checkpoint_wrapped_module.use_pad_tok_in_ffn # type: ignore else: # Otherwise we avoid the activation checkpointing and toggle the flag here - needs_flipping = not block._fsdp_wrapped_module.use_pad_tok_in_ffn # type: ignore + needs_flipping = not fsdp_wrapped_module.use_pad_tok_in_ffn # type: ignore if needs_flipping: flip_pad_token_usage_in_ffn(model) @@ -1024,11 +1031,17 @@ def flip_pad_token_usage_in_ffn(model: torch.nn.Module): for block in model.transformer.blocks: # type: ignore # Logic takes care of the activation checkpointing case w/ FSDP - if hasattr(block._fsdp_wrapped_module, '_checkpoint_wrapped_module'): - block._fsdp_wrapped_module._checkpoint_wrapped_module.use_pad_tok_in_ffn = not block._fsdp_wrapped_module._checkpoint_wrapped_module.use_pad_tok_in_ffn + if hasattr(block, '_fsdp_wrapped_module'): + fsdp_wrapped_module = block._fsdp_wrapped_module + elif isinstance(block, FSDPModule): + fsdp_wrapped_module = block + else: + continue + if hasattr(fsdp_wrapped_module, '_checkpoint_wrapped_module'): + fsdp_wrapped_module._checkpoint_wrapped_module.use_pad_tok_in_ffn = not fsdp_wrapped_module._checkpoint_wrapped_module.use_pad_tok_in_ffn # type: ignore else: # Otherwise we avoid the activation checkpointing and toggle the flag here - block._fsdp_wrapped_module.use_pad_tok_in_ffn = not block._fsdp_wrapped_module.use_pad_tok_in_ffn + fsdp_wrapped_module.use_pad_tok_in_ffn = not fsdp_wrapped_module.use_pad_tok_in_ffn # type: ignore def get_remote_name(pod_name: str): @@ -1237,3 +1250,163 @@ def flatten(coll: Union[Iterable[Any], str]) -> Generator[Any, None, None]: yield subc else: yield i + + +@contextmanager +def _summon_full_params_fsdp2( + model: torch.nn.Module, + writeback: bool = True, + recurse: bool = True, +): + """Context manager to get full params for FSDP2 models with DTensor APIs. + + Note: We use DTensor APIs to materialize the full parameters instead of using `unshard` + and `reshard` as writeback doesn't seem to work correctly with DTensors + that uses DTensor APIs to materialize the full parameters. We currently don't support + rank0_only + """ + from torch.distributed.tensor import DTensor, Replicate, distribute_tensor + + dtensor_params = { + name: param + for name, param in + model.named_parameters(recurse=recurse, remove_duplicate=False) + if isinstance(param, DTensor) + } + + if not dtensor_params: + yield + return + + model_dtensors = {} + metadata = {} + tied_params = {} + + # We want to get the module and attr of the param, so we can assign + # module.attr = param.full_tensor() before we yield and + # module.attr = distributed (maybe updated) tensor after we yield. + def _get_module_and_attr(model: torch.nn.Module, param_name: str): + parts = param_name.split('.') + module = model + for part in parts[:-1]: + module = getattr(module, part) + return module, parts[-1] + + # Group parameters by their underlying tensor to handle tied parameters + tensor_to_names = {} + for name, dtensor_param in dtensor_params.items(): + tensor_id = id(dtensor_param) + if tensor_id not in tensor_to_names: + tensor_to_names[tensor_id] = [] + tensor_to_names[tensor_id].append(name) + + # Process parameters, handling tied parameters correctly + processed_tensors = set() + for name, dtensor_param in dtensor_params.items(): + tensor_id = id(dtensor_param) + + metadata[name] = { + 'device_mesh': dtensor_param.device_mesh, + 'placements': dtensor_param.placements, + 'requires_grad': dtensor_param.requires_grad, + } + model_dtensors[name] = dtensor_param + + # Only materialize the full tensor once per unique tensor + if tensor_id not in processed_tensors: + full_tensor = dtensor_param.full_tensor() + new_param = torch.nn.Parameter(full_tensor.detach().clone()) + + # Set the same parameter instance for all tied parameters + for tied_name in tensor_to_names[tensor_id]: + module, attr_name = _get_module_and_attr(model, tied_name) + setattr(module, attr_name, new_param) + tied_params[tied_name] = new_param + + processed_tensors.add(tensor_id) + + try: + yield + finally: + # Process tied parameters to ensure writeback works correctly + processed_tensors = set() + tensor_to_updated_dtensor = {} + + for name in dtensor_params.keys(): + module, attr_name = _get_module_and_attr(model, name) + tensor_id = id(model_dtensors[name]) + + if writeback and tensor_id not in processed_tensors: + # We update model_dtensors[name] to use the updated param + # after the model changes. For tied parameters, we only need + # to do this once per unique tensor. + current_param = getattr(module, attr_name) + if hasattr( + current_param, + 'data', + ) and current_param.data is not None: + meta = metadata[name] + replicated = distribute_tensor( + current_param.data, + meta['device_mesh'], + [Replicate()], + ) + sharded = replicated.redistribute( + meta['device_mesh'], + meta['placements'], + ) + new_param = torch.nn.Parameter(sharded) + new_param.requires_grad = meta['requires_grad'] + tensor_to_updated_dtensor[tensor_id] = new_param + processed_tensors.add(tensor_id) + + # Restore the appropriate DTensor for this parameter + if writeback and tensor_id in tensor_to_updated_dtensor: + setattr(module, attr_name, tensor_to_updated_dtensor[tensor_id]) + else: + setattr(module, attr_name, model_dtensors[name]) + + +@contextmanager +def summon_full_params( + model: torch.nn.Module, + writeback: bool = True, + recurse: bool = True, + rank0_only: bool = False, +): + """Context manager to summon full parameters for an FSDP(1/2) model. + + Args: + model (torch.nn.Module): The FSDP model to summon full parameters for. + writeback (bool): Whether to write back parameter changes. Defaults to False. + recurse (bool): Whether to recurse into submodules. Defaults to True. + rank0_only (bool): Whether to summon full parameters on only rank 0. Defaults to False. + Only supported for FSDP1. FSDP2 by default materializes all parameters on all ranks. + """ + + def is_fsdp2(model: torch.nn.Module) -> bool: + try: + from torch.distributed.fsdp import FSDPModule + for module in model.modules(): + if isinstance(module, FSDPModule): + return True + except ImportError: + pass + return False + + if is_fsdp2(model): + with _summon_full_params_fsdp2( + model, + writeback=writeback, + recurse=recurse, + ): + yield + else: + from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + with FSDP.summon_full_params( + model, + writeback=writeback, + recurse=recurse, + rank0_only=rank0_only, + ): + yield diff --git a/tests/common/models.py b/tests/common/models.py new file mode 100644 index 00000000..57dceba5 --- /dev/null +++ b/tests/common/models.py @@ -0,0 +1,123 @@ +# Copyright 2024 MosaicML ComposeRL authors +# SPDX-License-Identifier: Apache-2.0 + +from functools import partial + +import torch +from composer.models import ComposerClassifier + + +class SimpleMLP(torch.nn.Module): + def __init__(self, num_features: int, device: str = 'cpu'): + super().__init__() + fc1 = torch.nn.Linear( + num_features, + num_features, + device=device, + bias=False, + ) + fc2 = torch.nn.Linear( + num_features, + num_features, + device=device, + bias=False, + ) + + self.net = torch.nn.Sequential(fc1, torch.nn.ReLU(), fc2) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.net(x) + + +class SimpleWeightTiedModel(ComposerClassifier): + """Small classification model with tied weights. + + Args: + num_features (int): number of input features (default: 1) + device (str): the device to initialize the model (default: 'cpu') + """ + + def __init__(self, num_features: int = 1, device: str = 'cpu') -> None: + self.num_features = num_features + + mlp = SimpleMLP(num_features, device) + + net = torch.nn.Sequential( + mlp, + torch.nn.Softmax(dim=-1), + ) + + super().__init__(module=net, num_classes=num_features) + + self.module.param_init_fn = self.param_init_fn # pyright: ignore[reportGeneralTypeIssues] + + # Adding mlp.fc1.weight = mlp.fc2.weight without assignment to self.fc1 and self.fc2 + # since we don't want to create duplicate references to the same module + # since that will break mixed init. + mlp.net[0].weight = mlp.net[-1].weight + + def add_fsdp_wrap_attribute_to_children(self): + for child in self.children(): + child._fsdp_wrap = False # type: ignore + for child in self.module.children(): + child._fsdp_wrap = True # type: ignore + + def param_init_fn(self, module: torch.nn.Module): + init_fn = partial(torch.nn.init.normal_, mean=0.0, std=0.1) + + if isinstance(module, torch.nn.Linear): + init_fn(module.weight) + if module.bias is not None: # pyright: ignore[reportUnnecessaryComparison] + torch.nn.init.zeros_(module.bias) + + +class PartialWeightTiedModel(ComposerClassifier): + """Small classification model with partially tied weights. + + Args: + num_features (int): number of input features (default: 1) + device (str): the device to initialize the model (default: 'cpu') + """ + + def __init__(self, num_features: int = 1, device: str = 'cpu') -> None: + mlp = SimpleMLP(num_features, device) + + # a third fc layer that is not tied to the above mlp + fc3 = torch.nn.Linear( + num_features, + num_features, + device=device, + bias=False, + ) + + net = torch.nn.Sequential( + mlp, + fc3, + torch.nn.Softmax(dim=-1), + ) + + # fc1 would be a child module of the Sequential module now but only the mlp should be FSDP wrapped + # TODO support this or add negative test for this + # net.fc1 = mlp.fc1 + + super().__init__(module=net, num_classes=num_features) + self.module.param_init_fn = self.param_init_fn # pyright: ignore[reportGeneralTypeIssues] + + # Adding mlp.fc1.weight = mlp.fc2.weight without assignment to self.fc1 and self.fc2 + # since we don't want to create duplicate references to the same module since that + # will break mixed init. + mlp.net[0].weight = mlp.net[-1].weight + + def add_fsdp_wrap_attribute_to_children(self): + for child in self.children(): + child._fsdp_wrap = False # type: ignore + for child in self.module.children(): + child._fsdp_wrap = True # type: ignore + + def param_init_fn(self, module: torch.nn.Module): + init_fn = partial(torch.nn.init.normal_, mean=0.0, std=0.1) + + if isinstance(module, torch.nn.Linear): + init_fn(module.weight) + if module.bias is not None: # pyright: ignore[reportUnnecessaryComparison] + torch.nn.init.zeros_(module.bias) diff --git a/tests/test_offline.py b/tests/test_offline.py index 2838700b..5bd52548 100644 --- a/tests/test_offline.py +++ b/tests/test_offline.py @@ -174,11 +174,14 @@ def test_model_forward(tiny_gpt2_tokenizer: PreTrainedTokenizer): @pytest.mark.gpu @world_size(2) @pytest.mark.parametrize('fsdp_config', [None, {}]) # type: ignore +@pytest.mark.parametrize('fsdp_version', [1, 2]) def test_train( tiny_gpt2_tokenizer: PreTrainedTokenizer, world_size: int, fsdp_config: dict[str, Any], + fsdp_version: int, ): + os.environ['FSDP_VERSION'] = str(fsdp_version) max_seq_len = 10 dataset = PairwisePreference(max_seq_len=max_seq_len) dataloader = DataLoader( @@ -214,6 +217,7 @@ def test_train( max_duration='1ep', ) trainer.fit() + os.environ['FSDP_VERSION'] = '1' @pytest.mark.skip( diff --git a/tests/test_utils.py b/tests/test_utils.py index c913ccbc..0cc42026 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,9 +1,13 @@ # Copyright 2024 MosaicML ComposeRL authors # SPDX-License-Identifier: Apache-2.0 +import os +from typing import Optional + import pytest import torch import torch.nn.functional as F +from transformers import PreTrainedTokenizer from compose_rl.utils import mask_eos from compose_rl.utils.utils import ( @@ -12,7 +16,10 @@ get_token_entropies, masked_mean, sample_wise_masked_mean, + summon_full_params, ) +from tests.common.markers import world_size +from tests.common.models import PartialWeightTiedModel def test_mask_eos_basic_functionality(): @@ -734,3 +741,245 @@ def test_get_entropies_integration(): peaky_probs = F.softmax(peaky_logits, dim=0) expected_entropy = -torch.sum(peaky_probs * torch.log(peaky_probs)) assert torch.isclose(entropies[1], expected_entropy, atol=1e-5) + + +def _setup_fsdp_test_environment( + tiny_gpt2_tokenizer: PreTrainedTokenizer, + fsdp_version: int, + model: Optional[torch.nn.Module] = None, +): + """Helper function to set up FSDP test environment.""" + import os + from functools import partial + + from composer import Trainer + from composer.utils import dist + from torch.utils.data import DataLoader + + from compose_rl.algorithms.offline import ComposerMPTPairwiseOfflinePolicyLM + from compose_rl.data import pairwise_preference_dataset_collate_fn + from tests.common import PairwisePreference + + # Set FSDP version + os.environ['FSDP_VERSION'] = str(fsdp_version) + + # Create a dataset and dataloader + max_seq_len = 10 + dataset = PairwisePreference(max_seq_len=max_seq_len) + dataloader = DataLoader( + dataset, + collate_fn=partial( + pairwise_preference_dataset_collate_fn, + tiny_gpt2_tokenizer, + max_seq_len, + ), + sampler=dist.get_sampler(dataset), + batch_size=2, + ) + + # Create model config + model_config = { + 'n_layers': 1, + 'attn_config': { + 'attn_impl': 'torch', + }, + 'tokenizer': tiny_gpt2_tokenizer, + } + + # Create model + if model is None: + model = ComposerMPTPairwiseOfflinePolicyLM(**model_config) + + # Enable FSDP + fsdp_config = {} + trainer = Trainer( + model=model, # type: ignore + train_dataloader=dataloader, + parallelism_config={'fsdp': fsdp_config}, + max_duration='1ba', + ) + + return trainer, trainer.state.model + + +@pytest.mark.gpu +@world_size(2) +@pytest.mark.parametrize('fsdp_version', [1, 2]) +def test_summon_full_params( + tiny_gpt2_tokenizer: PreTrainedTokenizer, + world_size: int, + fsdp_version: int, +): + """Test summon_full_params actually works with FSDP(1/2) models.""" + del world_size + trainer, fsdp_model = _setup_fsdp_test_environment( + tiny_gpt2_tokenizer, + fsdp_version, + ) + + def get_total_param_size(model: torch.nn.Module): + total_size = 0 + for param in model.parameters(): + if hasattr(param, 'to_local'): + param = param.to_local() + if param.data is not None: + total_size += param.data.numel() + return total_size + + distributed_param_size = get_total_param_size(fsdp_model) + + # Test with writeback=True + with summon_full_params(fsdp_model): + local_param_size = get_total_param_size(fsdp_model) + + assert local_param_size > distributed_param_size * 1.5, \ + f'Local param size {local_param_size} should be > 1.5x distributed param size {distributed_param_size}' + + trainer.close() + os.environ['FSDP_VERSION'] = '1' + + +@pytest.mark.gpu +@world_size(2) +@pytest.mark.parametrize('fsdp_version', [1, 2]) +def test_summon_full_params_with_fsdp_writeback( + tiny_gpt2_tokenizer: PreTrainedTokenizer, + world_size: int, + fsdp_version: int, +): + """Test summon_full_params with actual FSDP models.""" + del world_size + trainer, fsdp_model = _setup_fsdp_test_environment( + tiny_gpt2_tokenizer, + fsdp_version, + ) + + original_local_tensors = { + name: param.data.clone() for name, param in fsdp_model.named_parameters() + } + + # Test out writeback=False + with summon_full_params(fsdp_model, writeback=False): + # Modify parameters inside the context + for name, param in fsdp_model.named_parameters(): + if param.data is not None: # type: ignore + param.data.fill_(777.0) + + for name, param in fsdp_model.named_parameters(): + if param.data is not None: # type: ignore + assert torch.all( + param.data == original_local_tensors[name], + ), f'Parameter {name} should not be modified with writeback=False' + + # Test with writeback=True + with summon_full_params(fsdp_model, writeback=True): + for name, param in fsdp_model.named_parameters(): + if param.data is not None: # type: ignore + param.data.fill_(888.0) + + for name, param in fsdp_model.named_parameters(): + if param.data is not None: # type: ignore + assert torch.all( + param.data == 888.0, + ), f'Parameter {name} should be modified with writeback=True' + + trainer.close() + os.environ['FSDP_VERSION'] = '1' + + +@pytest.mark.gpu +@world_size(2) +@pytest.mark.parametrize('fsdp_version', [1, 2]) +def test_summon_full_params_recurse( + tiny_gpt2_tokenizer: PreTrainedTokenizer, + world_size: int, + fsdp_version: int, +): + """Test summon_full_params with recurse=False parameter.""" + del world_size + trainer, fsdp_model = _setup_fsdp_test_environment( + tiny_gpt2_tokenizer, + fsdp_version, + ) + + with summon_full_params(fsdp_model, recurse=False): + for name, param in fsdp_model.named_parameters(recurse=False): + assert param.data is not None # type: ignore + assert '.' not in name + + with summon_full_params(fsdp_model, recurse=True): + param_names = [ + name for name, _ in fsdp_model.named_parameters(recurse=True) + ] + assert any('.' in name for name in param_names) + + trainer.close() + os.environ['FSDP_VERSION'] = '1' + + +@pytest.mark.gpu +@world_size(2) +@pytest.mark.parametrize('fsdp_version', [1, 2]) +def test_summon_full_params_tied_weights_behavior( + world_size: int, + fsdp_version: int, + tiny_gpt2_tokenizer: PreTrainedTokenizer, +): + """Test summon_full_params with tied weights behavior verification.""" + del world_size + model = PartialWeightTiedModel(num_features=2) + + trainer, fsdp_model = _setup_fsdp_test_environment( + tiny_gpt2_tokenizer, + fsdp_version, + model, + ) + + # fill the tied weights with 999.0 + fsdp_model.module[0].net[0].weight.data.fill_(999.0) # type: ignore + + # Test writeback=False + with summon_full_params(fsdp_model, writeback=False): + error_msg = 'Tied weights should be the same tensor object inside context' + first_weight = fsdp_model.module[0].net[0].weight # type: ignore + last_weight = fsdp_model.module[0].net[-1].weight # type: ignore + assert first_weight is last_weight, error_msg + + first_weight.data.fill_(777.0) + error_msg = 'Tied weights should be consistent inside context' + assert torch.all(last_weight.data == 777.0), error_msg + + first_weight_same = torch.all( + fsdp_model.module[0].net[0].weight.data == 999.0, # type: ignore + ) + last_weight_same = torch.all( + fsdp_model.module[0].net[-1].weight.data == 999.0, # type: ignore + ) + + assert first_weight_same, 'First tied weight should be the same with writeback=False' + assert last_weight_same, 'Second tied weight should be the same with writeback=False' + + # Test writeback=True + with summon_full_params(fsdp_model, writeback=True): + first_weight = fsdp_model.module[0].net[0].weight # type: ignore + last_weight = fsdp_model.module[0].net[-1].weight # type: ignore + error_msg = 'Tied weights should be the same tensor object inside context' + assert first_weight is last_weight, error_msg + + first_weight.data.fill_(888.0) + + error_msg = 'Tied weights should be consistent inside context' + assert torch.all(last_weight.data == 888.0), error_msg + + first_weight_changed = torch.all( + fsdp_model.module[0].net[0].weight.data == 888.0, # type: ignore + ) + last_weight_changed = torch.all( + fsdp_model.module[0].net[-1].weight.data == 888.0, # type: ignore + ) + + assert first_weight_changed, 'First tied weight should keep modified value with writeback=True' + assert last_weight_changed, 'Second tied weight should keep modified value with writeback=True' + + trainer.close() + os.environ['FSDP_VERSION'] = '1' From 5c6410297540b3256d9d9b61d15e420eafe94c0b Mon Sep 17 00:00:00 2001 From: root Date: Fri, 18 Jul 2025 17:31:07 +0000 Subject: [PATCH 02/15] additional logging --- .../algorithms/online/generation_utils/vllm_utils.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/compose_rl/algorithms/online/generation_utils/vllm_utils.py b/compose_rl/algorithms/online/generation_utils/vllm_utils.py index 7fdcbcc4..c1334683 100644 --- a/compose_rl/algorithms/online/generation_utils/vllm_utils.py +++ b/compose_rl/algorithms/online/generation_utils/vllm_utils.py @@ -471,9 +471,11 @@ def broadcast_to_vllm( model(dummy_batch) start_time = time.time() update_time = 0 + _params_considered = 0 for module_name, module in model.named_modules(): # Skip non-FSDP modules + print(f"module_name: {module_name}") if not isinstance(module, (FSDP, FSDPModule)): continue @@ -500,7 +502,7 @@ def broadcast_to_vllm( # |- direct_param (found with recurse=False) # |- NonFSDP_Child # | |- child_param (missed with recurse=False) - for _, param in module.named_parameters(recurse=True): + for _param_name, param in module.named_parameters(recurse=True): # Only distribute on rank 0 if not dist.get_global_rank() == 0: continue @@ -511,10 +513,13 @@ def broadcast_to_vllm( # when summoned, will convert this DTensor to a regular tensor. # TODO: Investigate why this isn't an issue for FSDP1. if isinstance(param, DTensor): + print(f"DTensor found: {_param_name}, skipping.") continue full_name = get_path_to_param(model, param) parsed_name = simplify_param_path(full_name) + print(f"Valid tensor found: {parsed_name}") + _params_considered += 1 # We've already updated this module before, if parsed_name in seen_updated_parsed_names: @@ -535,6 +540,8 @@ def broadcast_to_vllm( if not update: continue + print(f"Updating: {parsed_name}") + start_update_time = time.time() seen_updated_parsed_names.add(parsed_name) @@ -557,6 +564,9 @@ def broadcast_to_vllm( ) update_time += time.time() - start_update_time + print(f"Number of parameters considered: {_params_considered}") + print(f"Number of parameters updated: {count}") + print(f"Number of parameters in the model: {num_params}") # Issue (#67): Note this code will likely need to be updated for PEFT for efficiency reasons. if dist.get_global_rank() == 0: # Check if the number of parameters updated is equal to the number of parameters From 4acd443640fde73e2ac00be751b4a4dcd10622bb Mon Sep 17 00:00:00 2001 From: root Date: Fri, 18 Jul 2025 19:59:18 +0000 Subject: [PATCH 03/15] hopefully this works --- .../online/generation_utils/vllm_utils.py | 19 +++--- compose_rl/utils/utils.py | 51 ++++++++++++-- tests/common/models.py | 66 +++++++++++++++++++ tests/test_utils.py | 40 ++++++++++- 4 files changed, 160 insertions(+), 16 deletions(-) diff --git a/compose_rl/algorithms/online/generation_utils/vllm_utils.py b/compose_rl/algorithms/online/generation_utils/vllm_utils.py index c1334683..69be1234 100644 --- a/compose_rl/algorithms/online/generation_utils/vllm_utils.py +++ b/compose_rl/algorithms/online/generation_utils/vllm_utils.py @@ -473,9 +473,10 @@ def broadcast_to_vllm( update_time = 0 _params_considered = 0 + # Working FSDP1 Run: compose-rl-grpo-test-FIeoi7 + # Working FSDP2 Run: TODO for module_name, module in model.named_modules(): # Skip non-FSDP modules - print(f"module_name: {module_name}") if not isinstance(module, (FSDP, FSDPModule)): continue @@ -502,7 +503,7 @@ def broadcast_to_vllm( # |- direct_param (found with recurse=False) # |- NonFSDP_Child # | |- child_param (missed with recurse=False) - for _param_name, param in module.named_parameters(recurse=True): + for _, param in module.named_parameters(recurse=True): # Only distribute on rank 0 if not dist.get_global_rank() == 0: continue @@ -511,14 +512,16 @@ def broadcast_to_vllm( # and we only want to broadcast the summoned parameters. # Encountering this implies nested FSDPModules and a later module, # when summoned, will convert this DTensor to a regular tensor. - # TODO: Investigate why this isn't an issue for FSDP1. + # TODO: Validate this understanding: It seems that for FSDP1, + # summon_full_params takes control of all parameters within it's + # scope, including any parameters from the submodules that are not + # FSDP-wrapped themselves. if isinstance(param, DTensor): - print(f"DTensor found: {_param_name}, skipping.") continue full_name = get_path_to_param(model, param) parsed_name = simplify_param_path(full_name) - print(f"Valid tensor found: {parsed_name}") + print(f"[RICKY] Valid tensor found: {parsed_name}") _params_considered += 1 # We've already updated this module before, @@ -564,9 +567,9 @@ def broadcast_to_vllm( ) update_time += time.time() - start_update_time - print(f"Number of parameters considered: {_params_considered}") - print(f"Number of parameters updated: {count}") - print(f"Number of parameters in the model: {num_params}") + print(f"[RICKY] Number of parameters considered: {_params_considered}") + print(f"[RICKY] Number of parameters updated: {count}") + print(f"[RICKY] Number of parameters in the model: {num_params}") # Issue (#67): Note this code will likely need to be updated for PEFT for efficiency reasons. if dist.get_global_rank() == 0: # Check if the number of parameters updated is equal to the number of parameters diff --git a/compose_rl/utils/utils.py b/compose_rl/utils/utils.py index 405535bd..e0c9227b 100644 --- a/compose_rl/utils/utils.py +++ b/compose_rl/utils/utils.py @@ -15,6 +15,7 @@ from composer.utils import dist from kubernetes import client, config from torch.distributed.fsdp import FSDPModule +from torch.distributed.tensor import DTensor from torch.utils.data import DataLoader from transformers import PretrainedConfig @@ -1252,6 +1253,47 @@ def flatten(coll: Union[Iterable[Any], str]) -> Generator[Any, None, None]: yield i +def _get_params_to_summon_fsdp2(module: torch.nn.Module, recurse: bool = True): + """ + Gets the DTensors to materialize for an FSDP2 model based on recurse. + + If recurse=False, we can encounter the following state: + FSDPModule_1 + |- weight (DTensor) + |- FSDPModule_2 + | |- weight (DTensor) + |- RegularModule_1 + | |- weight (DTensor) + | |- FSDPModule_3 + | | |- weight (DTensor) + Where summon_full_params(FSDPModule_1) should materialize RegularModule_1.weight + alongside the original FSDPModule_1.weight. Therefore, we use a dfs traversal + to get all DTensors not owned by downstream FSDPModules. + """ + if recurse: + return { + name: param + for name, param in + module.named_parameters(recurse=True, remove_duplicate=False) + if isinstance(param, DTensor) + } + + dtensor_params = {} + def _dfs(module: torch.nn.Module, prefix: str = ''): + # Add all DTensors within this (FSDP)module + for name, param in module.named_parameters(recurse=False, remove_duplicate=False): + if isinstance(param, DTensor): + full_name = f'{prefix}.{name}' if prefix else name + dtensor_params[full_name] = param + for child_name, child in module.named_children(): + if isinstance(child, FSDPModule): + continue + full_name = f'{prefix}.{child_name}' if prefix else child_name + _dfs(child, full_name) + _dfs(module, '') + return dtensor_params + + @contextmanager def _summon_full_params_fsdp2( model: torch.nn.Module, @@ -1265,14 +1307,9 @@ def _summon_full_params_fsdp2( that uses DTensor APIs to materialize the full parameters. We currently don't support rank0_only """ - from torch.distributed.tensor import DTensor, Replicate, distribute_tensor + from torch.distributed.tensor import Replicate, distribute_tensor - dtensor_params = { - name: param - for name, param in - model.named_parameters(recurse=recurse, remove_duplicate=False) - if isinstance(param, DTensor) - } + dtensor_params = _get_params_to_summon_fsdp2(model, recurse=recurse) if not dtensor_params: yield diff --git a/tests/common/models.py b/tests/common/models.py index 57dceba5..773c741d 100644 --- a/tests/common/models.py +++ b/tests/common/models.py @@ -121,3 +121,69 @@ def param_init_fn(self, module: torch.nn.Module): init_fn(module.weight) if module.bias is not None: # pyright: ignore[reportUnnecessaryComparison] torch.nn.init.zeros_(module.bias) + + +class NestedFSDPModel(ComposerClassifier): + """Model to test nested FSDP structure for _get_params_to_summon_fsdp2. + + Creates the following structure: + FSDPModule_1 (root) + |- weight (DTensor) <- 1s + |- FSDPModule_2 (nested FSDP) + | |- weight (DTensor) <- 2s + |- RegularModule_1 (regular module) + | |- weight (DTensor) <- 3s + | |- FSDPModule_3 (nested FSDP inside regular module) + | | |- weight (DTensor) <- 4s + + Args: + num_features (int): number of input features (default: 2) + device (str): the device to initialize the model (default: 'cpu') + """ + + def __init__(self, num_features: int = 2, device: str = 'cpu') -> None: + # Root level linear layer (will be FSDPModule_1) + root_linear = torch.nn.Linear(num_features, num_features, device=device, bias=False) + root_linear.weight.data.fill_(1.0) # All 1s + + # Nested FSDP module (FSDPModule_2) + nested_fsdp_linear = torch.nn.Linear(num_features, num_features, device=device, bias=False) + nested_fsdp_linear.weight.data.fill_(2.0) # All 2s + + # Regular module containing a linear layer and nested FSDP + regular_linear = torch.nn.Linear(num_features, num_features, device=device, bias=False) + regular_linear.weight.data.fill_(3.0) # All 3s + nested_fsdp_in_regular = torch.nn.Linear(num_features, num_features, device=device, bias=False) + nested_fsdp_in_regular.weight.data.fill_(4.0) # All 4s + + # Create the nested structure + regular_module = torch.nn.Sequential( + regular_linear, + nested_fsdp_in_regular, + ) + + # Main network structure + net = torch.nn.Sequential( + root_linear, + nested_fsdp_linear, + regular_module, + torch.nn.Softmax(dim=-1), + ) + + super().__init__(module=net, num_classes=num_features) + self.module.param_init_fn = self.param_init_fn # pyright: ignore[reportGeneralTypeIssues] + + def add_fsdp_wrap_attribute_to_children(self): + self.module[0]._fsdp_wrap = False + self.module[1]._fsdp_wrap = True + self.module[2]._fsdp_wrap = False + self.module[2][0]._fsdp_wrap = False + self.module[2][1]._fsdp_wrap = True + + def param_init_fn(self, module: torch.nn.Module): + init_fn = partial(torch.nn.init.normal_, mean=0.0, std=0.1) + + if isinstance(module, torch.nn.Linear): + init_fn(module.weight) + if module.bias is not None: # pyright: ignore[reportUnnecessaryComparison] + torch.nn.init.zeros_(module.bias) diff --git a/tests/test_utils.py b/tests/test_utils.py index 0cc42026..ceb59dc8 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -17,9 +17,11 @@ masked_mean, sample_wise_masked_mean, summon_full_params, + _get_params_to_summon_fsdp2, ) from tests.common.markers import world_size -from tests.common.models import PartialWeightTiedModel +from tests.common.models import NestedFSDPModel, PartialWeightTiedModel +from torch.distributed.tensor import DTensor def test_mask_eos_basic_functionality(): @@ -983,3 +985,39 @@ def test_summon_full_params_tied_weights_behavior( trainer.close() os.environ['FSDP_VERSION'] = '1' + + +@pytest.mark.gpu +@world_size(2) +def test_get_params_to_summon_fsdp2( + tiny_gpt2_tokenizer: PreTrainedTokenizer, + world_size: int, +): + """Test _get_params_to_summon_fsdp2 function with nested FSDP structure.""" + del world_size + + model = NestedFSDPModel(num_features=2) + model.add_fsdp_wrap_attribute_to_children() + + _, fsdp_model = _setup_fsdp_test_environment( + tiny_gpt2_tokenizer, + fsdp_version=2, + model=model, + ) + + dtensor_params_recurse = _get_params_to_summon_fsdp2(fsdp_model.module, recurse=True) + dtensor_params_no_recurse = _get_params_to_summon_fsdp2(fsdp_model.module, recurse=False) + + # Assert all are DTensors + for param in dtensor_params_recurse.values(): + assert isinstance(param, DTensor), f"Parameter {param.name} should be a DTensor" + for param in dtensor_params_no_recurse.values(): + assert isinstance(param, DTensor), f"Parameter {param.name} should be a DTensor" + + assert len(dtensor_params_recurse) == 4, "Should have 4 DTensors" + for (name, param), value in zip(dtensor_params_recurse.items(), [1.0, 2.0, 3.0, 4.0]): + assert torch.all(param.data == value), f"Parameter {name} should have value {value}" + assert len(dtensor_params_no_recurse) == 2, "Should have 2 DTensors" + for (name, param), value in zip(dtensor_params_no_recurse.items(), [1.0, 3.0]): + assert torch.all(param.data == value), f"Parameter {name} should have value {value}" + os.environ['FSDP_VERSION'] = '1' From 44a4345463a97c4c49e72431db45f7797c985303 Mon Sep 17 00:00:00 2001 From: root Date: Fri, 18 Jul 2025 21:49:33 +0000 Subject: [PATCH 04/15] works finally --- .../online/generation_utils/vllm_utils.py | 31 +++++++------------ compose_rl/algorithms/online/hf_utils.py | 3 +- compose_rl/algorithms/online/model.py | 4 +-- .../algorithms/reward_modeling/model.py | 3 +- compose_rl/utils/utils.py | 12 +++++-- tests/common/models.py | 9 +++--- tests/test_utils.py | 11 ++++--- 7 files changed, 38 insertions(+), 35 deletions(-) diff --git a/compose_rl/algorithms/online/generation_utils/vllm_utils.py b/compose_rl/algorithms/online/generation_utils/vllm_utils.py index 69be1234..1e2f38bb 100644 --- a/compose_rl/algorithms/online/generation_utils/vllm_utils.py +++ b/compose_rl/algorithms/online/generation_utils/vllm_utils.py @@ -471,10 +471,7 @@ def broadcast_to_vllm( model(dummy_batch) start_time = time.time() update_time = 0 - _params_considered = 0 - # Working FSDP1 Run: compose-rl-grpo-test-FIeoi7 - # Working FSDP2 Run: TODO for module_name, module in model.named_modules(): # Skip non-FSDP modules if not isinstance(module, (FSDP, FSDPModule)): @@ -491,14 +488,17 @@ def broadcast_to_vllm( seen_fsdp_modules.add(module) # Materializes parameters for this specific FSDP module specifically. - # Don't materialize the entire model to avoid potential OOM. + # Note that this also materializes params for all submodules that are + # not FSDP-wrapped themselves. We don't want to materialize the entire + # model to avoid potential OOM. with summon_full_params( module, writeback=False, rank0_only=True, recurse=False, ): - # Note: We have to recurse=True since the following case is possible: + # Note: For the following module.named_parameters(), we have to use recurse=True + # since the following case is possible: # FSDP_Module # |- direct_param (found with recurse=False) # |- NonFSDP_Child @@ -510,21 +510,19 @@ def broadcast_to_vllm( # Skip DTensor params at this level since they were not summoned # and we only want to broadcast the summoned parameters. - # Encountering this implies nested FSDPModules and a later module, - # when summoned, will convert this DTensor to a regular tensor. - # TODO: Validate this understanding: It seems that for FSDP1, - # summon_full_params takes control of all parameters within it's - # scope, including any parameters from the submodules that are not - # FSDP-wrapped themselves. + # Encountering this conditional implies that a FSDP-wrapped submodule + # exists and will later be summoned to materialize this parameter. + # + # It seems that for FSDP1, summon_full_params takes control of all parameters + # within its scope, including any parameters from submodules that are not + # FSDP-wrapped themselves. View NestedFSDPModel in tests/common/models.py + # as an example. if isinstance(param, DTensor): continue full_name = get_path_to_param(model, param) parsed_name = simplify_param_path(full_name) - print(f"[RICKY] Valid tensor found: {parsed_name}") - _params_considered += 1 - # We've already updated this module before, if parsed_name in seen_updated_parsed_names: continue @@ -543,8 +541,6 @@ def broadcast_to_vllm( if not update: continue - print(f"Updating: {parsed_name}") - start_update_time = time.time() seen_updated_parsed_names.add(parsed_name) @@ -567,9 +563,6 @@ def broadcast_to_vllm( ) update_time += time.time() - start_update_time - print(f"[RICKY] Number of parameters considered: {_params_considered}") - print(f"[RICKY] Number of parameters updated: {count}") - print(f"[RICKY] Number of parameters in the model: {num_params}") # Issue (#67): Note this code will likely need to be updated for PEFT for efficiency reasons. if dist.get_global_rank() == 0: # Check if the number of parameters updated is equal to the number of parameters diff --git a/compose_rl/algorithms/online/hf_utils.py b/compose_rl/algorithms/online/hf_utils.py index 2c1a46ec..5f91a5d2 100644 --- a/compose_rl/algorithms/online/hf_utils.py +++ b/compose_rl/algorithms/online/hf_utils.py @@ -7,7 +7,6 @@ import torch import torch.nn as nn -from composer.utils import is_model_fsdp from transformers import ( AutoConfig, AutoModelForCausalLM, @@ -23,7 +22,7 @@ ) from compose_rl.algorithms.online.policy_configuration import HFPolicyConfig from compose_rl.utils.consts import _MASTER_WEIGHTS_PRECISION -from compose_rl.utils.utils import summon_full_params +from compose_rl.utils.utils import is_model_fsdp, summon_full_params Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] diff --git a/compose_rl/algorithms/online/model.py b/compose_rl/algorithms/online/model.py index bbdeebd1..09efca01 100644 --- a/compose_rl/algorithms/online/model.py +++ b/compose_rl/algorithms/online/model.py @@ -9,7 +9,7 @@ import torch from composer.models import HuggingFaceModel -from composer.utils import dist, is_model_fsdp +from composer.utils import dist from llmfoundry.models import ComposerHFCausalLM from transformers import ( PreTrainedTokenizer, @@ -28,7 +28,7 @@ clear_mb_load_balancing_loss, get_mb_load_balancing_loss, ) -from compose_rl.utils.utils import summon_full_params +from compose_rl.utils.utils import is_model_fsdp, summon_full_params Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] diff --git a/compose_rl/algorithms/reward_modeling/model.py b/compose_rl/algorithms/reward_modeling/model.py index ff22edc7..bf2ac0a9 100644 --- a/compose_rl/algorithms/reward_modeling/model.py +++ b/compose_rl/algorithms/reward_modeling/model.py @@ -9,7 +9,6 @@ from typing import Any, Mapping, MutableMapping, Optional, Union import torch -from composer.utils import is_model_fsdp from llmfoundry.models import ComposerHFCausalLM, ComposerMPTCausalLM from compose_rl.algorithms.reward_modeling.base_reward import ( @@ -31,7 +30,7 @@ ComposerHFSequenceClassification from compose_rl.algorithms.reward_modeling.modeling_mpt import \ MPTForSequenceClassification -from compose_rl.utils.utils import summon_full_params +from compose_rl.utils.utils import is_model_fsdp, summon_full_params log = logging.getLogger(__name__) diff --git a/compose_rl/utils/utils.py b/compose_rl/utils/utils.py index e0c9227b..8c5ae6d0 100644 --- a/compose_rl/utils/utils.py +++ b/compose_rl/utils/utils.py @@ -14,6 +14,7 @@ import torch.nn.functional as F from composer.utils import dist from kubernetes import client, config +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import FSDPModule from torch.distributed.tensor import DTensor from torch.utils.data import DataLoader @@ -1423,7 +1424,6 @@ def summon_full_params( def is_fsdp2(model: torch.nn.Module) -> bool: try: - from torch.distributed.fsdp import FSDPModule for module in model.modules(): if isinstance(module, FSDPModule): return True @@ -1439,7 +1439,6 @@ def is_fsdp2(model: torch.nn.Module) -> bool: ): yield else: - from torch.distributed.fsdp import FullyShardedDataParallel as FSDP with FSDP.summon_full_params( model, writeback=writeback, @@ -1447,3 +1446,12 @@ def is_fsdp2(model: torch.nn.Module) -> bool: rank0_only=rank0_only, ): yield + +def is_model_fsdp(model: torch.nn.Module) -> bool: + """Whether model has FSDP1/FSDP2 wrapped modules.""" + if isinstance(model, (FSDP, FSDPModule)): + return True + for module in model.modules(): + if isinstance(module, (FSDP, FSDPModule)): + return True + return False diff --git a/tests/common/models.py b/tests/common/models.py index 773c741d..5152dc11 100644 --- a/tests/common/models.py +++ b/tests/common/models.py @@ -107,12 +107,13 @@ def __init__(self, num_features: int = 1, device: str = 'cpu') -> None: # since we don't want to create duplicate references to the same module since that # will break mixed init. mlp.net[0].weight = mlp.net[-1].weight + mlp.net[0]._fsdp_wrap = False + mlp.net[-1]._fsdp_wrap = False def add_fsdp_wrap_attribute_to_children(self): - for child in self.children(): - child._fsdp_wrap = False # type: ignore - for child in self.module.children(): - child._fsdp_wrap = True # type: ignore + for module in self.module.modules(): + if not hasattr(module, '_fsdp_wrap'): + module._fsdp_wrap = True def param_init_fn(self, module: torch.nn.Module): init_fn = partial(torch.nn.init.normal_, mean=0.0, std=0.1) diff --git a/tests/test_utils.py b/tests/test_utils.py index ceb59dc8..458e5dbf 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -790,7 +790,7 @@ def _setup_fsdp_test_environment( # Create model if model is None: - model = ComposerMPTPairwiseOfflinePolicyLM(**model_config) + model = ComposerMPTPairwiseOfflinePolicyLM(**model_config).to('cuda') # Enable FSDP fsdp_config = {} @@ -801,7 +801,7 @@ def _setup_fsdp_test_environment( max_duration='1ba', ) - return trainer, trainer.state.model + return trainer, model @pytest.mark.gpu @@ -929,7 +929,8 @@ def test_summon_full_params_tied_weights_behavior( ): """Test summon_full_params with tied weights behavior verification.""" del world_size - model = PartialWeightTiedModel(num_features=2) + model = PartialWeightTiedModel(num_features=2).to('cuda') + model.add_fsdp_wrap_attribute_to_children() trainer, fsdp_model = _setup_fsdp_test_environment( tiny_gpt2_tokenizer, @@ -939,6 +940,8 @@ def test_summon_full_params_tied_weights_behavior( # fill the tied weights with 999.0 fsdp_model.module[0].net[0].weight.data.fill_(999.0) # type: ignore + # assert weight tying + assert torch.all(fsdp_model.module[0].net[-1].weight.data == 999.0) # type: ignore # Test writeback=False with summon_full_params(fsdp_model, writeback=False): @@ -996,7 +999,7 @@ def test_get_params_to_summon_fsdp2( """Test _get_params_to_summon_fsdp2 function with nested FSDP structure.""" del world_size - model = NestedFSDPModel(num_features=2) + model = NestedFSDPModel(num_features=2).to('cuda') model.add_fsdp_wrap_attribute_to_children() _, fsdp_model = _setup_fsdp_test_environment( From 7defc3e446fd3ed0abd9ba376a2bb08ed916f3eb Mon Sep 17 00:00:00 2001 From: root Date: Fri, 18 Jul 2025 21:55:59 +0000 Subject: [PATCH 05/15] formatted --- .../online/generation_utils/vllm_utils.py | 1 - compose_rl/utils/utils.py | 15 +++-- tests/common/models.py | 56 +++++++++++++------ tests/test_utils.py | 54 ++++++++++++------ 4 files changed, 86 insertions(+), 40 deletions(-) diff --git a/compose_rl/algorithms/online/generation_utils/vllm_utils.py b/compose_rl/algorithms/online/generation_utils/vllm_utils.py index 1e2f38bb..cfa03806 100644 --- a/compose_rl/algorithms/online/generation_utils/vllm_utils.py +++ b/compose_rl/algorithms/online/generation_utils/vllm_utils.py @@ -44,7 +44,6 @@ ) from torch.distributed.fsdp import FSDPModule from torch.distributed.fsdp import FullyShardedDataParallel as FSDP - # DTensor debugging imports from torch.distributed.tensor import DTensor diff --git a/compose_rl/utils/utils.py b/compose_rl/utils/utils.py index 8c5ae6d0..9c9f8dec 100644 --- a/compose_rl/utils/utils.py +++ b/compose_rl/utils/utils.py @@ -14,8 +14,8 @@ import torch.nn.functional as F from composer.utils import dist from kubernetes import client, config -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import FSDPModule +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.tensor import DTensor from torch.utils.data import DataLoader from transformers import PretrainedConfig @@ -1255,8 +1255,7 @@ def flatten(coll: Union[Iterable[Any], str]) -> Generator[Any, None, None]: def _get_params_to_summon_fsdp2(module: torch.nn.Module, recurse: bool = True): - """ - Gets the DTensors to materialize for an FSDP2 model based on recurse. + """Gets the DTensors to materialize for an FSDP2 model based on recurse. If recurse=False, we can encounter the following state: FSDPModule_1 @@ -1278,11 +1277,15 @@ def _get_params_to_summon_fsdp2(module: torch.nn.Module, recurse: bool = True): module.named_parameters(recurse=True, remove_duplicate=False) if isinstance(param, DTensor) } - + dtensor_params = {} + def _dfs(module: torch.nn.Module, prefix: str = ''): # Add all DTensors within this (FSDP)module - for name, param in module.named_parameters(recurse=False, remove_duplicate=False): + for name, param in module.named_parameters( + recurse=False, + remove_duplicate=False, + ): if isinstance(param, DTensor): full_name = f'{prefix}.{name}' if prefix else name dtensor_params[full_name] = param @@ -1291,6 +1294,7 @@ def _dfs(module: torch.nn.Module, prefix: str = ''): continue full_name = f'{prefix}.{child_name}' if prefix else child_name _dfs(child, full_name) + _dfs(module, '') return dtensor_params @@ -1447,6 +1451,7 @@ def is_fsdp2(model: torch.nn.Module) -> bool: ): yield + def is_model_fsdp(model: torch.nn.Module) -> bool: """Whether model has FSDP1/FSDP2 wrapped modules.""" if isinstance(model, (FSDP, FSDPModule)): diff --git a/tests/common/models.py b/tests/common/models.py index 5152dc11..47263002 100644 --- a/tests/common/models.py +++ b/tests/common/models.py @@ -106,14 +106,14 @@ def __init__(self, num_features: int = 1, device: str = 'cpu') -> None: # Adding mlp.fc1.weight = mlp.fc2.weight without assignment to self.fc1 and self.fc2 # since we don't want to create duplicate references to the same module since that # will break mixed init. - mlp.net[0].weight = mlp.net[-1].weight - mlp.net[0]._fsdp_wrap = False - mlp.net[-1]._fsdp_wrap = False + mlp.net[0].weight = mlp.net[-1].weight # type: ignore + mlp.net[0]._fsdp_wrap = False # type: ignore + mlp.net[-1]._fsdp_wrap = False # type: ignore def add_fsdp_wrap_attribute_to_children(self): for module in self.module.modules(): if not hasattr(module, '_fsdp_wrap'): - module._fsdp_wrap = True + module._fsdp_wrap = True # type: ignore def param_init_fn(self, module: torch.nn.Module): init_fn = partial(torch.nn.init.normal_, mean=0.0, std=0.1) @@ -126,7 +126,7 @@ def param_init_fn(self, module: torch.nn.Module): class NestedFSDPModel(ComposerClassifier): """Model to test nested FSDP structure for _get_params_to_summon_fsdp2. - + Creates the following structure: FSDPModule_1 (root) |- weight (DTensor) <- 1s @@ -144,25 +144,45 @@ class NestedFSDPModel(ComposerClassifier): def __init__(self, num_features: int = 2, device: str = 'cpu') -> None: # Root level linear layer (will be FSDPModule_1) - root_linear = torch.nn.Linear(num_features, num_features, device=device, bias=False) + root_linear = torch.nn.Linear( + num_features, + num_features, + device=device, + bias=False, + ) root_linear.weight.data.fill_(1.0) # All 1s - + # Nested FSDP module (FSDPModule_2) - nested_fsdp_linear = torch.nn.Linear(num_features, num_features, device=device, bias=False) + nested_fsdp_linear = torch.nn.Linear( + num_features, + num_features, + device=device, + bias=False, + ) nested_fsdp_linear.weight.data.fill_(2.0) # All 2s - + # Regular module containing a linear layer and nested FSDP - regular_linear = torch.nn.Linear(num_features, num_features, device=device, bias=False) + regular_linear = torch.nn.Linear( + num_features, + num_features, + device=device, + bias=False, + ) regular_linear.weight.data.fill_(3.0) # All 3s - nested_fsdp_in_regular = torch.nn.Linear(num_features, num_features, device=device, bias=False) + nested_fsdp_in_regular = torch.nn.Linear( + num_features, + num_features, + device=device, + bias=False, + ) nested_fsdp_in_regular.weight.data.fill_(4.0) # All 4s - + # Create the nested structure regular_module = torch.nn.Sequential( regular_linear, nested_fsdp_in_regular, ) - + # Main network structure net = torch.nn.Sequential( root_linear, @@ -175,11 +195,11 @@ def __init__(self, num_features: int = 2, device: str = 'cpu') -> None: self.module.param_init_fn = self.param_init_fn # pyright: ignore[reportGeneralTypeIssues] def add_fsdp_wrap_attribute_to_children(self): - self.module[0]._fsdp_wrap = False - self.module[1]._fsdp_wrap = True - self.module[2]._fsdp_wrap = False - self.module[2][0]._fsdp_wrap = False - self.module[2][1]._fsdp_wrap = True + self.module[0]._fsdp_wrap = False # type: ignore + self.module[1]._fsdp_wrap = True # type: ignore + self.module[2]._fsdp_wrap = False # type: ignore + self.module[2][0]._fsdp_wrap = False # type: ignore + self.module[2][1]._fsdp_wrap = True # type: ignore def param_init_fn(self, module: torch.nn.Module): init_fn = partial(torch.nn.init.normal_, mean=0.0, std=0.1) diff --git a/tests/test_utils.py b/tests/test_utils.py index 458e5dbf..413b75c7 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -7,21 +7,21 @@ import pytest import torch import torch.nn.functional as F +from torch.distributed.tensor import DTensor from transformers import PreTrainedTokenizer from compose_rl.utils import mask_eos from compose_rl.utils.utils import ( + _get_params_to_summon_fsdp2, get_entropies, get_sequence_entropies, get_token_entropies, masked_mean, sample_wise_masked_mean, summon_full_params, - _get_params_to_summon_fsdp2, ) from tests.common.markers import world_size from tests.common.models import NestedFSDPModel, PartialWeightTiedModel -from torch.distributed.tensor import DTensor def test_mask_eos_basic_functionality(): @@ -823,7 +823,7 @@ def get_total_param_size(model: torch.nn.Module): total_size = 0 for param in model.parameters(): if hasattr(param, 'to_local'): - param = param.to_local() + param = param.to_local() # type: ignore if param.data is not None: total_size += param.data.numel() return total_size @@ -941,7 +941,9 @@ def test_summon_full_params_tied_weights_behavior( # fill the tied weights with 999.0 fsdp_model.module[0].net[0].weight.data.fill_(999.0) # type: ignore # assert weight tying - assert torch.all(fsdp_model.module[0].net[-1].weight.data == 999.0) # type: ignore + assert torch.all( + fsdp_model.module[0].net[-1].weight.data == 999.0, # type: ignore + ) # type: ignore # Test writeback=False with summon_full_params(fsdp_model, writeback=False): @@ -1001,26 +1003,46 @@ def test_get_params_to_summon_fsdp2( model = NestedFSDPModel(num_features=2).to('cuda') model.add_fsdp_wrap_attribute_to_children() - + _, fsdp_model = _setup_fsdp_test_environment( tiny_gpt2_tokenizer, fsdp_version=2, model=model, ) - dtensor_params_recurse = _get_params_to_summon_fsdp2(fsdp_model.module, recurse=True) - dtensor_params_no_recurse = _get_params_to_summon_fsdp2(fsdp_model.module, recurse=False) + dtensor_params_recurse = _get_params_to_summon_fsdp2( + fsdp_model.module, # type: ignore + recurse=True, + ) + dtensor_params_no_recurse = _get_params_to_summon_fsdp2( + fsdp_model.module, # type: ignore + recurse=False, + ) # Assert all are DTensors for param in dtensor_params_recurse.values(): - assert isinstance(param, DTensor), f"Parameter {param.name} should be a DTensor" + assert isinstance( + param, + DTensor, + ), f"Parameter {param.name} should be a DTensor" for param in dtensor_params_no_recurse.values(): - assert isinstance(param, DTensor), f"Parameter {param.name} should be a DTensor" - - assert len(dtensor_params_recurse) == 4, "Should have 4 DTensors" - for (name, param), value in zip(dtensor_params_recurse.items(), [1.0, 2.0, 3.0, 4.0]): - assert torch.all(param.data == value), f"Parameter {name} should have value {value}" - assert len(dtensor_params_no_recurse) == 2, "Should have 2 DTensors" - for (name, param), value in zip(dtensor_params_no_recurse.items(), [1.0, 3.0]): - assert torch.all(param.data == value), f"Parameter {name} should have value {value}" + assert isinstance( + param, + DTensor, + ), f"Parameter {param.name} should be a DTensor" + + assert len(dtensor_params_recurse) == 4, 'Should have 4 DTensors' + for ( + name, + param, + ), value in zip(dtensor_params_recurse.items(), [1.0, 2.0, 3.0, 4.0]): + assert torch.all( + param.data == value, + ), f"Parameter {name} should have value {value}" + assert len(dtensor_params_no_recurse) == 2, 'Should have 2 DTensors' + for (name, + param), value in zip(dtensor_params_no_recurse.items(), [1.0, 3.0]): + assert torch.all( + param.data == value, + ), f"Parameter {name} should have value {value}" os.environ['FSDP_VERSION'] = '1' From 9e569908971bc6927855e3918356117b3a37ed1c Mon Sep 17 00:00:00 2001 From: root Date: Fri, 18 Jul 2025 22:05:45 +0000 Subject: [PATCH 06/15] comment changes --- .../online/generation_utils/vllm_utils.py | 45 +++---------------- tests/test_offline.py | 4 -- 2 files changed, 6 insertions(+), 43 deletions(-) diff --git a/compose_rl/algorithms/online/generation_utils/vllm_utils.py b/compose_rl/algorithms/online/generation_utils/vllm_utils.py index cfa03806..72a20b56 100644 --- a/compose_rl/algorithms/online/generation_utils/vllm_utils.py +++ b/compose_rl/algorithms/online/generation_utils/vllm_utils.py @@ -44,7 +44,6 @@ ) from torch.distributed.fsdp import FSDPModule from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -# DTensor debugging imports from torch.distributed.tensor import DTensor from compose_rl.algorithms.online.generation_utils.vllm_actor import LLMRayActor @@ -280,33 +279,6 @@ def create_vllm_engines( return vllm_engines -def build_param_fullnames(top_module: nn.Module) -> dict: - """Builds a mapping of parameter objects to their fully-qualified names. - - Traverses the entire model from the top level and map each parameter - object to its fully-qualified name (e.g., - "lm_backbone.layer1.mlp.down_proj.weight"). - - Args: - top_module (nn.Module): The top-level module to traverse. - """ - param2fullname = {} - - def _dfs(current_module: nn.Module, prefix: str = ''): - # Get local parameters (without recursing into children). - for local_name, param in current_module.named_parameters(recurse=False): - full_name = f'{prefix}.{local_name}' if prefix else local_name - param2fullname[param] = full_name - - # Recurse on child modules. - for child_name, child_module in current_module.named_children(): - child_prefix = f'{prefix}.{child_name}' if prefix else child_name - _dfs(child_module, prefix=child_prefix) - - _dfs(top_module) - return param2fullname - - def simplify_param_path(path: str) -> str: """Simplifies the parameter path by removing unnecessary parts. @@ -315,7 +287,6 @@ def simplify_param_path(path: str) -> str: """ # Parts we want to remove remove_parts = [ - '_wrapped_module', '_fsdp_wrapped_module', '_checkpoint_wrapped_module', 'lm_backbone', @@ -412,7 +383,7 @@ def broadcast_to_vllm( loss_type (str): The loss type which decides whether to use critic-free or not. Defaults to `ppo`. enable_prefix_caching (bool): Whether to enable prefix caching. Defaults to `False`. """ - # To avoid OOM + # avoid OOM torch.cuda.empty_cache() if loss_type == OnPolicyEnum.PPO: # Extract the lm_backbone params from the model @@ -486,10 +457,11 @@ def broadcast_to_vllm( continue seen_fsdp_modules.add(module) - # Materializes parameters for this specific FSDP module specifically. - # Note that this also materializes params for all submodules that are - # not FSDP-wrapped themselves. We don't want to materialize the entire - # model to avoid potential OOM. + # Materializes parameters for this specific FSDP module only BUT THIS + # INCLUDES any parameters from submodules that are not FSDP-wrapped themselves. + # We don't want to materialize the entire model to avoid potential OOM. + # View NestedFSDPModel in tests/common/models.py and the related test in + # test_utils.py for an example. with summon_full_params( module, writeback=False, @@ -511,11 +483,6 @@ def broadcast_to_vllm( # and we only want to broadcast the summoned parameters. # Encountering this conditional implies that a FSDP-wrapped submodule # exists and will later be summoned to materialize this parameter. - # - # It seems that for FSDP1, summon_full_params takes control of all parameters - # within its scope, including any parameters from submodules that are not - # FSDP-wrapped themselves. View NestedFSDPModel in tests/common/models.py - # as an example. if isinstance(param, DTensor): continue diff --git a/tests/test_offline.py b/tests/test_offline.py index 5bd52548..2838700b 100644 --- a/tests/test_offline.py +++ b/tests/test_offline.py @@ -174,14 +174,11 @@ def test_model_forward(tiny_gpt2_tokenizer: PreTrainedTokenizer): @pytest.mark.gpu @world_size(2) @pytest.mark.parametrize('fsdp_config', [None, {}]) # type: ignore -@pytest.mark.parametrize('fsdp_version', [1, 2]) def test_train( tiny_gpt2_tokenizer: PreTrainedTokenizer, world_size: int, fsdp_config: dict[str, Any], - fsdp_version: int, ): - os.environ['FSDP_VERSION'] = str(fsdp_version) max_seq_len = 10 dataset = PairwisePreference(max_seq_len=max_seq_len) dataloader = DataLoader( @@ -217,7 +214,6 @@ def test_train( max_duration='1ep', ) trainer.fit() - os.environ['FSDP_VERSION'] = '1' @pytest.mark.skip( From 786bf6c59ed5f5fd4a7a3b49db32adad40b358f4 Mon Sep 17 00:00:00 2001 From: root Date: Fri, 18 Jul 2025 22:45:51 +0000 Subject: [PATCH 07/15] formatted strings --- tests/test_utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index 413b75c7..87977870 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1024,12 +1024,12 @@ def test_get_params_to_summon_fsdp2( assert isinstance( param, DTensor, - ), f"Parameter {param.name} should be a DTensor" + ), f'Parameter {param.name} should be a DTensor' for param in dtensor_params_no_recurse.values(): assert isinstance( param, DTensor, - ), f"Parameter {param.name} should be a DTensor" + ), f'Parameter {param.name} should be a DTensor' assert len(dtensor_params_recurse) == 4, 'Should have 4 DTensors' for ( @@ -1038,11 +1038,11 @@ def test_get_params_to_summon_fsdp2( ), value in zip(dtensor_params_recurse.items(), [1.0, 2.0, 3.0, 4.0]): assert torch.all( param.data == value, - ), f"Parameter {name} should have value {value}" + ), f'Parameter {name} should have value {value}' assert len(dtensor_params_no_recurse) == 2, 'Should have 2 DTensors' for (name, param), value in zip(dtensor_params_no_recurse.items(), [1.0, 3.0]): assert torch.all( param.data == value, - ), f"Parameter {name} should have value {value}" + ), f'Parameter {name} should have value {value}' os.environ['FSDP_VERSION'] = '1' From 46180342b32c6f0eebe796e5d9304e2fa3bf8169 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 21 Jul 2025 20:55:12 +0000 Subject: [PATCH 08/15] should be the same thing --- compose_rl/utils/utils.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/compose_rl/utils/utils.py b/compose_rl/utils/utils.py index 9c9f8dec..6ce65b36 100644 --- a/compose_rl/utils/utils.py +++ b/compose_rl/utils/utils.py @@ -1270,16 +1270,7 @@ def _get_params_to_summon_fsdp2(module: torch.nn.Module, recurse: bool = True): alongside the original FSDPModule_1.weight. Therefore, we use a dfs traversal to get all DTensors not owned by downstream FSDPModules. """ - if recurse: - return { - name: param - for name, param in - module.named_parameters(recurse=True, remove_duplicate=False) - if isinstance(param, DTensor) - } - dtensor_params = {} - def _dfs(module: torch.nn.Module, prefix: str = ''): # Add all DTensors within this (FSDP)module for name, param in module.named_parameters( @@ -1290,7 +1281,7 @@ def _dfs(module: torch.nn.Module, prefix: str = ''): full_name = f'{prefix}.{name}' if prefix else name dtensor_params[full_name] = param for child_name, child in module.named_children(): - if isinstance(child, FSDPModule): + if isinstance(child, FSDPModule) and not recurse: continue full_name = f'{prefix}.{child_name}' if prefix else child_name _dfs(child, full_name) From 8bdd61e930d15d70ba851ea455d1aa7ef1f3cef6 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 23 Jul 2025 03:06:27 +0000 Subject: [PATCH 09/15] moved code to composer --- .../online/generation_utils/vllm_utils.py | 12 +- compose_rl/algorithms/online/hf_utils.py | 4 +- compose_rl/algorithms/online/model.py | 4 +- .../algorithms/reward_modeling/model.py | 4 +- compose_rl/utils/utils.py | 202 ------------ tests/common/models.py | 210 ------------ tests/test_utils.py | 312 ------------------ 7 files changed, 17 insertions(+), 731 deletions(-) delete mode 100644 tests/common/models.py diff --git a/compose_rl/algorithms/online/generation_utils/vllm_utils.py b/compose_rl/algorithms/online/generation_utils/vllm_utils.py index 72a20b56..5498c119 100644 --- a/compose_rl/algorithms/online/generation_utils/vllm_utils.py +++ b/compose_rl/algorithms/online/generation_utils/vllm_utils.py @@ -51,7 +51,7 @@ ALGORITHM_TYPE, OnPolicyEnum, ) -from compose_rl.utils.utils import summon_full_params +from composer.distributed.shared_utils import get_summon_params_fn log = logging.getLogger(__name__) @@ -442,6 +442,10 @@ def broadcast_to_vllm( start_time = time.time() update_time = 0 + # Getting the correct summon_full_params function based on whether + # the model is FSDP1 vs FSDP2. + summon_full_params = get_summon_params_fn(model) + for module_name, module in model.named_modules(): # Skip non-FSDP modules if not isinstance(module, (FSDP, FSDPModule)): @@ -460,8 +464,8 @@ def broadcast_to_vllm( # Materializes parameters for this specific FSDP module only BUT THIS # INCLUDES any parameters from submodules that are not FSDP-wrapped themselves. # We don't want to materialize the entire model to avoid potential OOM. - # View NestedFSDPModel in tests/common/models.py and the related test in - # test_utils.py for an example. + # View NestedFSDPModel in the Composer repo and the related test in + # for an example of why this logic is needed. with summon_full_params( module, writeback=False, @@ -469,7 +473,7 @@ def broadcast_to_vllm( recurse=False, ): # Note: For the following module.named_parameters(), we have to use recurse=True - # since the following case is possible: + # since the following case is possible where we still need NonFSDP_Child's params # FSDP_Module # |- direct_param (found with recurse=False) # |- NonFSDP_Child diff --git a/compose_rl/algorithms/online/hf_utils.py b/compose_rl/algorithms/online/hf_utils.py index 5f91a5d2..85c0aac4 100644 --- a/compose_rl/algorithms/online/hf_utils.py +++ b/compose_rl/algorithms/online/hf_utils.py @@ -22,7 +22,8 @@ ) from compose_rl.algorithms.online.policy_configuration import HFPolicyConfig from compose_rl.utils.consts import _MASTER_WEIGHTS_PRECISION -from compose_rl.utils.utils import is_model_fsdp, summon_full_params +from composer.distributed.shared_utils import get_summon_params_fn +from composer.utils import is_model_fsdp Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] @@ -98,6 +99,7 @@ def generate( # are in the root FSDP module, and are summoned by the below context manager. See https://github.com/pytorch/pytorch/issues/100069 # for more info. # Note: We use recurse=False here so that we only summon full params for the LM head, not the entire model. + summon_full_params = get_summon_params_fn(self.lm_backbone) with summon_full_params( self.lm_backbone, writeback=False, diff --git a/compose_rl/algorithms/online/model.py b/compose_rl/algorithms/online/model.py index 09efca01..9264602c 100644 --- a/compose_rl/algorithms/online/model.py +++ b/compose_rl/algorithms/online/model.py @@ -28,7 +28,8 @@ clear_mb_load_balancing_loss, get_mb_load_balancing_loss, ) -from compose_rl.utils.utils import is_model_fsdp, summon_full_params +from composer.utils import is_model_fsdp +from composer.distributed.shared_utils import get_summon_params_fn Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] @@ -190,6 +191,7 @@ def generate(self, input_ids: torch.Tensor, *args: Any, **kwargs: Any): # are in the root FSDP module, and are summoned by the below context manager. See https://github.com/pytorch/pytorch/issues/100069 # for more info. # Note: We use recurse=False here so that we only summon full params for the LM head, not the entire model. + summon_full_params = get_summon_params_fn(self.model) with summon_full_params( self.model, writeback=False, diff --git a/compose_rl/algorithms/reward_modeling/model.py b/compose_rl/algorithms/reward_modeling/model.py index bf2ac0a9..908711a2 100644 --- a/compose_rl/algorithms/reward_modeling/model.py +++ b/compose_rl/algorithms/reward_modeling/model.py @@ -30,7 +30,8 @@ ComposerHFSequenceClassification from compose_rl.algorithms.reward_modeling.modeling_mpt import \ MPTForSequenceClassification -from compose_rl.utils.utils import is_model_fsdp, summon_full_params +from composer.utils import is_model_fsdp +from composer.distributed.shared_utils import get_summon_params_fn log = logging.getLogger(__name__) @@ -291,6 +292,7 @@ def mask_last_embed_except_eos( context_manager = nullcontext if is_model_fsdp(self.model): + summon_full_params = get_summon_params_fn(self.model) context_manager = partial( summon_full_params, self.model, diff --git a/compose_rl/utils/utils.py b/compose_rl/utils/utils.py index 6ce65b36..863745b5 100644 --- a/compose_rl/utils/utils.py +++ b/compose_rl/utils/utils.py @@ -5,7 +5,6 @@ import re import warnings from collections.abc import Generator, Iterable -from contextlib import contextmanager from typing import Any, Optional, Union import spacy @@ -15,8 +14,6 @@ from composer.utils import dist from kubernetes import client, config from torch.distributed.fsdp import FSDPModule -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from torch.distributed.tensor import DTensor from torch.utils.data import DataLoader from transformers import PretrainedConfig @@ -1252,202 +1249,3 @@ def flatten(coll: Union[Iterable[Any], str]) -> Generator[Any, None, None]: yield subc else: yield i - - -def _get_params_to_summon_fsdp2(module: torch.nn.Module, recurse: bool = True): - """Gets the DTensors to materialize for an FSDP2 model based on recurse. - - If recurse=False, we can encounter the following state: - FSDPModule_1 - |- weight (DTensor) - |- FSDPModule_2 - | |- weight (DTensor) - |- RegularModule_1 - | |- weight (DTensor) - | |- FSDPModule_3 - | | |- weight (DTensor) - Where summon_full_params(FSDPModule_1) should materialize RegularModule_1.weight - alongside the original FSDPModule_1.weight. Therefore, we use a dfs traversal - to get all DTensors not owned by downstream FSDPModules. - """ - dtensor_params = {} - def _dfs(module: torch.nn.Module, prefix: str = ''): - # Add all DTensors within this (FSDP)module - for name, param in module.named_parameters( - recurse=False, - remove_duplicate=False, - ): - if isinstance(param, DTensor): - full_name = f'{prefix}.{name}' if prefix else name - dtensor_params[full_name] = param - for child_name, child in module.named_children(): - if isinstance(child, FSDPModule) and not recurse: - continue - full_name = f'{prefix}.{child_name}' if prefix else child_name - _dfs(child, full_name) - - _dfs(module, '') - return dtensor_params - - -@contextmanager -def _summon_full_params_fsdp2( - model: torch.nn.Module, - writeback: bool = True, - recurse: bool = True, -): - """Context manager to get full params for FSDP2 models with DTensor APIs. - - Note: We use DTensor APIs to materialize the full parameters instead of using `unshard` - and `reshard` as writeback doesn't seem to work correctly with DTensors - that uses DTensor APIs to materialize the full parameters. We currently don't support - rank0_only - """ - from torch.distributed.tensor import Replicate, distribute_tensor - - dtensor_params = _get_params_to_summon_fsdp2(model, recurse=recurse) - - if not dtensor_params: - yield - return - - model_dtensors = {} - metadata = {} - tied_params = {} - - # We want to get the module and attr of the param, so we can assign - # module.attr = param.full_tensor() before we yield and - # module.attr = distributed (maybe updated) tensor after we yield. - def _get_module_and_attr(model: torch.nn.Module, param_name: str): - parts = param_name.split('.') - module = model - for part in parts[:-1]: - module = getattr(module, part) - return module, parts[-1] - - # Group parameters by their underlying tensor to handle tied parameters - tensor_to_names = {} - for name, dtensor_param in dtensor_params.items(): - tensor_id = id(dtensor_param) - if tensor_id not in tensor_to_names: - tensor_to_names[tensor_id] = [] - tensor_to_names[tensor_id].append(name) - - # Process parameters, handling tied parameters correctly - processed_tensors = set() - for name, dtensor_param in dtensor_params.items(): - tensor_id = id(dtensor_param) - - metadata[name] = { - 'device_mesh': dtensor_param.device_mesh, - 'placements': dtensor_param.placements, - 'requires_grad': dtensor_param.requires_grad, - } - model_dtensors[name] = dtensor_param - - # Only materialize the full tensor once per unique tensor - if tensor_id not in processed_tensors: - full_tensor = dtensor_param.full_tensor() - new_param = torch.nn.Parameter(full_tensor.detach().clone()) - - # Set the same parameter instance for all tied parameters - for tied_name in tensor_to_names[tensor_id]: - module, attr_name = _get_module_and_attr(model, tied_name) - setattr(module, attr_name, new_param) - tied_params[tied_name] = new_param - - processed_tensors.add(tensor_id) - - try: - yield - finally: - # Process tied parameters to ensure writeback works correctly - processed_tensors = set() - tensor_to_updated_dtensor = {} - - for name in dtensor_params.keys(): - module, attr_name = _get_module_and_attr(model, name) - tensor_id = id(model_dtensors[name]) - - if writeback and tensor_id not in processed_tensors: - # We update model_dtensors[name] to use the updated param - # after the model changes. For tied parameters, we only need - # to do this once per unique tensor. - current_param = getattr(module, attr_name) - if hasattr( - current_param, - 'data', - ) and current_param.data is not None: - meta = metadata[name] - replicated = distribute_tensor( - current_param.data, - meta['device_mesh'], - [Replicate()], - ) - sharded = replicated.redistribute( - meta['device_mesh'], - meta['placements'], - ) - new_param = torch.nn.Parameter(sharded) - new_param.requires_grad = meta['requires_grad'] - tensor_to_updated_dtensor[tensor_id] = new_param - processed_tensors.add(tensor_id) - - # Restore the appropriate DTensor for this parameter - if writeback and tensor_id in tensor_to_updated_dtensor: - setattr(module, attr_name, tensor_to_updated_dtensor[tensor_id]) - else: - setattr(module, attr_name, model_dtensors[name]) - - -@contextmanager -def summon_full_params( - model: torch.nn.Module, - writeback: bool = True, - recurse: bool = True, - rank0_only: bool = False, -): - """Context manager to summon full parameters for an FSDP(1/2) model. - - Args: - model (torch.nn.Module): The FSDP model to summon full parameters for. - writeback (bool): Whether to write back parameter changes. Defaults to False. - recurse (bool): Whether to recurse into submodules. Defaults to True. - rank0_only (bool): Whether to summon full parameters on only rank 0. Defaults to False. - Only supported for FSDP1. FSDP2 by default materializes all parameters on all ranks. - """ - - def is_fsdp2(model: torch.nn.Module) -> bool: - try: - for module in model.modules(): - if isinstance(module, FSDPModule): - return True - except ImportError: - pass - return False - - if is_fsdp2(model): - with _summon_full_params_fsdp2( - model, - writeback=writeback, - recurse=recurse, - ): - yield - else: - with FSDP.summon_full_params( - model, - writeback=writeback, - recurse=recurse, - rank0_only=rank0_only, - ): - yield - - -def is_model_fsdp(model: torch.nn.Module) -> bool: - """Whether model has FSDP1/FSDP2 wrapped modules.""" - if isinstance(model, (FSDP, FSDPModule)): - return True - for module in model.modules(): - if isinstance(module, (FSDP, FSDPModule)): - return True - return False diff --git a/tests/common/models.py b/tests/common/models.py deleted file mode 100644 index 47263002..00000000 --- a/tests/common/models.py +++ /dev/null @@ -1,210 +0,0 @@ -# Copyright 2024 MosaicML ComposeRL authors -# SPDX-License-Identifier: Apache-2.0 - -from functools import partial - -import torch -from composer.models import ComposerClassifier - - -class SimpleMLP(torch.nn.Module): - def __init__(self, num_features: int, device: str = 'cpu'): - super().__init__() - fc1 = torch.nn.Linear( - num_features, - num_features, - device=device, - bias=False, - ) - fc2 = torch.nn.Linear( - num_features, - num_features, - device=device, - bias=False, - ) - - self.net = torch.nn.Sequential(fc1, torch.nn.ReLU(), fc2) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.net(x) - - -class SimpleWeightTiedModel(ComposerClassifier): - """Small classification model with tied weights. - - Args: - num_features (int): number of input features (default: 1) - device (str): the device to initialize the model (default: 'cpu') - """ - - def __init__(self, num_features: int = 1, device: str = 'cpu') -> None: - self.num_features = num_features - - mlp = SimpleMLP(num_features, device) - - net = torch.nn.Sequential( - mlp, - torch.nn.Softmax(dim=-1), - ) - - super().__init__(module=net, num_classes=num_features) - - self.module.param_init_fn = self.param_init_fn # pyright: ignore[reportGeneralTypeIssues] - - # Adding mlp.fc1.weight = mlp.fc2.weight without assignment to self.fc1 and self.fc2 - # since we don't want to create duplicate references to the same module - # since that will break mixed init. - mlp.net[0].weight = mlp.net[-1].weight - - def add_fsdp_wrap_attribute_to_children(self): - for child in self.children(): - child._fsdp_wrap = False # type: ignore - for child in self.module.children(): - child._fsdp_wrap = True # type: ignore - - def param_init_fn(self, module: torch.nn.Module): - init_fn = partial(torch.nn.init.normal_, mean=0.0, std=0.1) - - if isinstance(module, torch.nn.Linear): - init_fn(module.weight) - if module.bias is not None: # pyright: ignore[reportUnnecessaryComparison] - torch.nn.init.zeros_(module.bias) - - -class PartialWeightTiedModel(ComposerClassifier): - """Small classification model with partially tied weights. - - Args: - num_features (int): number of input features (default: 1) - device (str): the device to initialize the model (default: 'cpu') - """ - - def __init__(self, num_features: int = 1, device: str = 'cpu') -> None: - mlp = SimpleMLP(num_features, device) - - # a third fc layer that is not tied to the above mlp - fc3 = torch.nn.Linear( - num_features, - num_features, - device=device, - bias=False, - ) - - net = torch.nn.Sequential( - mlp, - fc3, - torch.nn.Softmax(dim=-1), - ) - - # fc1 would be a child module of the Sequential module now but only the mlp should be FSDP wrapped - # TODO support this or add negative test for this - # net.fc1 = mlp.fc1 - - super().__init__(module=net, num_classes=num_features) - self.module.param_init_fn = self.param_init_fn # pyright: ignore[reportGeneralTypeIssues] - - # Adding mlp.fc1.weight = mlp.fc2.weight without assignment to self.fc1 and self.fc2 - # since we don't want to create duplicate references to the same module since that - # will break mixed init. - mlp.net[0].weight = mlp.net[-1].weight # type: ignore - mlp.net[0]._fsdp_wrap = False # type: ignore - mlp.net[-1]._fsdp_wrap = False # type: ignore - - def add_fsdp_wrap_attribute_to_children(self): - for module in self.module.modules(): - if not hasattr(module, '_fsdp_wrap'): - module._fsdp_wrap = True # type: ignore - - def param_init_fn(self, module: torch.nn.Module): - init_fn = partial(torch.nn.init.normal_, mean=0.0, std=0.1) - - if isinstance(module, torch.nn.Linear): - init_fn(module.weight) - if module.bias is not None: # pyright: ignore[reportUnnecessaryComparison] - torch.nn.init.zeros_(module.bias) - - -class NestedFSDPModel(ComposerClassifier): - """Model to test nested FSDP structure for _get_params_to_summon_fsdp2. - - Creates the following structure: - FSDPModule_1 (root) - |- weight (DTensor) <- 1s - |- FSDPModule_2 (nested FSDP) - | |- weight (DTensor) <- 2s - |- RegularModule_1 (regular module) - | |- weight (DTensor) <- 3s - | |- FSDPModule_3 (nested FSDP inside regular module) - | | |- weight (DTensor) <- 4s - - Args: - num_features (int): number of input features (default: 2) - device (str): the device to initialize the model (default: 'cpu') - """ - - def __init__(self, num_features: int = 2, device: str = 'cpu') -> None: - # Root level linear layer (will be FSDPModule_1) - root_linear = torch.nn.Linear( - num_features, - num_features, - device=device, - bias=False, - ) - root_linear.weight.data.fill_(1.0) # All 1s - - # Nested FSDP module (FSDPModule_2) - nested_fsdp_linear = torch.nn.Linear( - num_features, - num_features, - device=device, - bias=False, - ) - nested_fsdp_linear.weight.data.fill_(2.0) # All 2s - - # Regular module containing a linear layer and nested FSDP - regular_linear = torch.nn.Linear( - num_features, - num_features, - device=device, - bias=False, - ) - regular_linear.weight.data.fill_(3.0) # All 3s - nested_fsdp_in_regular = torch.nn.Linear( - num_features, - num_features, - device=device, - bias=False, - ) - nested_fsdp_in_regular.weight.data.fill_(4.0) # All 4s - - # Create the nested structure - regular_module = torch.nn.Sequential( - regular_linear, - nested_fsdp_in_regular, - ) - - # Main network structure - net = torch.nn.Sequential( - root_linear, - nested_fsdp_linear, - regular_module, - torch.nn.Softmax(dim=-1), - ) - - super().__init__(module=net, num_classes=num_features) - self.module.param_init_fn = self.param_init_fn # pyright: ignore[reportGeneralTypeIssues] - - def add_fsdp_wrap_attribute_to_children(self): - self.module[0]._fsdp_wrap = False # type: ignore - self.module[1]._fsdp_wrap = True # type: ignore - self.module[2]._fsdp_wrap = False # type: ignore - self.module[2][0]._fsdp_wrap = False # type: ignore - self.module[2][1]._fsdp_wrap = True # type: ignore - - def param_init_fn(self, module: torch.nn.Module): - init_fn = partial(torch.nn.init.normal_, mean=0.0, std=0.1) - - if isinstance(module, torch.nn.Linear): - init_fn(module.weight) - if module.bias is not None: # pyright: ignore[reportUnnecessaryComparison] - torch.nn.init.zeros_(module.bias) diff --git a/tests/test_utils.py b/tests/test_utils.py index 87977870..c913ccbc 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,27 +1,18 @@ # Copyright 2024 MosaicML ComposeRL authors # SPDX-License-Identifier: Apache-2.0 -import os -from typing import Optional - import pytest import torch import torch.nn.functional as F -from torch.distributed.tensor import DTensor -from transformers import PreTrainedTokenizer from compose_rl.utils import mask_eos from compose_rl.utils.utils import ( - _get_params_to_summon_fsdp2, get_entropies, get_sequence_entropies, get_token_entropies, masked_mean, sample_wise_masked_mean, - summon_full_params, ) -from tests.common.markers import world_size -from tests.common.models import NestedFSDPModel, PartialWeightTiedModel def test_mask_eos_basic_functionality(): @@ -743,306 +734,3 @@ def test_get_entropies_integration(): peaky_probs = F.softmax(peaky_logits, dim=0) expected_entropy = -torch.sum(peaky_probs * torch.log(peaky_probs)) assert torch.isclose(entropies[1], expected_entropy, atol=1e-5) - - -def _setup_fsdp_test_environment( - tiny_gpt2_tokenizer: PreTrainedTokenizer, - fsdp_version: int, - model: Optional[torch.nn.Module] = None, -): - """Helper function to set up FSDP test environment.""" - import os - from functools import partial - - from composer import Trainer - from composer.utils import dist - from torch.utils.data import DataLoader - - from compose_rl.algorithms.offline import ComposerMPTPairwiseOfflinePolicyLM - from compose_rl.data import pairwise_preference_dataset_collate_fn - from tests.common import PairwisePreference - - # Set FSDP version - os.environ['FSDP_VERSION'] = str(fsdp_version) - - # Create a dataset and dataloader - max_seq_len = 10 - dataset = PairwisePreference(max_seq_len=max_seq_len) - dataloader = DataLoader( - dataset, - collate_fn=partial( - pairwise_preference_dataset_collate_fn, - tiny_gpt2_tokenizer, - max_seq_len, - ), - sampler=dist.get_sampler(dataset), - batch_size=2, - ) - - # Create model config - model_config = { - 'n_layers': 1, - 'attn_config': { - 'attn_impl': 'torch', - }, - 'tokenizer': tiny_gpt2_tokenizer, - } - - # Create model - if model is None: - model = ComposerMPTPairwiseOfflinePolicyLM(**model_config).to('cuda') - - # Enable FSDP - fsdp_config = {} - trainer = Trainer( - model=model, # type: ignore - train_dataloader=dataloader, - parallelism_config={'fsdp': fsdp_config}, - max_duration='1ba', - ) - - return trainer, model - - -@pytest.mark.gpu -@world_size(2) -@pytest.mark.parametrize('fsdp_version', [1, 2]) -def test_summon_full_params( - tiny_gpt2_tokenizer: PreTrainedTokenizer, - world_size: int, - fsdp_version: int, -): - """Test summon_full_params actually works with FSDP(1/2) models.""" - del world_size - trainer, fsdp_model = _setup_fsdp_test_environment( - tiny_gpt2_tokenizer, - fsdp_version, - ) - - def get_total_param_size(model: torch.nn.Module): - total_size = 0 - for param in model.parameters(): - if hasattr(param, 'to_local'): - param = param.to_local() # type: ignore - if param.data is not None: - total_size += param.data.numel() - return total_size - - distributed_param_size = get_total_param_size(fsdp_model) - - # Test with writeback=True - with summon_full_params(fsdp_model): - local_param_size = get_total_param_size(fsdp_model) - - assert local_param_size > distributed_param_size * 1.5, \ - f'Local param size {local_param_size} should be > 1.5x distributed param size {distributed_param_size}' - - trainer.close() - os.environ['FSDP_VERSION'] = '1' - - -@pytest.mark.gpu -@world_size(2) -@pytest.mark.parametrize('fsdp_version', [1, 2]) -def test_summon_full_params_with_fsdp_writeback( - tiny_gpt2_tokenizer: PreTrainedTokenizer, - world_size: int, - fsdp_version: int, -): - """Test summon_full_params with actual FSDP models.""" - del world_size - trainer, fsdp_model = _setup_fsdp_test_environment( - tiny_gpt2_tokenizer, - fsdp_version, - ) - - original_local_tensors = { - name: param.data.clone() for name, param in fsdp_model.named_parameters() - } - - # Test out writeback=False - with summon_full_params(fsdp_model, writeback=False): - # Modify parameters inside the context - for name, param in fsdp_model.named_parameters(): - if param.data is not None: # type: ignore - param.data.fill_(777.0) - - for name, param in fsdp_model.named_parameters(): - if param.data is not None: # type: ignore - assert torch.all( - param.data == original_local_tensors[name], - ), f'Parameter {name} should not be modified with writeback=False' - - # Test with writeback=True - with summon_full_params(fsdp_model, writeback=True): - for name, param in fsdp_model.named_parameters(): - if param.data is not None: # type: ignore - param.data.fill_(888.0) - - for name, param in fsdp_model.named_parameters(): - if param.data is not None: # type: ignore - assert torch.all( - param.data == 888.0, - ), f'Parameter {name} should be modified with writeback=True' - - trainer.close() - os.environ['FSDP_VERSION'] = '1' - - -@pytest.mark.gpu -@world_size(2) -@pytest.mark.parametrize('fsdp_version', [1, 2]) -def test_summon_full_params_recurse( - tiny_gpt2_tokenizer: PreTrainedTokenizer, - world_size: int, - fsdp_version: int, -): - """Test summon_full_params with recurse=False parameter.""" - del world_size - trainer, fsdp_model = _setup_fsdp_test_environment( - tiny_gpt2_tokenizer, - fsdp_version, - ) - - with summon_full_params(fsdp_model, recurse=False): - for name, param in fsdp_model.named_parameters(recurse=False): - assert param.data is not None # type: ignore - assert '.' not in name - - with summon_full_params(fsdp_model, recurse=True): - param_names = [ - name for name, _ in fsdp_model.named_parameters(recurse=True) - ] - assert any('.' in name for name in param_names) - - trainer.close() - os.environ['FSDP_VERSION'] = '1' - - -@pytest.mark.gpu -@world_size(2) -@pytest.mark.parametrize('fsdp_version', [1, 2]) -def test_summon_full_params_tied_weights_behavior( - world_size: int, - fsdp_version: int, - tiny_gpt2_tokenizer: PreTrainedTokenizer, -): - """Test summon_full_params with tied weights behavior verification.""" - del world_size - model = PartialWeightTiedModel(num_features=2).to('cuda') - model.add_fsdp_wrap_attribute_to_children() - - trainer, fsdp_model = _setup_fsdp_test_environment( - tiny_gpt2_tokenizer, - fsdp_version, - model, - ) - - # fill the tied weights with 999.0 - fsdp_model.module[0].net[0].weight.data.fill_(999.0) # type: ignore - # assert weight tying - assert torch.all( - fsdp_model.module[0].net[-1].weight.data == 999.0, # type: ignore - ) # type: ignore - - # Test writeback=False - with summon_full_params(fsdp_model, writeback=False): - error_msg = 'Tied weights should be the same tensor object inside context' - first_weight = fsdp_model.module[0].net[0].weight # type: ignore - last_weight = fsdp_model.module[0].net[-1].weight # type: ignore - assert first_weight is last_weight, error_msg - - first_weight.data.fill_(777.0) - error_msg = 'Tied weights should be consistent inside context' - assert torch.all(last_weight.data == 777.0), error_msg - - first_weight_same = torch.all( - fsdp_model.module[0].net[0].weight.data == 999.0, # type: ignore - ) - last_weight_same = torch.all( - fsdp_model.module[0].net[-1].weight.data == 999.0, # type: ignore - ) - - assert first_weight_same, 'First tied weight should be the same with writeback=False' - assert last_weight_same, 'Second tied weight should be the same with writeback=False' - - # Test writeback=True - with summon_full_params(fsdp_model, writeback=True): - first_weight = fsdp_model.module[0].net[0].weight # type: ignore - last_weight = fsdp_model.module[0].net[-1].weight # type: ignore - error_msg = 'Tied weights should be the same tensor object inside context' - assert first_weight is last_weight, error_msg - - first_weight.data.fill_(888.0) - - error_msg = 'Tied weights should be consistent inside context' - assert torch.all(last_weight.data == 888.0), error_msg - - first_weight_changed = torch.all( - fsdp_model.module[0].net[0].weight.data == 888.0, # type: ignore - ) - last_weight_changed = torch.all( - fsdp_model.module[0].net[-1].weight.data == 888.0, # type: ignore - ) - - assert first_weight_changed, 'First tied weight should keep modified value with writeback=True' - assert last_weight_changed, 'Second tied weight should keep modified value with writeback=True' - - trainer.close() - os.environ['FSDP_VERSION'] = '1' - - -@pytest.mark.gpu -@world_size(2) -def test_get_params_to_summon_fsdp2( - tiny_gpt2_tokenizer: PreTrainedTokenizer, - world_size: int, -): - """Test _get_params_to_summon_fsdp2 function with nested FSDP structure.""" - del world_size - - model = NestedFSDPModel(num_features=2).to('cuda') - model.add_fsdp_wrap_attribute_to_children() - - _, fsdp_model = _setup_fsdp_test_environment( - tiny_gpt2_tokenizer, - fsdp_version=2, - model=model, - ) - - dtensor_params_recurse = _get_params_to_summon_fsdp2( - fsdp_model.module, # type: ignore - recurse=True, - ) - dtensor_params_no_recurse = _get_params_to_summon_fsdp2( - fsdp_model.module, # type: ignore - recurse=False, - ) - - # Assert all are DTensors - for param in dtensor_params_recurse.values(): - assert isinstance( - param, - DTensor, - ), f'Parameter {param.name} should be a DTensor' - for param in dtensor_params_no_recurse.values(): - assert isinstance( - param, - DTensor, - ), f'Parameter {param.name} should be a DTensor' - - assert len(dtensor_params_recurse) == 4, 'Should have 4 DTensors' - for ( - name, - param, - ), value in zip(dtensor_params_recurse.items(), [1.0, 2.0, 3.0, 4.0]): - assert torch.all( - param.data == value, - ), f'Parameter {name} should have value {value}' - assert len(dtensor_params_no_recurse) == 2, 'Should have 2 DTensors' - for (name, - param), value in zip(dtensor_params_no_recurse.items(), [1.0, 3.0]): - assert torch.all( - param.data == value, - ), f'Parameter {name} should have value {value}' - os.environ['FSDP_VERSION'] = '1' From 5f0977079aaffda1cf30fb5ef1668f509d9b7b98 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 23 Jul 2025 03:09:11 +0000 Subject: [PATCH 10/15] formatted --- compose_rl/algorithms/online/generation_utils/vllm_utils.py | 2 +- compose_rl/algorithms/online/hf_utils.py | 4 ++-- compose_rl/algorithms/online/model.py | 5 ++--- compose_rl/algorithms/reward_modeling/model.py | 4 ++-- 4 files changed, 7 insertions(+), 8 deletions(-) diff --git a/compose_rl/algorithms/online/generation_utils/vllm_utils.py b/compose_rl/algorithms/online/generation_utils/vllm_utils.py index 5498c119..5ae40b59 100644 --- a/compose_rl/algorithms/online/generation_utils/vllm_utils.py +++ b/compose_rl/algorithms/online/generation_utils/vllm_utils.py @@ -28,6 +28,7 @@ import torch import torch.distributed import torch.nn as nn +from composer.distributed.shared_utils import get_summon_params_fn from composer.utils import dist from ray.exceptions import GetTimeoutError from ray.util.placement_group import placement_group @@ -51,7 +52,6 @@ ALGORITHM_TYPE, OnPolicyEnum, ) -from composer.distributed.shared_utils import get_summon_params_fn log = logging.getLogger(__name__) diff --git a/compose_rl/algorithms/online/hf_utils.py b/compose_rl/algorithms/online/hf_utils.py index 85c0aac4..db114db3 100644 --- a/compose_rl/algorithms/online/hf_utils.py +++ b/compose_rl/algorithms/online/hf_utils.py @@ -7,6 +7,8 @@ import torch import torch.nn as nn +from composer.distributed.shared_utils import get_summon_params_fn +from composer.utils import is_model_fsdp from transformers import ( AutoConfig, AutoModelForCausalLM, @@ -22,8 +24,6 @@ ) from compose_rl.algorithms.online.policy_configuration import HFPolicyConfig from compose_rl.utils.consts import _MASTER_WEIGHTS_PRECISION -from composer.distributed.shared_utils import get_summon_params_fn -from composer.utils import is_model_fsdp Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] diff --git a/compose_rl/algorithms/online/model.py b/compose_rl/algorithms/online/model.py index 9264602c..154be9fa 100644 --- a/compose_rl/algorithms/online/model.py +++ b/compose_rl/algorithms/online/model.py @@ -8,8 +8,9 @@ from typing import Any, MutableMapping, Optional, Union import torch +from composer.distributed.shared_utils import get_summon_params_fn from composer.models import HuggingFaceModel -from composer.utils import dist +from composer.utils import dist, is_model_fsdp from llmfoundry.models import ComposerHFCausalLM from transformers import ( PreTrainedTokenizer, @@ -28,8 +29,6 @@ clear_mb_load_balancing_loss, get_mb_load_balancing_loss, ) -from composer.utils import is_model_fsdp -from composer.distributed.shared_utils import get_summon_params_fn Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] diff --git a/compose_rl/algorithms/reward_modeling/model.py b/compose_rl/algorithms/reward_modeling/model.py index 908711a2..0986a533 100644 --- a/compose_rl/algorithms/reward_modeling/model.py +++ b/compose_rl/algorithms/reward_modeling/model.py @@ -9,6 +9,8 @@ from typing import Any, Mapping, MutableMapping, Optional, Union import torch +from composer.distributed.shared_utils import get_summon_params_fn +from composer.utils import is_model_fsdp from llmfoundry.models import ComposerHFCausalLM, ComposerMPTCausalLM from compose_rl.algorithms.reward_modeling.base_reward import ( @@ -30,8 +32,6 @@ ComposerHFSequenceClassification from compose_rl.algorithms.reward_modeling.modeling_mpt import \ MPTForSequenceClassification -from composer.utils import is_model_fsdp -from composer.distributed.shared_utils import get_summon_params_fn log = logging.getLogger(__name__) From 8f7ea67e6a4d47ba27a3a9a08aed092313c661c1 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 23 Jul 2025 04:21:52 +0000 Subject: [PATCH 11/15] updated version requirements to test out changes --- pyproject.toml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 043391af..1dbe1c9d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,13 +23,13 @@ dependencies = [ [project.optional-dependencies] cpu = [ - 'llm-foundry[all-cpu]@git+https://github.com/mosaicml/llm-foundry.git@main#egg=llmfoundry', + 'llm-foundry[all-cpu]@git+https://github.com/mosaicml/llm-foundry.git@ricky-fsdp2-temp-version#egg=llmfoundry', ] gpu = [ - 'llm-foundry[all]@git+https://github.com/mosaicml/llm-foundry.git@main#egg=llmfoundry', + 'llm-foundry[all]@git+https://github.com/mosaicml/llm-foundry.git@ricky-fsdp2-temp-version#egg=llmfoundry', ] dev = [ - 'llm-foundry[dev]@git+https://github.com/mosaicml/llm-foundry.git@main#egg=llmfoundry', + 'llm-foundry[dev]@git+https://github.com/mosaicml/llm-foundry.git@ricky-fsdp2-temp-version#egg=llmfoundry', ] released = [ 'llm-foundry[all]>=0.21.0', From 244463d228cde10a5931f66ba0c1d86ade023aec Mon Sep 17 00:00:00 2001 From: root Date: Fri, 25 Jul 2025 00:01:02 +0000 Subject: [PATCH 12/15] minor name change --- .../algorithms/online/generation_utils/vllm_utils.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/compose_rl/algorithms/online/generation_utils/vllm_utils.py b/compose_rl/algorithms/online/generation_utils/vllm_utils.py index 5ae40b59..151b58de 100644 --- a/compose_rl/algorithms/online/generation_utils/vllm_utils.py +++ b/compose_rl/algorithms/online/generation_utils/vllm_utils.py @@ -352,12 +352,12 @@ def should_update_torch_module( return False -def get_path_to_param(model: nn.Module, param: torch.Tensor) -> str: - """Get the path to a parameter in the model. +def get_name_for_param(model: nn.Module, param: torch.Tensor) -> str: + """Get the full name of a parameter in the model. Args: - model (nn.Module): The model to get the path to - param (torch.Tensor): The parameter to get the path to + model (nn.Module): The model that contains the parameter + param (torch.Tensor): The parameter to get the name for """ for name, p in model.named_parameters(): if p is param: @@ -490,7 +490,7 @@ def broadcast_to_vllm( if isinstance(param, DTensor): continue - full_name = get_path_to_param(model, param) + full_name = get_name_for_param(model, param) parsed_name = simplify_param_path(full_name) if parsed_name in seen_updated_parsed_names: From 65c936c14bc10dd9858348a8a0435c6c85c65035 Mon Sep 17 00:00:00 2001 From: root Date: Fri, 25 Jul 2025 17:31:23 +0000 Subject: [PATCH 13/15] we don't support rank0_only=True anymore --- compose_rl/algorithms/online/generation_utils/vllm_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compose_rl/algorithms/online/generation_utils/vllm_utils.py b/compose_rl/algorithms/online/generation_utils/vllm_utils.py index 151b58de..1a9afc0e 100644 --- a/compose_rl/algorithms/online/generation_utils/vllm_utils.py +++ b/compose_rl/algorithms/online/generation_utils/vllm_utils.py @@ -469,7 +469,7 @@ def broadcast_to_vllm( with summon_full_params( module, writeback=False, - rank0_only=True, + rank0_only=False, recurse=False, ): # Note: For the following module.named_parameters(), we have to use recurse=True From 19e8e3c43b4c6da03b5b2b6ff061b79c5081df18 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 28 Jul 2025 17:46:35 +0000 Subject: [PATCH 14/15] formatted + removed toml changes --- compose_rl/algorithms/online/generation_utils/vllm_utils.py | 1 + compose_rl/algorithms/online/model.py | 2 ++ pyproject.toml | 6 +++--- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/compose_rl/algorithms/online/generation_utils/vllm_utils.py b/compose_rl/algorithms/online/generation_utils/vllm_utils.py index 1a9afc0e..6f0e91c6 100644 --- a/compose_rl/algorithms/online/generation_utils/vllm_utils.py +++ b/compose_rl/algorithms/online/generation_utils/vllm_utils.py @@ -116,6 +116,7 @@ def init_process_group( class WorkerWrap: + def init_process_group( self, master_address: str, diff --git a/compose_rl/algorithms/online/model.py b/compose_rl/algorithms/online/model.py index 154be9fa..1ed00e6a 100644 --- a/compose_rl/algorithms/online/model.py +++ b/compose_rl/algorithms/online/model.py @@ -36,6 +36,7 @@ class ComposerMPTPolicyLM(HuggingFaceModel): + def __init__( self, tokenizer: Tokenizer, @@ -135,6 +136,7 @@ def set_batch_stats(self, batch_stats: dict[str, Any]): class ComposerHFPolicyLM(ComposerHFPolicy): + def __init__( self, *, diff --git a/pyproject.toml b/pyproject.toml index 1dbe1c9d..043391af 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,13 +23,13 @@ dependencies = [ [project.optional-dependencies] cpu = [ - 'llm-foundry[all-cpu]@git+https://github.com/mosaicml/llm-foundry.git@ricky-fsdp2-temp-version#egg=llmfoundry', + 'llm-foundry[all-cpu]@git+https://github.com/mosaicml/llm-foundry.git@main#egg=llmfoundry', ] gpu = [ - 'llm-foundry[all]@git+https://github.com/mosaicml/llm-foundry.git@ricky-fsdp2-temp-version#egg=llmfoundry', + 'llm-foundry[all]@git+https://github.com/mosaicml/llm-foundry.git@main#egg=llmfoundry', ] dev = [ - 'llm-foundry[dev]@git+https://github.com/mosaicml/llm-foundry.git@ricky-fsdp2-temp-version#egg=llmfoundry', + 'llm-foundry[dev]@git+https://github.com/mosaicml/llm-foundry.git@main#egg=llmfoundry', ] released = [ 'llm-foundry[all]>=0.21.0', From 0dbe84e043051182275be6be3989326a25fd59d8 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 28 Jul 2025 18:11:00 +0000 Subject: [PATCH 15/15] test if this allows the right imports --- .github/workflows/pr-cpu.yaml | 2 +- .github/workflows/pr-gpu.yaml | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/pr-cpu.yaml b/.github/workflows/pr-cpu.yaml index 283c7ed1..05b47cc1 100644 --- a/.github/workflows/pr-cpu.yaml +++ b/.github/workflows/pr-cpu.yaml @@ -23,7 +23,7 @@ jobs: matrix: include: - name: "cpu-2.7.0" - container: mosaicml/pytorch:2.7.0_cpu-python3.12-ubuntu22.04 + container: mosaicml/dle:nightly-latest # Update after the next release of llm-foundry (mosaicml/llm-foundry:2.7.0_cpu-python3.12-ubuntu22.04) markers: "not gpu and not only_release" pip_deps: "[cpu]" pytest_command: "coverage run -m pytest" diff --git a/.github/workflows/pr-gpu.yaml b/.github/workflows/pr-gpu.yaml index 0cfc1784..efc5d4c4 100644 --- a/.github/workflows/pr-gpu.yaml +++ b/.github/workflows/pr-gpu.yaml @@ -23,7 +23,7 @@ jobs: matrix: include: - name: "gpu-2.7.0-1" - container: mosaicml/llm-foundry:2.7.0_cu128-latest + container: mosaicml/dle:nightly-latest # Update after the next release of llm-foundry (mosaicml/llm-foundry:2.7.0_cu128-latest) markers: "gpu" pip_deps: "[gpu]" pytest_command: "coverage run -m pytest" @@ -52,7 +52,7 @@ jobs: matrix: include: - name: "gpu-2.7.0-2" - container: mosaicml/llm-foundry:2.7.0_cu128-latest + container: mosaicml/dle:nightly-latest # Update after the next release of llm-foundry (mosaicml/llm-foundry:2.7.0_cu128-latest) markers: "gpu" pip_deps: "[gpu]" pytest_command: "coverage run -m pytest"