From d1569d2f51809ab5f2cb2976d21f3e4e722dad6c Mon Sep 17 00:00:00 2001 From: NikhilNayak-debug <51783304+NikhilNayak-debug@users.noreply.github.com> Date: Tue, 15 Jul 2025 10:34:24 -0400 Subject: [PATCH 01/19] Add orthogonal subspace learning via SVD --- src/peft/__init__.py | 12 ++ src/peft/utils/__init__.py | 14 ++ src/peft/utils/svd_utils.py | 290 ++++++++++++++++++++++++++++++++++++ tests/test_svd_utils.py | 14 ++ 4 files changed, 330 insertions(+) create mode 100644 src/peft/utils/svd_utils.py create mode 100644 tests/test_svd_utils.py diff --git a/src/peft/__init__.py b/src/peft/__init__.py index b2fcbe901f..c4740cce85 100644 --- a/src/peft/__init__.py +++ b/src/peft/__init__.py @@ -108,11 +108,17 @@ TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING, PeftType, TaskType, + auto_generate_target_svd_config, bloom_model_postprocess_past_key_value, cast_mixed_precision_params, + create_svd_model_class, + decompose_weight_matrix, get_peft_model_state_dict, load_peft_weights, + optim_wrapper, prepare_model_for_kbit_training, + project_gradient_to_orthogonal_space, + reconstruct_weight_matrix, replace_lora_weights_loftq, set_peft_model_state_dict, shift_tokens_right, @@ -200,8 +206,11 @@ "VeraModel", "XLoraConfig", "XLoraModel", + "auto_generate_target_svd_config", "bloom_model_postprocess_past_key_value", "cast_mixed_precision_params", + "create_svd_model_class", + "decompose_weight_matrix", "get_eva_state_dict", "get_layer_status", "get_model_status", @@ -211,7 +220,10 @@ "initialize_lora_eva_weights", "inject_adapter_in_model", "load_peft_weights", + "optim_wrapper", "prepare_model_for_kbit_training", + "project_gradient_to_orthogonal_space", + "reconstruct_weight_matrix", "replace_lora_weights_loftq", "set_peft_model_state_dict", "shift_tokens_right", diff --git a/src/peft/utils/__init__.py b/src/peft/utils/__init__.py index 3b992d8aac..29b25d4431 100644 --- a/src/peft/utils/__init__.py +++ b/src/peft/utils/__init__.py @@ -55,6 +55,14 @@ ) from .peft_types import PeftType, TaskType, register_peft_method from .save_and_load import get_peft_model_state_dict, load_peft_weights, set_peft_model_state_dict +from .svd_utils import ( + auto_generate_target_svd_config, + create_svd_model_class, + decompose_weight_matrix, + optim_wrapper, + project_gradient_to_orthogonal_space, + reconstruct_weight_matrix, +) __all__ = [ @@ -86,8 +94,11 @@ "_prepare_prompt_learning_config", "_set_adapter", "_set_trainable", + "auto_generate_target_svd_config", "bloom_model_postprocess_past_key_value", "cast_mixed_precision_params", + "create_svd_model_class", + "decompose_weight_matrix", "get_auto_gptq_quant_linear", "get_gptqmodel_quant_linear", "get_peft_model_state_dict", @@ -96,7 +107,10 @@ "infer_device", "load_peft_weights", "map_cache_to_layer_device_map", + "optim_wrapper", "prepare_model_for_kbit_training", + "project_gradient_to_orthogonal_space", + "reconstruct_weight_matrix", "register_peft_method", "replace_lora_weights_loftq", "set_additional_trainable_modules", diff --git a/src/peft/utils/svd_utils.py b/src/peft/utils/svd_utils.py new file mode 100644 index 0000000000..b81342e0bd --- /dev/null +++ b/src/peft/utils/svd_utils.py @@ -0,0 +1,290 @@ +# Copyright 2024-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utilities for Orthogonal Subspace Learning with adaptive SVD.""" + +from __future__ import annotations + +import math +from typing import Any, Optional + +import torch +from torch import nn +from torch.nn import functional as F + + +__all__ = [ + "auto_generate_target_svd_config", + "create_svd_model_class", + "decompose_weight_matrix", + "optim_wrapper", + "project_gradient_to_orthogonal_space", + "reconstruct_weight_matrix", +] + + +def decompose_weight_matrix(weight: torch.Tensor, top_k: int) -> dict[str, Any]: + """Perform an SVD of ``weight`` and split it into frozen and trainable parts.""" + device_local = weight.device + orig_dtype = weight.dtype + W = weight.to(torch.float32) + U, S, Vt = torch.linalg.svd(W, full_matrices=False) + k = min(top_k, S.shape[0]) + + svd = { + "U_high": U[:, :k].contiguous().detach().to(device=device_local, dtype=orig_dtype), + "S_high": S[:k].contiguous().detach().to(device=device_local, dtype=orig_dtype), + "V_high": Vt[:k, :].contiguous().detach().to(device=device_local, dtype=orig_dtype), + "U_low": nn.Parameter(U[:, k:].contiguous().detach().to(device=device_local, dtype=orig_dtype)), + "S_low": nn.Parameter(S[k:].contiguous().detach().to(device=device_local, dtype=orig_dtype)), + "V_low": nn.Parameter(Vt[k:, :].contiguous().detach().to(device=device_local, dtype=orig_dtype)), + "rank_high": k, + } + return svd + + +def reconstruct_weight_matrix(svd_dict: dict[str, torch.Tensor]) -> torch.Tensor: + """Reconstruct a weight matrix from its SVD components.""" + U_high = svd_dict["U_high"] + S_high = svd_dict["S_high"] + V_high = svd_dict["V_high"] + U_low = svd_dict["U_low"] + S_low = svd_dict["S_low"] + V_low = svd_dict["V_low"] + + high_part = ( + torch.mm(U_high * S_high.unsqueeze(0), V_high) + if U_high.numel() > 0 and S_high.numel() > 0 + else torch.zeros(U_low.size(0), V_low.size(1), device=U_high.device) + ) + low_part = ( + torch.mm(U_low * S_low.unsqueeze(0), V_low) + if U_low.numel() > 0 and S_low.numel() > 0 + else torch.zeros(U_high.size(0), V_high.size(1), device=U_low.device) + ) + return high_part + low_part + + +def project_gradient_to_orthogonal_space(svd_dict: dict[str, Any]) -> None: + """Project gradients of ``U_low`` and ``V_low`` to be orthogonal to the high rank space.""" + if svd_dict["U_low"].grad is None and svd_dict["S_low"].grad is None and svd_dict["V_low"].grad is None: + return + + U_high = svd_dict["U_high"] + V_high = svd_dict["V_high"] + + if svd_dict["U_low"].grad is not None: + dU = svd_dict["U_low"].grad + local_U_high = getattr(U_high, "to_local", lambda: U_high)() + local_dU = getattr(dU, "to_local", lambda: dU)() + if local_U_high.size(0) != local_dU.size(0): + rank = torch.distributed.get_rank() + start = rank * local_dU.size(0) + end = start + local_dU.size(0) + local_U_high = local_U_high[start:end] + proj = local_U_high @ (local_U_high.transpose(0, 1) @ local_dU) + local_dU.sub_(proj) + if hasattr(dU, "_local_tensor"): + dU._local_tensor.copy_(local_dU) + else: + dU.copy_(local_dU) + + if svd_dict["V_low"].grad is not None: + dV = svd_dict["V_low"].grad + local_V_high = getattr(V_high, "to_local", lambda: V_high)() + local_dV = getattr(dV, "to_local", lambda: dV)() + if local_V_high.size(1) != local_dV.size(1): + rank = torch.distributed.get_rank() + start = rank * local_dV.size(1) + end = start + local_dV.size(1) + local_V_high = local_V_high[:, start:end] + proj = (local_dV @ local_V_high.transpose(0, 1)) @ local_V_high + local_dV.sub_(proj) + if hasattr(dV, "_local_tensor"): + dV._local_tensor.copy_(local_dV) + else: + dV.copy_(local_dV) + + +def auto_generate_target_svd_config(model: nn.Module) -> dict[str, int]: + """Create a mapping from parameter names to ``top_k`` based on layer size.""" + target_patterns = [ + "self_attn.q_proj", + "self_attn.k_proj", + "self_attn.v_proj", + "self_attn.o_proj", + "mlp.gate_proj", + "mlp.down_proj", + "mlp.up_proj", + ] + config: dict[str, int] = {} + for name, param in model.named_parameters(): + if any(pat in name for pat in target_patterns) and len(param.shape) == 2: + top_k = int(math.floor(min(param.shape) * 0.5)) + full_rank = min(param.shape) + if top_k >= full_rank: + top_k = full_rank - 1 + config[name] = top_k + return config + + +def create_svd_model_class(base_cls: type) -> type: + """Create a subclass of ``base_cls`` where selected linear weights are replaced by SVD decompositions.""" + + class ModelWithSVD(base_cls): + def __init__(self, config, svd_config: Optional[dict[str, int]] = None, initialize_svd: bool = True, **kwargs): + super().__init__(config, **kwargs) + self.svd_config = svd_config or {} + self.name_mapping: dict[str, str] = {} + self.svd_params = nn.ModuleDict() + if initialize_svd: + self._initialize_svd_parameters() + + @classmethod + def from_pretrained( + cls, pretrained_model_name_or_path, *model_args, svd_config: Optional[dict[str, int]] = None, **kwargs + ): + model = super().from_pretrained( + pretrained_model_name_or_path, + *model_args, + svd_config=svd_config or {}, + **kwargs, + ) + if svd_config is None: + svd_config = auto_generate_target_svd_config(model) + model.svd_config = svd_config + model.reinitialize_svd() + return model + + def reinitialize_svd(self) -> None: + self.name_mapping = {} + self.svd_params = nn.ModuleDict() + self._initialize_svd_parameters() + + def _get_module_by_name(self, name: str): + parts = name.split(".") + attr = parts[-1] + mod = self + for p in parts[:-1]: + if hasattr(mod, p): + mod = getattr(mod, p) + elif p.isdigit(): + mod = mod[int(p)] + else: + return None, None + return mod, attr + + def _initialize_svd_parameters(self) -> None: + for name, param in list(self.named_parameters()): + if len(param.shape) == 2 and name in self.svd_config and self.svd_config[name] > 0: + top_k = self.svd_config[name] + svd_dict = decompose_weight_matrix(param.data, top_k=top_k) + safe_name = name.replace(".", "_") + self.name_mapping[name] = safe_name + self.register_buffer(f"{safe_name}_U_high", svd_dict["U_high"]) + self.register_buffer(f"{safe_name}_S_high", svd_dict["S_high"]) + self.register_buffer(f"{safe_name}_V_high", svd_dict["V_high"]) + + module_svd = nn.Module() + module_svd.U_low = svd_dict["U_low"] + module_svd.S_low = svd_dict["S_low"] + module_svd.V_low = svd_dict["V_low"] + module_svd.rank_high = svd_dict["rank_high"] + module_svd.safe_name = safe_name + self.svd_params[safe_name] = module_svd + + mod, attr = self._get_module_by_name(name) + bias = mod.bias if hasattr(mod, "bias") else None + + def make_forward(sn: str, bias: Optional[torch.Tensor]): + def forward(x): + W = self._reconstruct_weight_by_safe_name(sn) + if W.dtype != x.dtype: + W = W.to(x.dtype) + return F.linear(x, W, bias) + + return forward + + mod.forward = make_forward(safe_name, bias) + param.requires_grad = False + mod._parameters.pop(attr, None) + + def _reconstruct_weight_by_safe_name(self, safe_name: str) -> torch.Tensor: + U_high = getattr(self, f"{safe_name}_U_high") + S_high = getattr(self, f"{safe_name}_S_high") + V_high = getattr(self, f"{safe_name}_V_high") + module_svd = self.svd_params[safe_name] + svd_dict = { + "U_high": U_high, + "S_high": S_high, + "V_high": V_high, + "U_low": module_svd.U_low, + "S_low": module_svd.S_low, + "V_low": module_svd.V_low, + } + return reconstruct_weight_matrix(svd_dict) + + def project_gradients(self) -> None: + for safe_name, module_svd in self.svd_params.items(): + svd_dict = { + "U_high": getattr(self, f"{safe_name}_U_high"), + "S_high": getattr(self, f"{safe_name}_S_high"), + "V_high": getattr(self, f"{safe_name}_V_high"), + "U_low": module_svd.U_low, + "S_low": module_svd.S_low, + "V_low": module_svd.V_low, + } + project_gradient_to_orthogonal_space(svd_dict) + + def prepare_state_dict_for_save(self, state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + if not hasattr(self, "name_mapping"): + return state_dict + for orig, safe in self.name_mapping.items(): + U_high = state_dict.pop(f"{safe}_U_high") + S_high = state_dict.pop(f"{safe}_S_high") + V_high = state_dict.pop(f"{safe}_V_high") + U_low = state_dict.pop(f"svd_params.{safe}.U_low") + S_low = state_dict.pop(f"svd_params.{safe}.S_low") + V_low = state_dict.pop(f"svd_params.{safe}.V_low") + W = reconstruct_weight_matrix( + { + "U_high": U_high, + "S_high": S_high, + "V_high": V_high, + "U_low": U_low, + "S_low": S_low, + "V_low": V_low, + } + ) + state_dict[orig] = W + return state_dict + + ModelWithSVD.__name__ = f"{base_cls.__name__}WithSVD" + return ModelWithSVD + + +def optim_wrapper(optimizer: torch.optim.Optimizer, model: nn.Module) -> torch.optim.Optimizer: + """Wrap ``optimizer.step`` to project gradients before each update.""" + if not hasattr(model, "project_gradients"): + return optimizer + + import types + + orig_step = optimizer.step + + def step(self, *args, **kwargs): + model.project_gradients() + return orig_step(*args, **kwargs) + + optimizer.step = types.MethodType(step, optimizer) + return optimizer diff --git a/tests/test_svd_utils.py b/tests/test_svd_utils.py new file mode 100644 index 0000000000..24f8abee19 --- /dev/null +++ b/tests/test_svd_utils.py @@ -0,0 +1,14 @@ +import torch +from torch.testing import assert_close + +from peft.utils.svd_utils import ( + decompose_weight_matrix, + reconstruct_weight_matrix, +) + + +def test_svd_roundtrip(): + w = torch.randn(10, 8) + svd = decompose_weight_matrix(w, top_k=4) + w_rec = reconstruct_weight_matrix(svd) + assert_close(w_rec, w, atol=1e-5, rtol=1e-5) From 0a30d28f2b385b34fd902424a1f23b04966407bc Mon Sep 17 00:00:00 2001 From: NikhilNayak-debug <51783304+NikhilNayak-debug@users.noreply.github.com> Date: Tue, 15 Jul 2025 10:53:50 -0400 Subject: [PATCH 02/19] Add wrapper for SVD models --- src/peft/__init__.py | 2 ++ src/peft/utils/__init__.py | 2 ++ src/peft/utils/svd_utils.py | 31 +++++++++++++++++++++++++++---- tests/test_svd_utils.py | 25 +++++++++++++++++++++++++ 4 files changed, 56 insertions(+), 4 deletions(-) diff --git a/src/peft/__init__.py b/src/peft/__init__.py index c4740cce85..8d448c1ab5 100644 --- a/src/peft/__init__.py +++ b/src/peft/__init__.py @@ -122,6 +122,7 @@ replace_lora_weights_loftq, set_peft_model_state_dict, shift_tokens_right, + wrap_model_with_svd, ) @@ -227,4 +228,5 @@ "replace_lora_weights_loftq", "set_peft_model_state_dict", "shift_tokens_right", + "wrap_model_with_svd", ] diff --git a/src/peft/utils/__init__.py b/src/peft/utils/__init__.py index 29b25d4431..f4b5f789e2 100644 --- a/src/peft/utils/__init__.py +++ b/src/peft/utils/__init__.py @@ -62,6 +62,7 @@ optim_wrapper, project_gradient_to_orthogonal_space, reconstruct_weight_matrix, + wrap_model_with_svd, ) @@ -117,4 +118,5 @@ "set_peft_model_state_dict", "shift_tokens_right", "transpose", + "wrap_model_with_svd", ] diff --git a/src/peft/utils/svd_utils.py b/src/peft/utils/svd_utils.py index b81342e0bd..d173ca7446 100644 --- a/src/peft/utils/svd_utils.py +++ b/src/peft/utils/svd_utils.py @@ -16,7 +16,7 @@ from __future__ import annotations import math -from typing import Any, Optional +from typing import Any import torch from torch import nn @@ -30,6 +30,7 @@ "optim_wrapper", "project_gradient_to_orthogonal_space", "reconstruct_weight_matrix", + "wrap_model_with_svd", ] @@ -142,7 +143,7 @@ def create_svd_model_class(base_cls: type) -> type: """Create a subclass of ``base_cls`` where selected linear weights are replaced by SVD decompositions.""" class ModelWithSVD(base_cls): - def __init__(self, config, svd_config: Optional[dict[str, int]] = None, initialize_svd: bool = True, **kwargs): + def __init__(self, config, svd_config: dict[str, int] | None = None, initialize_svd: bool = True, **kwargs): super().__init__(config, **kwargs) self.svd_config = svd_config or {} self.name_mapping: dict[str, str] = {} @@ -152,7 +153,7 @@ def __init__(self, config, svd_config: Optional[dict[str, int]] = None, initiali @classmethod def from_pretrained( - cls, pretrained_model_name_or_path, *model_args, svd_config: Optional[dict[str, int]] = None, **kwargs + cls, pretrained_model_name_or_path, *model_args, svd_config: dict[str, int] | None = None, **kwargs ): model = super().from_pretrained( pretrained_model_name_or_path, @@ -206,7 +207,7 @@ def _initialize_svd_parameters(self) -> None: mod, attr = self._get_module_by_name(name) bias = mod.bias if hasattr(mod, "bias") else None - def make_forward(sn: str, bias: Optional[torch.Tensor]): + def make_forward(sn: str, bias: torch.Tensor | None): def forward(x): W = self._reconstruct_weight_by_safe_name(sn) if W.dtype != x.dtype: @@ -288,3 +289,25 @@ def step(self, *args, **kwargs): optimizer.step = types.MethodType(step, optimizer) return optimizer + + +def wrap_model_with_svd(model: nn.Module, svd_config: dict[str, int] | None = None) -> nn.Module: + """Return a version of ``model`` where selected weights are decomposed using SVD. + + Parameters ---------- model: + The model to wrap. It must expose a ``config`` attribute that will be passed to the wrapped class' constructor. + svd_config: + A mapping from parameter names to ``top_k`` ranks. If not provided, it is automatically generated based on the + layer shapes using :func:`auto_generate_target_svd_config`. + + Returns ------- ``nn.Module`` + A new model instance with the same weights as ``model`` but with trainable low-rank parameters and frozen + high-rank components. + """ + + svd_config = svd_config or auto_generate_target_svd_config(model) + SVDCls = create_svd_model_class(model.__class__) + wrapped = SVDCls(model.config, svd_config=svd_config, initialize_svd=False) + wrapped.load_state_dict(model.state_dict()) + wrapped.reinitialize_svd() + return wrapped diff --git a/tests/test_svd_utils.py b/tests/test_svd_utils.py index 24f8abee19..c48222de85 100644 --- a/tests/test_svd_utils.py +++ b/tests/test_svd_utils.py @@ -4,6 +4,7 @@ from peft.utils.svd_utils import ( decompose_weight_matrix, reconstruct_weight_matrix, + wrap_model_with_svd, ) @@ -12,3 +13,27 @@ def test_svd_roundtrip(): svd = decompose_weight_matrix(w, top_k=4) w_rec = reconstruct_weight_matrix(svd) assert_close(w_rec, w, atol=1e-5, rtol=1e-5) + + +class DummyConfig: + pass + + +class DummyModel(torch.nn.Module): + def __init__(self, config=None): + super().__init__() + self.config = config + self.linear = torch.nn.Linear(8, 4) + + def forward(self, x): + return self.linear(x) + + +def test_wrap_model_with_svd_preserves_output(): + torch.manual_seed(0) + model = DummyModel(DummyConfig()) + x = torch.randn(2, 8) + y_ref = model(x) + wrapped = wrap_model_with_svd(model) + y = wrapped(x) + assert_close(y, y_ref, atol=1e-5, rtol=1e-5) From af65172302f9ce79ad3086c2141d6c004f213e55 Mon Sep 17 00:00:00 2001 From: NikhilNayak-debug <51783304+NikhilNayak-debug@users.noreply.github.com> Date: Tue, 15 Jul 2025 11:17:32 -0400 Subject: [PATCH 03/19] docs: add adaptive SVD utilities --- docs/source/_toctree.yml | 2 ++ docs/source/package_reference/svd_utils.md | 26 +++++++++++++++++++ .../orthogonal_subspace_learning/README.md | 22 ++++++++++++++++ 3 files changed, 50 insertions(+) create mode 100644 docs/source/package_reference/svd_utils.md create mode 100644 examples/orthogonal_subspace_learning/README.md diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 0d7b785d68..65c84d3f63 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -137,6 +137,8 @@ title: Model merge - local: package_reference/helpers title: Helpers + - local: package_reference/svd_utils + title: SVD utilities - local: package_reference/hotswap title: Hotswapping adapters title: Utilities diff --git a/docs/source/package_reference/svd_utils.md b/docs/source/package_reference/svd_utils.md new file mode 100644 index 0000000000..c1c4e537b4 --- /dev/null +++ b/docs/source/package_reference/svd_utils.md @@ -0,0 +1,26 @@ + + +# SVD utilities + +Helper functions for orthogonal subspace learning with adaptive SVD. + +[[autodoc]] utils.svd_utils.decompose_weight_matrix + - all + +[[autodoc]] utils.svd_utils.reconstruct_weight_matrix + - all + +[[autodoc]] utils.svd_utils.project_gradient_to_orthogonal_space + - all + +[[autodoc]] utils.svd_utils.auto_generate_target_svd_config + - all + +[[autodoc]] utils.svd_utils.create_svd_model_class + - all + +[[autodoc]] utils.svd_utils.wrap_model_with_svd + - all + +[[autodoc]] utils.svd_utils.optim_wrapper + - all diff --git a/examples/orthogonal_subspace_learning/README.md b/examples/orthogonal_subspace_learning/README.md new file mode 100644 index 0000000000..c2550bf168 --- /dev/null +++ b/examples/orthogonal_subspace_learning/README.md @@ -0,0 +1,22 @@ +# Orthogonal Subspace Learning with adaptive SVD + +This example shows how to wrap a pretrained model with SVD-decomposed weights to enable orthogonal subspace training. + +```python +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer +from peft import wrap_model_with_svd, optim_wrapper + +model = AutoModelForCausalLM.from_pretrained("gpt2") +model = wrap_model_with_svd(model) # add trainable low-rank parameters + +optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4) +optimizer = optim_wrapper(optimizer, model) + +tokenizer = AutoTokenizer.from_pretrained("gpt2") +input_ids = tokenizer("Hello world", return_tensors="pt").input_ids +loss = model(input_ids, labels=input_ids).loss +loss.backward() +optimizer.step() +optimizer.zero_grad() +``` From bc3bb8804fca6d8150ea93c7c00d75a92fe317d1 Mon Sep 17 00:00:00 2001 From: Nikhil Nayak Date: Wed, 23 Jul 2025 14:47:59 +0000 Subject: [PATCH 04/19] changes for unified OSF method implementation in PEFT and documentation changes --- docs/source/_toctree.yml | 4 +- docs/source/package_reference/osf_utils.md | 26 ++++++++ docs/source/package_reference/svd_utils.md | 26 -------- .../orthogonal_subspace_learning/README.md | 6 +- src/peft/__init__.py | 16 +++-- src/peft/tuners/__init__.py | 2 + src/peft/tuners/osf/__init__.py | 13 ++++ src/peft/tuners/osf/config.py | 20 ++++++ src/peft/tuners/osf/model.py | 62 +++++++++++++++++++ 9 files changed, 138 insertions(+), 37 deletions(-) create mode 100644 docs/source/package_reference/osf_utils.md delete mode 100644 docs/source/package_reference/svd_utils.md create mode 100644 src/peft/tuners/osf/__init__.py create mode 100644 src/peft/tuners/osf/config.py create mode 100644 src/peft/tuners/osf/model.py diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 65c84d3f63..fb71541940 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -137,8 +137,8 @@ title: Model merge - local: package_reference/helpers title: Helpers - - local: package_reference/svd_utils - title: SVD utilities + - local: package_reference/osf_utils + title: OSF utilities - local: package_reference/hotswap title: Hotswapping adapters title: Utilities diff --git a/docs/source/package_reference/osf_utils.md b/docs/source/package_reference/osf_utils.md new file mode 100644 index 0000000000..4ea210a827 --- /dev/null +++ b/docs/source/package_reference/osf_utils.md @@ -0,0 +1,26 @@ + + +# OSF utilities + +Helper functions for orthogonal subspace learning with Adaptive OSF. + +[[autodoc]] utils.osf_utils.decompose_weight_matrix + - all + +[[autodoc]] utils.osf_utils.reconstruct_weight_matrix + - all + +[[autodoc]] utils.osf_utils.project_gradient_to_orthogonal_space + - all + +[[autodoc]] utils.osf_utils.auto_generate_target_osf_config + - all + +[[autodoc]] utils.osf_utils.create_osf_model_class + - all + +[[autodoc]] utils.osf_utils.wrap_model_with_osf + - all + +[[autodoc]] utils.osf_utils.optim_wrapper + - all \ No newline at end of file diff --git a/docs/source/package_reference/svd_utils.md b/docs/source/package_reference/svd_utils.md deleted file mode 100644 index c1c4e537b4..0000000000 --- a/docs/source/package_reference/svd_utils.md +++ /dev/null @@ -1,26 +0,0 @@ - - -# SVD utilities - -Helper functions for orthogonal subspace learning with adaptive SVD. - -[[autodoc]] utils.svd_utils.decompose_weight_matrix - - all - -[[autodoc]] utils.svd_utils.reconstruct_weight_matrix - - all - -[[autodoc]] utils.svd_utils.project_gradient_to_orthogonal_space - - all - -[[autodoc]] utils.svd_utils.auto_generate_target_svd_config - - all - -[[autodoc]] utils.svd_utils.create_svd_model_class - - all - -[[autodoc]] utils.svd_utils.wrap_model_with_svd - - all - -[[autodoc]] utils.svd_utils.optim_wrapper - - all diff --git a/examples/orthogonal_subspace_learning/README.md b/examples/orthogonal_subspace_learning/README.md index c2550bf168..911fd59640 100644 --- a/examples/orthogonal_subspace_learning/README.md +++ b/examples/orthogonal_subspace_learning/README.md @@ -1,14 +1,14 @@ -# Orthogonal Subspace Learning with adaptive SVD +# Orthogonal Subspace Learning with Adaptive OSF This example shows how to wrap a pretrained model with SVD-decomposed weights to enable orthogonal subspace training. ```python import torch from transformers import AutoModelForCausalLM, AutoTokenizer -from peft import wrap_model_with_svd, optim_wrapper +from peft import wrap_model_with_osf, optim_wrapper model = AutoModelForCausalLM.from_pretrained("gpt2") -model = wrap_model_with_svd(model) # add trainable low-rank parameters +model = wrap_model_with_osf(model) # add trainable low-rank parameters optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4) optimizer = optim_wrapper(optimizer, model) diff --git a/src/peft/__init__.py b/src/peft/__init__.py index 8d448c1ab5..5822871290 100644 --- a/src/peft/__init__.py +++ b/src/peft/__init__.py @@ -95,6 +95,8 @@ ShiraModel, TrainableTokensConfig, TrainableTokensModel, + OSFConfig, + OSFModel, VBLoRAConfig, VBLoRAModel, VeraConfig, @@ -108,10 +110,10 @@ TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING, PeftType, TaskType, - auto_generate_target_svd_config, + auto_generate_target_osf_config, bloom_model_postprocess_past_key_value, cast_mixed_precision_params, - create_svd_model_class, + create_osf_model_class, decompose_weight_matrix, get_peft_model_state_dict, load_peft_weights, @@ -122,7 +124,7 @@ replace_lora_weights_loftq, set_peft_model_state_dict, shift_tokens_right, - wrap_model_with_svd, + wrap_model_with_osf, ) @@ -200,6 +202,8 @@ "TaskType", "TrainableTokensConfig", "TrainableTokensModel", + "OSFConfig", + "OSFModel", "VBLoRAConfig", "VBLoRAConfig", "VBLoRAModel", @@ -207,10 +211,10 @@ "VeraModel", "XLoraConfig", "XLoraModel", - "auto_generate_target_svd_config", + "auto_generate_target_osf_config", "bloom_model_postprocess_past_key_value", "cast_mixed_precision_params", - "create_svd_model_class", + "create_osf_model_class", "decompose_weight_matrix", "get_eva_state_dict", "get_layer_status", @@ -228,5 +232,5 @@ "replace_lora_weights_loftq", "set_peft_model_state_dict", "shift_tokens_right", - "wrap_model_with_svd", + "wrap_model_with_osf", ] diff --git a/src/peft/tuners/__init__.py b/src/peft/tuners/__init__.py index f758499e12..3d0a94329a 100644 --- a/src/peft/tuners/__init__.py +++ b/src/peft/tuners/__init__.py @@ -100,6 +100,8 @@ "ShiraModel", "TrainableTokensConfig", "TrainableTokensModel", + "OSFConfig", + "OSFModel", "VBLoRAConfig", "VBLoRAModel", "VeraConfig", diff --git a/src/peft/tuners/osf/__init__.py b/src/peft/tuners/osf/__init__.py new file mode 100644 index 0000000000..4cf83ac38b --- /dev/null +++ b/src/peft/tuners/osf/__init__.py @@ -0,0 +1,13 @@ +from peft.utils import register_peft_method + +from .config import OSFConfig +from .model import OSFModel + +__all__ = ["OSFConfig", "OSFModel"] + +register_peft_method( + name="osf", + config_cls=OSFConfig, + model_cls=OSFModel, + is_mixed_compatible=False, +) \ No newline at end of file diff --git a/src/peft/tuners/osf/config.py b/src/peft/tuners/osf/config.py new file mode 100644 index 0000000000..5092eebd9d --- /dev/null +++ b/src/peft/tuners/osf/config.py @@ -0,0 +1,20 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Optional, Dict + +from peft.config import PeftConfig +from peft.utils import PeftType + + +@dataclass +class OSFConfig(PeftConfig): + """Configuration for Orthogonal Subspace Fine-tuning (OSF).""" + + target_svd_config: Optional[Dict[str, int]] = field( + default=None, + metadata={"help": "Mapping from parameter names to top_k rank."}, + ) + + def __post_init__(self): + self.peft_type = PeftType.OSF \ No newline at end of file diff --git a/src/peft/tuners/osf/model.py b/src/peft/tuners/osf/model.py new file mode 100644 index 0000000000..94c8bce48b --- /dev/null +++ b/src/peft/tuners/osf/model.py @@ -0,0 +1,62 @@ +from __future__ import annotations + +import torch +import torch.nn as nn + +from peft.tuners.tuners_utils import BaseTuner +from peft.utils.osf_utils import ( + auto_generate_target_osf_config, + create_osf_model_class, +) + + +class OSFModel(BaseTuner): + """A minimal tuner implementing Orthogonal Subspace Fine-tuning.""" + + prefix: str = "osf_" + + def __init__(self, model, config, adapter_name, low_cpu_mem_usage: bool = False): + super().__init__(model, config, adapter_name, low_cpu_mem_usage=low_cpu_mem_usage) + + def _prepare_adapter_config(self, peft_config, model_config): + return peft_config + + def inject_adapter(self, model: nn.Module, adapter_name: str, autocast_adapter_dtype: bool = True, low_cpu_mem_usage: bool = False) -> None: + svd_cfg = self.peft_config[adapter_name].target_svd_config + if svd_cfg is None: + svd_cfg = auto_generate_target_osf_config(model) + self.peft_config[adapter_name].target_svd_config = svd_cfg + OSFCls = create_osf_model_class(model.__class__) + osf_model = OSFCls(model.config, svd_config=svd_cfg, initialize_svd=False) + osf_model.load_state_dict(model.state_dict()) + osf_model.reinitialize_svd() + self.model = osf_model + + def _create_and_replace(self, *args, **kwargs): + pass + + def _check_target_module_exists(self, *args, **kwargs) -> bool: + return True + + def _mark_only_adapters_as_trainable(self, model: nn.Module) -> None: + for n, p in model.named_parameters(): + if "svd_params" not in n and not n.endswith(("_U_low", "_S_low", "_V_low")): + p.requires_grad = False + + def _set_adapter_layers(self, enabled: bool = True) -> None: + pass + + def enable_adapter_layers(self) -> None: + self._set_adapter_layers(True) + + def disable_adapter_layers(self) -> None: + self._set_adapter_layers(False) + + def set_adapter(self, adapter_name): + self.active_adapter = adapter_name + + def unload(self): + raise NotImplementedError("OSF models cannot be unloaded yet") + + def merge_and_unload(self, *args, **kwargs): + raise NotImplementedError("OSF models do not support merging") \ No newline at end of file From 0e400893a5c1e5cdd45769c74a3ed1af989c720c Mon Sep 17 00:00:00 2001 From: Nikhil Nayak Date: Mon, 28 Jul 2025 21:06:24 +0000 Subject: [PATCH 05/19] naming changes and compatibility --- src/peft/utils/__init__.py | 14 ++++----- src/peft/utils/{svd_utils.py => osf_utils.py} | 30 +++++++++---------- src/peft/utils/peft_types.py | 2 ++ tests/test_custom_models.py | 5 ++++ .../{test_svd_utils.py => test_osf_utils.py} | 10 +++---- 5 files changed, 34 insertions(+), 27 deletions(-) rename src/peft/utils/{svd_utils.py => osf_utils.py} (93%) rename tests/{test_svd_utils.py => test_osf_utils.py} (80%) diff --git a/src/peft/utils/__init__.py b/src/peft/utils/__init__.py index f4b5f789e2..5a65eb8799 100644 --- a/src/peft/utils/__init__.py +++ b/src/peft/utils/__init__.py @@ -55,14 +55,14 @@ ) from .peft_types import PeftType, TaskType, register_peft_method from .save_and_load import get_peft_model_state_dict, load_peft_weights, set_peft_model_state_dict -from .svd_utils import ( - auto_generate_target_svd_config, - create_svd_model_class, +from .osf_utils import ( + auto_generate_target_osf_config, + create_osf_model_class, decompose_weight_matrix, optim_wrapper, project_gradient_to_orthogonal_space, reconstruct_weight_matrix, - wrap_model_with_svd, + wrap_model_with_osf, ) @@ -95,10 +95,10 @@ "_prepare_prompt_learning_config", "_set_adapter", "_set_trainable", - "auto_generate_target_svd_config", + "auto_generate_target_osf_config", "bloom_model_postprocess_past_key_value", "cast_mixed_precision_params", - "create_svd_model_class", + "create_osf_model_class", "decompose_weight_matrix", "get_auto_gptq_quant_linear", "get_gptqmodel_quant_linear", @@ -118,5 +118,5 @@ "set_peft_model_state_dict", "shift_tokens_right", "transpose", - "wrap_model_with_svd", + "wrap_model_with_osf", ] diff --git a/src/peft/utils/svd_utils.py b/src/peft/utils/osf_utils.py similarity index 93% rename from src/peft/utils/svd_utils.py rename to src/peft/utils/osf_utils.py index d173ca7446..15e6b08a9b 100644 --- a/src/peft/utils/svd_utils.py +++ b/src/peft/utils/osf_utils.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Utilities for Orthogonal Subspace Learning with adaptive SVD.""" +"""Utilities for Orthogonal Subspace Learning with Adaptive OSF.""" from __future__ import annotations @@ -24,13 +24,13 @@ __all__ = [ - "auto_generate_target_svd_config", - "create_svd_model_class", + "auto_generate_target_osf_config", + "create_osf_model_class", "decompose_weight_matrix", "optim_wrapper", "project_gradient_to_orthogonal_space", "reconstruct_weight_matrix", - "wrap_model_with_svd", + "wrap_model_with_osf", ] @@ -117,7 +117,7 @@ def project_gradient_to_orthogonal_space(svd_dict: dict[str, Any]) -> None: dV.copy_(local_dV) -def auto_generate_target_svd_config(model: nn.Module) -> dict[str, int]: +def auto_generate_target_osf_config(model: nn.Module) -> dict[str, int]: """Create a mapping from parameter names to ``top_k`` based on layer size.""" target_patterns = [ "self_attn.q_proj", @@ -139,10 +139,10 @@ def auto_generate_target_svd_config(model: nn.Module) -> dict[str, int]: return config -def create_svd_model_class(base_cls: type) -> type: +def create_osf_model_class(base_cls: type) -> type: """Create a subclass of ``base_cls`` where selected linear weights are replaced by SVD decompositions.""" - class ModelWithSVD(base_cls): + class ModelWithOSF(base_cls): def __init__(self, config, svd_config: dict[str, int] | None = None, initialize_svd: bool = True, **kwargs): super().__init__(config, **kwargs) self.svd_config = svd_config or {} @@ -162,7 +162,7 @@ def from_pretrained( **kwargs, ) if svd_config is None: - svd_config = auto_generate_target_svd_config(model) + svd_config = auto_generate_target_osf_config(model) model.svd_config = svd_config model.reinitialize_svd() return model @@ -270,8 +270,8 @@ def prepare_state_dict_for_save(self, state_dict: dict[str, torch.Tensor]) -> di state_dict[orig] = W return state_dict - ModelWithSVD.__name__ = f"{base_cls.__name__}WithSVD" - return ModelWithSVD + ModelWithOSF.__name__ = f"{base_cls.__name__}WithOSF" + return ModelWithOSF def optim_wrapper(optimizer: torch.optim.Optimizer, model: nn.Module) -> torch.optim.Optimizer: @@ -291,23 +291,23 @@ def step(self, *args, **kwargs): return optimizer -def wrap_model_with_svd(model: nn.Module, svd_config: dict[str, int] | None = None) -> nn.Module: +def wrap_model_with_osf(model: nn.Module, svd_config: dict[str, int] | None = None) -> nn.Module: """Return a version of ``model`` where selected weights are decomposed using SVD. Parameters ---------- model: The model to wrap. It must expose a ``config`` attribute that will be passed to the wrapped class' constructor. svd_config: A mapping from parameter names to ``top_k`` ranks. If not provided, it is automatically generated based on the - layer shapes using :func:`auto_generate_target_svd_config`. + layer shapes using :func:`auto_generate_target_osf_config`. Returns ------- ``nn.Module`` A new model instance with the same weights as ``model`` but with trainable low-rank parameters and frozen high-rank components. """ - svd_config = svd_config or auto_generate_target_svd_config(model) - SVDCls = create_svd_model_class(model.__class__) - wrapped = SVDCls(model.config, svd_config=svd_config, initialize_svd=False) + svd_config = svd_config or auto_generate_target_osf_config(model) + OSFCls = create_osf_model_class(model.__class__) + wrapped = OSFCls(model.config, svd_config=svd_config, initialize_svd=False) wrapped.load_state_dict(model.state_dict()) wrapped.reinitialize_svd() return wrapped diff --git a/src/peft/utils/peft_types.py b/src/peft/utils/peft_types.py index 6e4aeae248..519756f1d9 100644 --- a/src/peft/utils/peft_types.py +++ b/src/peft/utils/peft_types.py @@ -43,6 +43,7 @@ class PeftType(str, enum.Enum): - RANDLORA - SHIRA - C3A + - OSF """ PROMPT_TUNING = "PROMPT_TUNING" @@ -70,6 +71,7 @@ class PeftType(str, enum.Enum): TRAINABLE_TOKENS = "TRAINABLE_TOKENS" SHIRA = "SHIRA" C3A = "C3A" + OSF = "OSF" class TaskType(str, enum.Enum): diff --git a/tests/test_custom_models.py b/tests/test_custom_models.py index 4b5d254afa..735398033f 100644 --- a/tests/test_custom_models.py +++ b/tests/test_custom_models.py @@ -49,6 +49,7 @@ ShiraConfig, TaskType, TrainableTokensConfig, + OSFConfig, VBLoRAConfig, VeraConfig, get_peft_model, @@ -621,6 +622,10 @@ TrainableTokensConfig, {"target_modules": ["emb"], "token_indices": [0, 1, 3], "init_weights": False}, ), + ################################ + # Orthogonal Subspace Learning # + ################################ + ("Vanilla MLP OSF", "MLP", OSFConfig, {}), ######## # RandLora # ######## diff --git a/tests/test_svd_utils.py b/tests/test_osf_utils.py similarity index 80% rename from tests/test_svd_utils.py rename to tests/test_osf_utils.py index c48222de85..2da13555d4 100644 --- a/tests/test_svd_utils.py +++ b/tests/test_osf_utils.py @@ -1,14 +1,14 @@ import torch from torch.testing import assert_close -from peft.utils.svd_utils import ( +from peft.utils.osf_utils import ( decompose_weight_matrix, reconstruct_weight_matrix, - wrap_model_with_svd, + wrap_model_with_osf, ) -def test_svd_roundtrip(): +def test_osf_roundtrip(): w = torch.randn(10, 8) svd = decompose_weight_matrix(w, top_k=4) w_rec = reconstruct_weight_matrix(svd) @@ -29,11 +29,11 @@ def forward(self, x): return self.linear(x) -def test_wrap_model_with_svd_preserves_output(): +def test_wrap_model_with_osf_preserves_output(): torch.manual_seed(0) model = DummyModel(DummyConfig()) x = torch.randn(2, 8) y_ref = model(x) - wrapped = wrap_model_with_svd(model) + wrapped = wrap_model_with_osf(model) y = wrapped(x) assert_close(y, y_ref, atol=1e-5, rtol=1e-5) From 8f77e5771871533f0b72ed5066ab7614bb7e90c7 Mon Sep 17 00:00:00 2001 From: Nikhil Nayak Date: Thu, 31 Jul 2025 13:25:39 +0000 Subject: [PATCH 06/19] unifying implementation with other PEFT methods and added gradient hooks --- docs/source/package_reference/osf_utils.md | 5 +- .../orthogonal_subspace_learning/README.md | 5 +- src/peft/__init__.py | 6 +- src/peft/tuners/osf/config.py | 3 +- src/peft/tuners/osf/model.py | 7 +- src/peft/utils/__init__.py | 6 +- src/peft/utils/osf_utils.py | 68 +++++++------------ tests/test_osf_utils.py | 6 +- 8 files changed, 42 insertions(+), 64 deletions(-) diff --git a/docs/source/package_reference/osf_utils.md b/docs/source/package_reference/osf_utils.md index 4ea210a827..d54bb7e66c 100644 --- a/docs/source/package_reference/osf_utils.md +++ b/docs/source/package_reference/osf_utils.md @@ -19,8 +19,5 @@ Helper functions for orthogonal subspace learning with Adaptive OSF. [[autodoc]] utils.osf_utils.create_osf_model_class - all -[[autodoc]] utils.osf_utils.wrap_model_with_osf - - all - -[[autodoc]] utils.osf_utils.optim_wrapper +[[autodoc]] utils.osf_utils.attach_gradient_hooks - all \ No newline at end of file diff --git a/examples/orthogonal_subspace_learning/README.md b/examples/orthogonal_subspace_learning/README.md index 911fd59640..f411b084a5 100644 --- a/examples/orthogonal_subspace_learning/README.md +++ b/examples/orthogonal_subspace_learning/README.md @@ -5,13 +5,12 @@ This example shows how to wrap a pretrained model with SVD-decomposed weights to ```python import torch from transformers import AutoModelForCausalLM, AutoTokenizer -from peft import wrap_model_with_osf, optim_wrapper +from peft import OSFConfig, get_peft_model model = AutoModelForCausalLM.from_pretrained("gpt2") -model = wrap_model_with_osf(model) # add trainable low-rank parameters +model = get_peft_model(model, OSFConfig()) # add trainable low-rank parameters optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4) -optimizer = optim_wrapper(optimizer, model) tokenizer = AutoTokenizer.from_pretrained("gpt2") input_ids = tokenizer("Hello world", return_tensors="pt").input_ids diff --git a/src/peft/__init__.py b/src/peft/__init__.py index 5822871290..72cc76e9fd 100644 --- a/src/peft/__init__.py +++ b/src/peft/__init__.py @@ -110,6 +110,7 @@ TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING, PeftType, TaskType, + attach_gradient_hooks, auto_generate_target_osf_config, bloom_model_postprocess_past_key_value, cast_mixed_precision_params, @@ -117,14 +118,12 @@ decompose_weight_matrix, get_peft_model_state_dict, load_peft_weights, - optim_wrapper, prepare_model_for_kbit_training, project_gradient_to_orthogonal_space, reconstruct_weight_matrix, replace_lora_weights_loftq, set_peft_model_state_dict, shift_tokens_right, - wrap_model_with_osf, ) @@ -211,6 +210,7 @@ "VeraModel", "XLoraConfig", "XLoraModel", + "attach_gradient_hooks", "auto_generate_target_osf_config", "bloom_model_postprocess_past_key_value", "cast_mixed_precision_params", @@ -225,12 +225,10 @@ "initialize_lora_eva_weights", "inject_adapter_in_model", "load_peft_weights", - "optim_wrapper", "prepare_model_for_kbit_training", "project_gradient_to_orthogonal_space", "reconstruct_weight_matrix", "replace_lora_weights_loftq", "set_peft_model_state_dict", "shift_tokens_right", - "wrap_model_with_osf", ] diff --git a/src/peft/tuners/osf/config.py b/src/peft/tuners/osf/config.py index 5092eebd9d..3e3f14945b 100644 --- a/src/peft/tuners/osf/config.py +++ b/src/peft/tuners/osf/config.py @@ -1,7 +1,6 @@ from __future__ import annotations from dataclasses import dataclass, field -from typing import Optional, Dict from peft.config import PeftConfig from peft.utils import PeftType @@ -11,7 +10,7 @@ class OSFConfig(PeftConfig): """Configuration for Orthogonal Subspace Fine-tuning (OSF).""" - target_svd_config: Optional[Dict[str, int]] = field( + target_svd_config: dict[str, int] | None = field( default=None, metadata={"help": "Mapping from parameter names to top_k rank."}, ) diff --git a/src/peft/tuners/osf/model.py b/src/peft/tuners/osf/model.py index 94c8bce48b..f27098d47e 100644 --- a/src/peft/tuners/osf/model.py +++ b/src/peft/tuners/osf/model.py @@ -1,10 +1,10 @@ from __future__ import annotations -import torch import torch.nn as nn from peft.tuners.tuners_utils import BaseTuner from peft.utils.osf_utils import ( + attach_gradient_hooks, auto_generate_target_osf_config, create_osf_model_class, ) @@ -21,7 +21,9 @@ def __init__(self, model, config, adapter_name, low_cpu_mem_usage: bool = False) def _prepare_adapter_config(self, peft_config, model_config): return peft_config - def inject_adapter(self, model: nn.Module, adapter_name: str, autocast_adapter_dtype: bool = True, low_cpu_mem_usage: bool = False) -> None: + def inject_adapter( + self, model: nn.Module, adapter_name: str, autocast_adapter_dtype: bool = True, low_cpu_mem_usage: bool = False + ) -> None: svd_cfg = self.peft_config[adapter_name].target_svd_config if svd_cfg is None: svd_cfg = auto_generate_target_osf_config(model) @@ -30,6 +32,7 @@ def inject_adapter(self, model: nn.Module, adapter_name: str, autocast_adapter_d osf_model = OSFCls(model.config, svd_config=svd_cfg, initialize_svd=False) osf_model.load_state_dict(model.state_dict()) osf_model.reinitialize_svd() + attach_gradient_hooks(osf_model) self.model = osf_model def _create_and_replace(self, *args, **kwargs): diff --git a/src/peft/utils/__init__.py b/src/peft/utils/__init__.py index 5a65eb8799..fc7e9dc4f7 100644 --- a/src/peft/utils/__init__.py +++ b/src/peft/utils/__init__.py @@ -56,13 +56,12 @@ from .peft_types import PeftType, TaskType, register_peft_method from .save_and_load import get_peft_model_state_dict, load_peft_weights, set_peft_model_state_dict from .osf_utils import ( + attach_gradient_hooks, auto_generate_target_osf_config, create_osf_model_class, decompose_weight_matrix, - optim_wrapper, project_gradient_to_orthogonal_space, reconstruct_weight_matrix, - wrap_model_with_osf, ) @@ -95,6 +94,7 @@ "_prepare_prompt_learning_config", "_set_adapter", "_set_trainable", + "attach_gradient_hooks", "auto_generate_target_osf_config", "bloom_model_postprocess_past_key_value", "cast_mixed_precision_params", @@ -108,7 +108,6 @@ "infer_device", "load_peft_weights", "map_cache_to_layer_device_map", - "optim_wrapper", "prepare_model_for_kbit_training", "project_gradient_to_orthogonal_space", "reconstruct_weight_matrix", @@ -118,5 +117,4 @@ "set_peft_model_state_dict", "shift_tokens_right", "transpose", - "wrap_model_with_osf", ] diff --git a/src/peft/utils/osf_utils.py b/src/peft/utils/osf_utils.py index 15e6b08a9b..7d553b415f 100644 --- a/src/peft/utils/osf_utils.py +++ b/src/peft/utils/osf_utils.py @@ -24,13 +24,12 @@ __all__ = [ + "attach_gradient_hooks", "auto_generate_target_osf_config", "create_osf_model_class", "decompose_weight_matrix", - "optim_wrapper", "project_gradient_to_orthogonal_space", "reconstruct_weight_matrix", - "wrap_model_with_osf", ] @@ -117,6 +116,29 @@ def project_gradient_to_orthogonal_space(svd_dict: dict[str, Any]) -> None: dV.copy_(local_dV) +def attach_gradient_hooks(model: nn.Module) -> None: + """Attach gradient hooks to project OSF gradients automatically.""" + if not hasattr(model, "svd_params"): + return + for safe_name, module_svd in model.svd_params.items(): + svd_dict = { + "U_high": getattr(model, f"{safe_name}_U_high"), + "S_high": getattr(model, f"{safe_name}_S_high"), + "V_high": getattr(model, f"{safe_name}_V_high"), + "U_low": module_svd.U_low, + "S_low": module_svd.S_low, + "V_low": module_svd.V_low, + } + + def hook(grad, svd=svd_dict): + project_gradient_to_orthogonal_space(svd) + return grad + + module_svd.U_low.register_hook(hook) + module_svd.S_low.register_hook(hook) + module_svd.V_low.register_hook(hook) + + def auto_generate_target_osf_config(model: nn.Module) -> dict[str, int]: """Create a mapping from parameter names to ``top_k`` based on layer size.""" target_patterns = [ @@ -150,6 +172,7 @@ def __init__(self, config, svd_config: dict[str, int] | None = None, initialize_ self.svd_params = nn.ModuleDict() if initialize_svd: self._initialize_svd_parameters() + attach_gradient_hooks(self) @classmethod def from_pretrained( @@ -171,6 +194,7 @@ def reinitialize_svd(self) -> None: self.name_mapping = {} self.svd_params = nn.ModuleDict() self._initialize_svd_parameters() + attach_gradient_hooks(self) def _get_module_by_name(self, name: str): parts = name.split(".") @@ -218,7 +242,6 @@ def forward(x): mod.forward = make_forward(safe_name, bias) param.requires_grad = False - mod._parameters.pop(attr, None) def _reconstruct_weight_by_safe_name(self, safe_name: str) -> torch.Tensor: U_high = getattr(self, f"{safe_name}_U_high") @@ -272,42 +295,3 @@ def prepare_state_dict_for_save(self, state_dict: dict[str, torch.Tensor]) -> di ModelWithOSF.__name__ = f"{base_cls.__name__}WithOSF" return ModelWithOSF - - -def optim_wrapper(optimizer: torch.optim.Optimizer, model: nn.Module) -> torch.optim.Optimizer: - """Wrap ``optimizer.step`` to project gradients before each update.""" - if not hasattr(model, "project_gradients"): - return optimizer - - import types - - orig_step = optimizer.step - - def step(self, *args, **kwargs): - model.project_gradients() - return orig_step(*args, **kwargs) - - optimizer.step = types.MethodType(step, optimizer) - return optimizer - - -def wrap_model_with_osf(model: nn.Module, svd_config: dict[str, int] | None = None) -> nn.Module: - """Return a version of ``model`` where selected weights are decomposed using SVD. - - Parameters ---------- model: - The model to wrap. It must expose a ``config`` attribute that will be passed to the wrapped class' constructor. - svd_config: - A mapping from parameter names to ``top_k`` ranks. If not provided, it is automatically generated based on the - layer shapes using :func:`auto_generate_target_osf_config`. - - Returns ------- ``nn.Module`` - A new model instance with the same weights as ``model`` but with trainable low-rank parameters and frozen - high-rank components. - """ - - svd_config = svd_config or auto_generate_target_osf_config(model) - OSFCls = create_osf_model_class(model.__class__) - wrapped = OSFCls(model.config, svd_config=svd_config, initialize_svd=False) - wrapped.load_state_dict(model.state_dict()) - wrapped.reinitialize_svd() - return wrapped diff --git a/tests/test_osf_utils.py b/tests/test_osf_utils.py index 2da13555d4..74ebd0b60f 100644 --- a/tests/test_osf_utils.py +++ b/tests/test_osf_utils.py @@ -1,10 +1,10 @@ import torch from torch.testing import assert_close +from peft import OSFConfig, get_peft_model from peft.utils.osf_utils import ( decompose_weight_matrix, reconstruct_weight_matrix, - wrap_model_with_osf, ) @@ -29,11 +29,11 @@ def forward(self, x): return self.linear(x) -def test_wrap_model_with_osf_preserves_output(): +def test_osf_get_peft_model_preserves_output(): torch.manual_seed(0) model = DummyModel(DummyConfig()) x = torch.randn(2, 8) y_ref = model(x) - wrapped = wrap_model_with_osf(model) + wrapped = get_peft_model(model, OSFConfig()) y = wrapped(x) assert_close(y, y_ref, atol=1e-5, rtol=1e-5) From 49774d8d5e5e4004c7ed2db623d057bd20bd87e3 Mon Sep 17 00:00:00 2001 From: Nikhil Nayak Date: Wed, 6 Aug 2025 14:22:56 +0000 Subject: [PATCH 07/19] adding test cases for various functionality of the method --- src/peft/tuners/osf/model.py | 9 ++++++- src/peft/utils/osf_utils.py | 13 ++++++++-- tests/test_config.py | 2 ++ tests/test_custom_models.py | 14 ++++++----- tests/test_osf_utils.py | 48 +++++++++++++++++++++++++++++++++++- 5 files changed, 76 insertions(+), 10 deletions(-) diff --git a/src/peft/tuners/osf/model.py b/src/peft/tuners/osf/model.py index f27098d47e..03e22c5d1e 100644 --- a/src/peft/tuners/osf/model.py +++ b/src/peft/tuners/osf/model.py @@ -29,7 +29,8 @@ def inject_adapter( svd_cfg = auto_generate_target_osf_config(model) self.peft_config[adapter_name].target_svd_config = svd_cfg OSFCls = create_osf_model_class(model.__class__) - osf_model = OSFCls(model.config, svd_config=svd_cfg, initialize_svd=False) + base_cfg = getattr(model, "config", None) + osf_model = OSFCls(base_cfg, svd_config=svd_cfg, initialize_svd=False) osf_model.load_state_dict(model.state_dict()) osf_model.reinitialize_svd() attach_gradient_hooks(osf_model) @@ -61,5 +62,11 @@ def set_adapter(self, adapter_name): def unload(self): raise NotImplementedError("OSF models cannot be unloaded yet") + def merge_adapter(self, *args, **kwargs): + raise NotImplementedError("OSF models do not support merging") + + def unmerge_adapter(self, *args, **kwargs): + raise NotImplementedError("OSF models do not support merging") + def merge_and_unload(self, *args, **kwargs): raise NotImplementedError("OSF models do not support merging") \ No newline at end of file diff --git a/src/peft/utils/osf_utils.py b/src/peft/utils/osf_utils.py index 7d553b415f..b3eb67cca0 100644 --- a/src/peft/utils/osf_utils.py +++ b/src/peft/utils/osf_utils.py @@ -165,8 +165,17 @@ def create_osf_model_class(base_cls: type) -> type: """Create a subclass of ``base_cls`` where selected linear weights are replaced by SVD decompositions.""" class ModelWithOSF(base_cls): - def __init__(self, config, svd_config: dict[str, int] | None = None, initialize_svd: bool = True, **kwargs): - super().__init__(config, **kwargs) + def __init__( + self, config=None, svd_config: dict[str, int] | None = None, initialize_svd: bool = True, **kwargs + ): + if config is not None: + try: + super().__init__(config, **kwargs) + except TypeError: + super().__init__(**kwargs) + self.config = config + else: + super().__init__(**kwargs) self.svd_config = svd_config or {} self.name_mapping: dict[str, str] = {} self.svd_params = nn.ModuleDict() diff --git a/tests/test_config.py b/tests/test_config.py index 179496b6f3..65be51f89c 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -33,6 +33,7 @@ LoraConfig, MultitaskPromptTuningConfig, OFTConfig, + OSFConfig, PeftConfig, PeftType, PolyConfig, @@ -60,6 +61,7 @@ (LoHaConfig, {}), (LoKrConfig, {}), (LoraConfig, {}), + (OSFConfig, {}), (MultitaskPromptTuningConfig, {}), (PolyConfig, {}), (PrefixTuningConfig, {}), diff --git a/tests/test_custom_models.py b/tests/test_custom_models.py index 735398033f..5a9308b5be 100644 --- a/tests/test_custom_models.py +++ b/tests/test_custom_models.py @@ -44,6 +44,7 @@ LoKrConfig, LoraConfig, OFTConfig, + OSFConfig, PeftModel, RandLoraConfig, ShiraConfig, @@ -625,7 +626,8 @@ ################################ # Orthogonal Subspace Learning # ################################ - ("Vanilla MLP OSF", "MLP", OSFConfig, {}), + ("Vanilla MLP 1 OSF", "MLP", OSFConfig, {}), + ("Vanilla MLP 2 OSF", "MLP", OSFConfig, {"target_svd_config": {"lin0.weight": 5, "lin1.weight": 1}}), ######## # RandLora # ######## @@ -1313,7 +1315,7 @@ def test_load_multiple_adapters(self, test_name, model_id, config_cls, config_kw @pytest.mark.parametrize("test_name, model_id, config_cls, config_kwargs", TEST_CASES) def test_merge_layers(self, test_name, model_id, config_cls, config_kwargs): # https://github.com/huggingface/peft/pull/2403 - if model_id in ["Conv2dGroups", "Conv2dGroups2"]: + if model_id in ["Conv2dGroups", "Conv2dGroups2"] or issubclass(config_cls, OSFConfig): pytest.skip( f"Skipping test for {model_id} as merging is not supported. (See https://github.com/huggingface/peft/pull/2403 for details)" ) @@ -1336,7 +1338,7 @@ def test_merge_layers(self, test_name, model_id, config_cls, config_kwargs): @pytest.mark.parametrize("test_name, model_id, config_cls, config_kwargs", TEST_CASES) def test_merge_layers_fp16(self, test_name, model_id, config_cls, config_kwargs): # https://github.com/huggingface/peft/pull/2403 - if model_id in ["Conv2dGroups", "Conv2dGroups2"]: + if model_id in ["Conv2dGroups", "Conv2dGroups2"] or issubclass(config_cls, OSFConfig): pytest.skip( f"Skipping test for {model_id} as merging is not supported. (See https://github.com/huggingface/peft/pull/2403 for details)" ) @@ -1351,7 +1353,7 @@ def test_merge_layers_fp16(self, test_name, model_id, config_cls, config_kwargs) @pytest.mark.parametrize("test_name, model_id, config_cls, config_kwargs", TEST_CASES) def test_merge_layers_is_idempotent(self, test_name, model_id, config_cls, config_kwargs): # https://github.com/huggingface/peft/pull/2403 - if model_id in ["Conv2dGroups", "Conv2dGroups2"]: + if model_id in ["Conv2dGroups", "Conv2dGroups2"] or issubclass(config_cls, OSFConfig): pytest.skip( f"Skipping test for {model_id} as merging is not supported. (See https://github.com/huggingface/peft/pull/2403 for details)" ) @@ -1367,7 +1369,7 @@ def test_merge_layers_is_idempotent(self, test_name, model_id, config_cls, confi @pytest.mark.parametrize("test_name, model_id, config_cls, config_kwargs", TEST_CASES) def test_safe_merge(self, test_name, model_id, config_cls, config_kwargs): # https://github.com/huggingface/peft/pull/2403 - if model_id in ["Conv2dGroups", "Conv2dGroups2"]: + if model_id in ["Conv2dGroups", "Conv2dGroups2"] or issubclass(config_cls, OSFConfig): pytest.skip( f"Skipping test for {model_id} as merging is not supported. (See https://github.com/huggingface/peft/pull/2403 for details)" ) @@ -1762,7 +1764,7 @@ def test_disable_adapters(self, test_name, model_id, config_cls, config_kwargs): @pytest.mark.parametrize("test_name, model_id, config_cls, config_kwargs", TEST_CASES) def test_disable_adapters_with_merging(self, test_name, model_id, config_cls, config_kwargs): # https://github.com/huggingface/peft/pull/2403 - if model_id in ["Conv2dGroups", "Conv2dGroups2"]: + if model_id in ["Conv2dGroups", "Conv2dGroups2"] or issubclass(config_cls, OSFConfig): pytest.skip( f"Skipping test for {model_id} as merging is not supported. (See https://github.com/huggingface/peft/pull/2403 for details)" ) diff --git a/tests/test_osf_utils.py b/tests/test_osf_utils.py index 74ebd0b60f..5d1c1ddc9a 100644 --- a/tests/test_osf_utils.py +++ b/tests/test_osf_utils.py @@ -1,3 +1,6 @@ +from tempfile import TemporaryDirectory + +import pytest import torch from torch.testing import assert_close @@ -15,7 +18,7 @@ def test_osf_roundtrip(): assert_close(w_rec, w, atol=1e-5, rtol=1e-5) -class DummyConfig: +class DummyConfig(dict): pass @@ -37,3 +40,46 @@ def test_osf_get_peft_model_preserves_output(): wrapped = get_peft_model(model, OSFConfig()) y = wrapped(x) assert_close(y, y_ref, atol=1e-5, rtol=1e-5) + + +def test_osf_gradient_projection_hook(): + torch.manual_seed(0) + model = DummyModel(DummyConfig()) + cfg = OSFConfig(target_svd_config={"linear.weight": 2}) + wrapped = get_peft_model(model, cfg) + x = torch.randn(3, 8) + wrapped(x).sum().backward() + inner = wrapped.base_model.model + safe_name = next(iter(inner.svd_params)) + module_svd = inner.svd_params[safe_name] + U_high = getattr(inner, f"{safe_name}_U_high") + V_high = getattr(inner, f"{safe_name}_V_high") + assert_close( + U_high.T @ module_svd.U_low.grad, torch.zeros_like(U_high.T @ module_svd.U_low.grad), atol=1e-6, rtol=1e-6 + ) + assert_close( + module_svd.V_low.grad @ V_high.T, + torch.zeros_like(module_svd.V_low.grad @ V_high.T), + atol=1e-6, + rtol=1e-6, + ) + + +def test_osf_config_roundtrip(): + cfg = OSFConfig(target_svd_config={"linear.weight": 2}) + with TemporaryDirectory() as tmp: + cfg.save_pretrained(tmp) + loaded = OSFConfig.from_pretrained(tmp) + assert cfg.target_svd_config == loaded.target_svd_config + + +def test_osf_merge_unmerge_unsupported(): + model = DummyModel(DummyConfig()) + cfg = OSFConfig(target_svd_config={"linear.weight": 2}) + wrapped = get_peft_model(model, cfg) + with pytest.raises(NotImplementedError): + wrapped.merge_adapter() + with pytest.raises(NotImplementedError): + wrapped.unmerge_adapter() + with pytest.raises(NotImplementedError): + wrapped.merge_and_unload() \ No newline at end of file From 22df96d4921be5fc37d6c3490e225690914abe2a Mon Sep 17 00:00:00 2001 From: Nikhil Nayak Date: Wed, 6 Aug 2025 15:30:09 +0000 Subject: [PATCH 08/19] adding test for loading and saving OSFT model --- src/peft/tuners/__init__.py | 1 + tests/test_osf_utils.py | 35 ++++++++++++++++++++++++++++++++++- 2 files changed, 35 insertions(+), 1 deletion(-) diff --git a/src/peft/tuners/__init__.py b/src/peft/tuners/__init__.py index 3d0a94329a..7fe83de048 100644 --- a/src/peft/tuners/__init__.py +++ b/src/peft/tuners/__init__.py @@ -36,6 +36,7 @@ from .mixed import MixedModel from .multitask_prompt_tuning import MultitaskPromptEmbedding, MultitaskPromptTuningConfig, MultitaskPromptTuningInit from .oft import OFTConfig, OFTModel +from .osf import OSFConfig, OSFModel from .p_tuning import PromptEncoder, PromptEncoderConfig, PromptEncoderReparameterizationType from .poly import PolyConfig, PolyModel from .prefix_tuning import PrefixEncoder, PrefixTuningConfig diff --git a/tests/test_osf_utils.py b/tests/test_osf_utils.py index 5d1c1ddc9a..1a8f27529f 100644 --- a/tests/test_osf_utils.py +++ b/tests/test_osf_utils.py @@ -4,7 +4,7 @@ import torch from torch.testing import assert_close -from peft import OSFConfig, get_peft_model +from peft import OSFConfig, PeftModel, get_peft_model from peft.utils.osf_utils import ( decompose_weight_matrix, reconstruct_weight_matrix, @@ -73,6 +73,39 @@ def test_osf_config_roundtrip(): assert cfg.target_svd_config == loaded.target_svd_config +def test_osf_save_and_load_model(): + torch.manual_seed(0) + model = DummyModel(DummyConfig()) + cfg = OSFConfig(target_svd_config={"linear.weight": 2}) + wrapped = get_peft_model(model, cfg) + x = torch.randn(2, 8) + y_ref = wrapped(x) + with TemporaryDirectory() as tmp: + wrapped.save_pretrained(tmp) + torch.manual_seed(0) + base = DummyModel(DummyConfig()) + loaded = PeftModel.from_pretrained(base, tmp) + y = loaded(x) + assert_close(y, y_ref, atol=1e-5, rtol=1e-5) + + +def test_osf_save_and_load_autogen_config(): + torch.manual_seed(0) + model = DummyModel(DummyConfig()) + wrapped = get_peft_model(model, OSFConfig()) + x = torch.randn(2, 8) + y_ref = wrapped(x) + original_cfg = wrapped.peft_config["default"].target_svd_config + with TemporaryDirectory() as tmp: + wrapped.save_pretrained(tmp) + torch.manual_seed(0) + base = DummyModel(DummyConfig()) + loaded = PeftModel.from_pretrained(base, tmp) + y = loaded(x) + assert_close(y, y_ref, atol=1e-5, rtol=1e-5) + assert original_cfg == loaded.peft_config["default"].target_svd_config + + def test_osf_merge_unmerge_unsupported(): model = DummyModel(DummyConfig()) cfg = OSFConfig(target_svd_config={"linear.weight": 2}) From aacbef2f3176f1649964dc8c74bf07b97e02e37d Mon Sep 17 00:00:00 2001 From: Nikhil Nayak Date: Wed, 13 Aug 2025 14:54:35 +0000 Subject: [PATCH 09/19] adding more test cases for OSF method --- tests/test_custom_models.py | 32 ++++++---------------------- tests/test_decoder_models.py | 7 ++++++ tests/test_encoder_decoder_models.py | 7 ++++++ tests/testing_common.py | 22 +++++++++++++++++++ 4 files changed, 42 insertions(+), 26 deletions(-) diff --git a/tests/test_custom_models.py b/tests/test_custom_models.py index 5a9308b5be..f60ca54a36 100644 --- a/tests/test_custom_models.py +++ b/tests/test_custom_models.py @@ -58,7 +58,7 @@ from peft.tuners.tuners_utils import BaseTunerLayer from peft.utils import AuxiliaryTrainingWrapper, infer_device -from .testing_common import PeftCommonTester +from .testing_common import PeftCommonTester, _skip_if_merging_not_supported from .testing_utils import get_state_dict, require_non_cpu @@ -1314,11 +1314,7 @@ def test_load_multiple_adapters(self, test_name, model_id, config_cls, config_kw @pytest.mark.parametrize("test_name, model_id, config_cls, config_kwargs", TEST_CASES) def test_merge_layers(self, test_name, model_id, config_cls, config_kwargs): - # https://github.com/huggingface/peft/pull/2403 - if model_id in ["Conv2dGroups", "Conv2dGroups2"] or issubclass(config_cls, OSFConfig): - pytest.skip( - f"Skipping test for {model_id} as merging is not supported. (See https://github.com/huggingface/peft/pull/2403 for details)" - ) + _skip_if_merging_not_supported(model_id, config_cls) config_kwargs = config_kwargs.copy() if issubclass(config_cls, LoraConfig): @@ -1337,11 +1333,7 @@ def test_merge_layers(self, test_name, model_id, config_cls, config_kwargs): @pytest.mark.parametrize("test_name, model_id, config_cls, config_kwargs", TEST_CASES) def test_merge_layers_fp16(self, test_name, model_id, config_cls, config_kwargs): - # https://github.com/huggingface/peft/pull/2403 - if model_id in ["Conv2dGroups", "Conv2dGroups2"] or issubclass(config_cls, OSFConfig): - pytest.skip( - f"Skipping test for {model_id} as merging is not supported. (See https://github.com/huggingface/peft/pull/2403 for details)" - ) + _skip_if_merging_not_supported(model_id, config_cls) config_kwargs = config_kwargs.copy() if issubclass(config_cls, LoraConfig): @@ -1352,11 +1344,7 @@ def test_merge_layers_fp16(self, test_name, model_id, config_cls, config_kwargs) @pytest.mark.parametrize("test_name, model_id, config_cls, config_kwargs", TEST_CASES) def test_merge_layers_is_idempotent(self, test_name, model_id, config_cls, config_kwargs): - # https://github.com/huggingface/peft/pull/2403 - if model_id in ["Conv2dGroups", "Conv2dGroups2"] or issubclass(config_cls, OSFConfig): - pytest.skip( - f"Skipping test for {model_id} as merging is not supported. (See https://github.com/huggingface/peft/pull/2403 for details)" - ) + _skip_if_merging_not_supported(model_id, config_cls) # calling merge twice with the same arguments should not change the output config_kwargs = config_kwargs.copy() @@ -1368,11 +1356,7 @@ def test_merge_layers_is_idempotent(self, test_name, model_id, config_cls, confi @pytest.mark.parametrize("test_name, model_id, config_cls, config_kwargs", TEST_CASES) def test_safe_merge(self, test_name, model_id, config_cls, config_kwargs): - # https://github.com/huggingface/peft/pull/2403 - if model_id in ["Conv2dGroups", "Conv2dGroups2"] or issubclass(config_cls, OSFConfig): - pytest.skip( - f"Skipping test for {model_id} as merging is not supported. (See https://github.com/huggingface/peft/pull/2403 for details)" - ) + _skip_if_merging_not_supported(model_id, config_cls) # calling merge twice with the same arguments should not change the output config_kwargs = config_kwargs.copy() @@ -1763,11 +1747,7 @@ def test_disable_adapters(self, test_name, model_id, config_cls, config_kwargs): @pytest.mark.parametrize("test_name, model_id, config_cls, config_kwargs", TEST_CASES) def test_disable_adapters_with_merging(self, test_name, model_id, config_cls, config_kwargs): - # https://github.com/huggingface/peft/pull/2403 - if model_id in ["Conv2dGroups", "Conv2dGroups2"] or issubclass(config_cls, OSFConfig): - pytest.skip( - f"Skipping test for {model_id} as merging is not supported. (See https://github.com/huggingface/peft/pull/2403 for details)" - ) + _skip_if_merging_not_supported(model_id, config_cls) # same as test_disable_adapters, but with merging X = self.prepare_inputs_for_testing() diff --git a/tests/test_decoder_models.py b/tests/test_decoder_models.py index 3e756c2f43..0780aac467 100644 --- a/tests/test_decoder_models.py +++ b/tests/test_decoder_models.py @@ -36,6 +36,7 @@ IA3Config, LoraConfig, OFTConfig, + OSFConfig, PrefixTuningConfig, PromptEncoderConfig, PromptTuningConfig, @@ -221,6 +222,12 @@ "target_modules": None, }, ), + ( + OSFConfig, + { + "task_type": "CAUSAL_LM", + }, + ), ] diff --git a/tests/test_encoder_decoder_models.py b/tests/test_encoder_decoder_models.py index 3fca67683d..dff4e724f2 100644 --- a/tests/test_encoder_decoder_models.py +++ b/tests/test_encoder_decoder_models.py @@ -27,6 +27,7 @@ IA3Config, LoraConfig, OFTConfig, + OSFConfig, PrefixTuningConfig, PromptEncoderConfig, PromptTuningConfig, @@ -186,6 +187,12 @@ "target_modules": None, }, ), + ( + OSFConfig, + { + "task_type": "SEQ_2_SEQ_LM", + }, + ), ] diff --git a/tests/testing_common.py b/tests/testing_common.py index 5602199d52..3f441ecd2c 100644 --- a/tests/testing_common.py +++ b/tests/testing_common.py @@ -43,6 +43,7 @@ LoKrConfig, LoraConfig, OFTConfig, + OSFConfig, PeftModel, PeftType, PrefixTuningConfig, @@ -234,6 +235,15 @@ def test_something(model_id, config_kwargs): raise +def _skip_if_merging_not_supported(model_id, config_cls): + """Skip tests for Conv2dGroups models or OSF configs where merging is not supported.""" + if model_id in ["Conv2dGroups", "Conv2dGroups2"] or issubclass(config_cls, OSFConfig): + pytest.skip( + f"Skipping test for {model_id} with {config_cls} as merging is not supported. " + "(See https://github.com/huggingface/peft/pull/2403 for details)" + ) + + class PeftCommonTester: r""" A large testing suite for testing common functionality of the PEFT models. @@ -629,6 +639,8 @@ def _test_load_multiple_adapters(self, model_id, config_cls, config_kwargs): assert load_result2.missing_keys == [] def _test_merge_layers_fp16(self, model_id, config_cls, config_kwargs): + _skip_if_merging_not_supported(model_id, config_cls) + if config_cls not in (LoraConfig, IA3Config, AdaLoraConfig, LoHaConfig, LoKrConfig, VBLoRAConfig): # Merge layers only supported for LoRA and IA³ return pytest.skip(f"Test not applicable for {config_cls}") @@ -654,6 +666,8 @@ def _test_merge_layers_fp16(self, model_id, config_cls, config_kwargs): _ = model.merge_and_unload() def _test_merge_layers_nan(self, model_id, config_cls, config_kwargs): + _skip_if_merging_not_supported(model_id, config_cls) + if config_cls not in ( LoraConfig, IA3Config, @@ -737,6 +751,8 @@ def _test_merge_layers_nan(self, model_id, config_cls, config_kwargs): model = model.merge_and_unload(safe_merge=True) def _test_merge_layers(self, model_id, config_cls, config_kwargs): + _skip_if_merging_not_supported(model_id, config_cls) + if issubclass(config_cls, PromptLearningConfig): return pytest.skip(f"Test not applicable for {config_cls}") @@ -819,6 +835,8 @@ def _test_merge_layers(self, model_id, config_cls, config_kwargs): assert torch.allclose(logits_merged, logits_merged_from_pretrained, atol=atol, rtol=rtol) def _test_merge_layers_multi(self, model_id, config_cls, config_kwargs): + _skip_if_merging_not_supported(model_id, config_cls) + supported_peft_types = [ PeftType.LORA, PeftType.LOHA, @@ -899,6 +917,8 @@ def _test_merge_layers_multi(self, model_id, config_cls, config_kwargs): assert torch.allclose(logits_merged_adapter_default, logits_adapter_1, atol=1e-3, rtol=1e-3) def _test_merge_layers_is_idempotent(self, model_id, config_cls, config_kwargs): + _skip_if_merging_not_supported(model_id, config_cls) + with hub_online_once(model_id): model = self.transformers_class.from_pretrained(model_id) config = config_cls( @@ -921,6 +941,8 @@ def _test_merge_layers_is_idempotent(self, model_id, config_cls, config_kwargs): assert torch.allclose(logits_0, logits_1, atol=1e-6, rtol=1e-6) def _test_safe_merge(self, model_id, config_cls, config_kwargs): + _skip_if_merging_not_supported(model_id, config_cls) + torch.manual_seed(0) with hub_online_once(model_id): model = self.transformers_class.from_pretrained(model_id) From fdb6d73f8c3d05764d08b79faefaacf4332b6c72 Mon Sep 17 00:00:00 2001 From: Nikhil Nayak Date: Wed, 13 Aug 2025 18:08:58 +0000 Subject: [PATCH 10/19] moving OSF util methods to the appropriate directory --- src/peft/__init__.py | 12 ------------ src/peft/tuners/osf/model.py | 2 +- .../{utils/osf_utils.py => tuners/osf/utils.py} | 4 ++-- src/peft/utils/__init__.py | 14 -------------- tests/test_osf_utils.py | 2 +- 5 files changed, 4 insertions(+), 30 deletions(-) rename src/peft/{utils/osf_utils.py => tuners/osf/utils.py} (99%) diff --git a/src/peft/__init__.py b/src/peft/__init__.py index 72cc76e9fd..af953bbcca 100644 --- a/src/peft/__init__.py +++ b/src/peft/__init__.py @@ -110,17 +110,11 @@ TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING, PeftType, TaskType, - attach_gradient_hooks, - auto_generate_target_osf_config, bloom_model_postprocess_past_key_value, cast_mixed_precision_params, - create_osf_model_class, - decompose_weight_matrix, get_peft_model_state_dict, load_peft_weights, prepare_model_for_kbit_training, - project_gradient_to_orthogonal_space, - reconstruct_weight_matrix, replace_lora_weights_loftq, set_peft_model_state_dict, shift_tokens_right, @@ -210,12 +204,8 @@ "VeraModel", "XLoraConfig", "XLoraModel", - "attach_gradient_hooks", - "auto_generate_target_osf_config", "bloom_model_postprocess_past_key_value", "cast_mixed_precision_params", - "create_osf_model_class", - "decompose_weight_matrix", "get_eva_state_dict", "get_layer_status", "get_model_status", @@ -226,8 +216,6 @@ "inject_adapter_in_model", "load_peft_weights", "prepare_model_for_kbit_training", - "project_gradient_to_orthogonal_space", - "reconstruct_weight_matrix", "replace_lora_weights_loftq", "set_peft_model_state_dict", "shift_tokens_right", diff --git a/src/peft/tuners/osf/model.py b/src/peft/tuners/osf/model.py index 03e22c5d1e..9b88122c12 100644 --- a/src/peft/tuners/osf/model.py +++ b/src/peft/tuners/osf/model.py @@ -3,7 +3,7 @@ import torch.nn as nn from peft.tuners.tuners_utils import BaseTuner -from peft.utils.osf_utils import ( +from .utils import ( attach_gradient_hooks, auto_generate_target_osf_config, create_osf_model_class, diff --git a/src/peft/utils/osf_utils.py b/src/peft/tuners/osf/utils.py similarity index 99% rename from src/peft/utils/osf_utils.py rename to src/peft/tuners/osf/utils.py index b3eb67cca0..332b5f3889 100644 --- a/src/peft/utils/osf_utils.py +++ b/src/peft/tuners/osf/utils.py @@ -1,4 +1,4 @@ -# Copyright 2024-present the HuggingFace Inc. team. +# Copyright 2025-present the HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -303,4 +303,4 @@ def prepare_state_dict_for_save(self, state_dict: dict[str, torch.Tensor]) -> di return state_dict ModelWithOSF.__name__ = f"{base_cls.__name__}WithOSF" - return ModelWithOSF + return ModelWithOSF \ No newline at end of file diff --git a/src/peft/utils/__init__.py b/src/peft/utils/__init__.py index fc7e9dc4f7..3b992d8aac 100644 --- a/src/peft/utils/__init__.py +++ b/src/peft/utils/__init__.py @@ -55,14 +55,6 @@ ) from .peft_types import PeftType, TaskType, register_peft_method from .save_and_load import get_peft_model_state_dict, load_peft_weights, set_peft_model_state_dict -from .osf_utils import ( - attach_gradient_hooks, - auto_generate_target_osf_config, - create_osf_model_class, - decompose_weight_matrix, - project_gradient_to_orthogonal_space, - reconstruct_weight_matrix, -) __all__ = [ @@ -94,12 +86,8 @@ "_prepare_prompt_learning_config", "_set_adapter", "_set_trainable", - "attach_gradient_hooks", - "auto_generate_target_osf_config", "bloom_model_postprocess_past_key_value", "cast_mixed_precision_params", - "create_osf_model_class", - "decompose_weight_matrix", "get_auto_gptq_quant_linear", "get_gptqmodel_quant_linear", "get_peft_model_state_dict", @@ -109,8 +97,6 @@ "load_peft_weights", "map_cache_to_layer_device_map", "prepare_model_for_kbit_training", - "project_gradient_to_orthogonal_space", - "reconstruct_weight_matrix", "register_peft_method", "replace_lora_weights_loftq", "set_additional_trainable_modules", diff --git a/tests/test_osf_utils.py b/tests/test_osf_utils.py index 1a8f27529f..2efb44d12b 100644 --- a/tests/test_osf_utils.py +++ b/tests/test_osf_utils.py @@ -5,7 +5,7 @@ from torch.testing import assert_close from peft import OSFConfig, PeftModel, get_peft_model -from peft.utils.osf_utils import ( +from peft.tuners.osf.utils import ( decompose_weight_matrix, reconstruct_weight_matrix, ) From 7533cdfe59add59f0bccc754a55474b7ce04ed5b Mon Sep 17 00:00:00 2001 From: Nikhil Nayak Date: Wed, 13 Aug 2025 19:02:41 +0000 Subject: [PATCH 11/19] removed redundant check while generating osf config --- src/peft/tuners/osf/utils.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/peft/tuners/osf/utils.py b/src/peft/tuners/osf/utils.py index 332b5f3889..e070b6b71f 100644 --- a/src/peft/tuners/osf/utils.py +++ b/src/peft/tuners/osf/utils.py @@ -153,10 +153,7 @@ def auto_generate_target_osf_config(model: nn.Module) -> dict[str, int]: config: dict[str, int] = {} for name, param in model.named_parameters(): if any(pat in name for pat in target_patterns) and len(param.shape) == 2: - top_k = int(math.floor(min(param.shape) * 0.5)) - full_rank = min(param.shape) - if top_k >= full_rank: - top_k = full_rank - 1 + top_k = min(param.shape) // 2 config[name] = top_k return config From 5b87c7d64929f90a592cb6b1203d233460edbc4a Mon Sep 17 00:00:00 2001 From: Nikhil Nayak Date: Wed, 13 Aug 2025 19:42:25 +0000 Subject: [PATCH 12/19] handle async calls in OSF gradient projection --- src/peft/tuners/osf/utils.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/src/peft/tuners/osf/utils.py b/src/peft/tuners/osf/utils.py index e070b6b71f..49dbb124fc 100644 --- a/src/peft/tuners/osf/utils.py +++ b/src/peft/tuners/osf/utils.py @@ -33,6 +33,13 @@ ] +def _wait_if_async(tensor): + """Wait for AsyncCollectiveTensor if needed, otherwise return tensor as-is.""" + if hasattr(tensor, "wait"): + return tensor.wait() + return tensor + + def decompose_weight_matrix(weight: torch.Tensor, top_k: int) -> dict[str, Any]: """Perform an SVD of ``weight`` and split it into frozen and trainable parts.""" device_local = weight.device @@ -85,8 +92,8 @@ def project_gradient_to_orthogonal_space(svd_dict: dict[str, Any]) -> None: if svd_dict["U_low"].grad is not None: dU = svd_dict["U_low"].grad - local_U_high = getattr(U_high, "to_local", lambda: U_high)() - local_dU = getattr(dU, "to_local", lambda: dU)() + local_U_high = _wait_if_async(getattr(U_high, "to_local", lambda: U_high)()) + local_dU = _wait_if_async(getattr(dU, "to_local", lambda: dU)()) if local_U_high.size(0) != local_dU.size(0): rank = torch.distributed.get_rank() start = rank * local_dU.size(0) @@ -101,8 +108,8 @@ def project_gradient_to_orthogonal_space(svd_dict: dict[str, Any]) -> None: if svd_dict["V_low"].grad is not None: dV = svd_dict["V_low"].grad - local_V_high = getattr(V_high, "to_local", lambda: V_high)() - local_dV = getattr(dV, "to_local", lambda: dV)() + local_V_high = _wait_if_async(getattr(V_high, "to_local", lambda: V_high)()) + local_dV = _wait_if_async(getattr(dV, "to_local", lambda: dV)()) if local_V_high.size(1) != local_dV.size(1): rank = torch.distributed.get_rank() start = rank * local_dV.size(1) From 1dd5c687f37e70ac664987e5705825e95b38db82 Mon Sep 17 00:00:00 2001 From: Nikhil Nayak Date: Wed, 13 Aug 2025 19:49:33 +0000 Subject: [PATCH 13/19] removed unnecessary DTensor distinction --- src/peft/tuners/osf/utils.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/src/peft/tuners/osf/utils.py b/src/peft/tuners/osf/utils.py index 49dbb124fc..a8f9c7b474 100644 --- a/src/peft/tuners/osf/utils.py +++ b/src/peft/tuners/osf/utils.py @@ -101,10 +101,7 @@ def project_gradient_to_orthogonal_space(svd_dict: dict[str, Any]) -> None: local_U_high = local_U_high[start:end] proj = local_U_high @ (local_U_high.transpose(0, 1) @ local_dU) local_dU.sub_(proj) - if hasattr(dU, "_local_tensor"): - dU._local_tensor.copy_(local_dU) - else: - dU.copy_(local_dU) + dU.copy_(local_dU) if svd_dict["V_low"].grad is not None: dV = svd_dict["V_low"].grad @@ -117,10 +114,7 @@ def project_gradient_to_orthogonal_space(svd_dict: dict[str, Any]) -> None: local_V_high = local_V_high[:, start:end] proj = (local_dV @ local_V_high.transpose(0, 1)) @ local_V_high local_dV.sub_(proj) - if hasattr(dV, "_local_tensor"): - dV._local_tensor.copy_(local_dV) - else: - dV.copy_(local_dV) + dV.copy_(local_dV) def attach_gradient_hooks(model: nn.Module) -> None: From 3cb88fb2950f396ee09b922aeeea875a8ba5af34 Mon Sep 17 00:00:00 2001 From: Nikhil Nayak Date: Wed, 13 Aug 2025 19:58:00 +0000 Subject: [PATCH 14/19] simplifying gradient hook method --- src/peft/tuners/osf/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/peft/tuners/osf/utils.py b/src/peft/tuners/osf/utils.py index a8f9c7b474..a5b9f5c68e 100644 --- a/src/peft/tuners/osf/utils.py +++ b/src/peft/tuners/osf/utils.py @@ -131,8 +131,8 @@ def attach_gradient_hooks(model: nn.Module) -> None: "V_low": module_svd.V_low, } - def hook(grad, svd=svd_dict): - project_gradient_to_orthogonal_space(svd) + def hook(grad): + project_gradient_to_orthogonal_space(svd_dict) return grad module_svd.U_low.register_hook(hook) From a0e445ea10d73b7466d2cc8fca357691cec1e918 Mon Sep 17 00:00:00 2001 From: Nikhil Nayak Date: Wed, 13 Aug 2025 20:07:20 +0000 Subject: [PATCH 15/19] fix: implement proper gradient hook management for OSF tuner --- src/peft/tuners/osf/utils.py | 25 ++++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/src/peft/tuners/osf/utils.py b/src/peft/tuners/osf/utils.py index a5b9f5c68e..6cd1adea41 100644 --- a/src/peft/tuners/osf/utils.py +++ b/src/peft/tuners/osf/utils.py @@ -28,6 +28,7 @@ "auto_generate_target_osf_config", "create_osf_model_class", "decompose_weight_matrix", + "detach_gradient_hooks", "project_gradient_to_orthogonal_space", "reconstruct_weight_matrix", ] @@ -121,6 +122,11 @@ def attach_gradient_hooks(model: nn.Module) -> None: """Attach gradient hooks to project OSF gradients automatically.""" if not hasattr(model, "svd_params"): return + + # Initialize hook_handles if not exists + if not hasattr(model, "hook_handles"): + model.hook_handles = [] + for safe_name, module_svd in model.svd_params.items(): svd_dict = { "U_high": getattr(model, f"{safe_name}_U_high"), @@ -135,9 +141,20 @@ def hook(grad): project_gradient_to_orthogonal_space(svd_dict) return grad - module_svd.U_low.register_hook(hook) - module_svd.S_low.register_hook(hook) - module_svd.V_low.register_hook(hook) + # Store hook handles for later cleanup + handle_u = module_svd.U_low.register_hook(hook) + handle_s = module_svd.S_low.register_hook(hook) + handle_v = module_svd.V_low.register_hook(hook) + + model.hook_handles.extend([handle_u, handle_s, handle_v]) + + +def detach_gradient_hooks(model: nn.Module) -> None: + """Remove all gradient hooks that were attached by attach_gradient_hooks.""" + if hasattr(model, "hook_handles"): + for handle in model.hook_handles: + handle.remove() + model.hook_handles.clear() def auto_generate_target_osf_config(model: nn.Module) -> dict[str, int]: @@ -198,6 +215,8 @@ def from_pretrained( return model def reinitialize_svd(self) -> None: + # Clean up existing hooks before reinitializing + detach_gradient_hooks(self) self.name_mapping = {} self.svd_params = nn.ModuleDict() self._initialize_svd_parameters() From d28b9d7e8330f56902f253bdb0c211312b111255 Mon Sep 17 00:00:00 2001 From: Nikhil Nayak Date: Wed, 13 Aug 2025 20:19:20 +0000 Subject: [PATCH 16/19] adding model-specific constants for OSF target modules --- src/peft/tuners/osf/utils.py | 32 +++++++++++++++++++++++--------- src/peft/utils/constants.py | 20 ++++++++++++++++++++ 2 files changed, 43 insertions(+), 9 deletions(-) diff --git a/src/peft/tuners/osf/utils.py b/src/peft/tuners/osf/utils.py index 6cd1adea41..91f66ef384 100644 --- a/src/peft/tuners/osf/utils.py +++ b/src/peft/tuners/osf/utils.py @@ -22,6 +22,8 @@ from torch import nn from torch.nn import functional as F +from peft.utils.constants import TRANSFORMERS_MODELS_TO_OSF_TARGET_MODULES_MAPPING + __all__ = [ "attach_gradient_hooks", @@ -159,15 +161,27 @@ def detach_gradient_hooks(model: nn.Module) -> None: def auto_generate_target_osf_config(model: nn.Module) -> dict[str, int]: """Create a mapping from parameter names to ``top_k`` based on layer size.""" - target_patterns = [ - "self_attn.q_proj", - "self_attn.k_proj", - "self_attn.v_proj", - "self_attn.o_proj", - "mlp.gate_proj", - "mlp.down_proj", - "mlp.up_proj", - ] + # Get model type and corresponding target modules + model_type = getattr(model.config, "model_type", None) + target_modules = TRANSFORMERS_MODELS_TO_OSF_TARGET_MODULES_MAPPING.get(model_type, []) + + # Fallback to hardcoded patterns if model type not found + if not target_modules: + target_patterns = [ + "self_attn.q_proj", + "self_attn.k_proj", + "self_attn.v_proj", + "self_attn.o_proj", + "mlp.gate_proj", + "mlp.down_proj", + "mlp.up_proj", + ] + else: + # Convert module names to patterns that match the full parameter names + target_patterns = [f"self_attn.{mod}" if mod in ["q_proj", "k_proj", "v_proj", "o_proj"] + else f"mlp.{mod}" if mod in ["gate_proj", "down_proj", "up_proj"] + else mod for mod in target_modules] + config: dict[str, int] = {} for name, param in model.named_parameters(): if any(pat in name for pat in target_patterns) and len(param.shape) == 2: diff --git a/src/peft/utils/constants.py b/src/peft/utils/constants.py index a765e7b1f7..31948d351e 100644 --- a/src/peft/utils/constants.py +++ b/src/peft/utils/constants.py @@ -392,6 +392,26 @@ def starcoder_model_postprocess_past_key_value(past_key_values): "qwen3": ["q_proj", "v_proj"], } +TRANSFORMERS_MODELS_TO_OSF_TARGET_MODULES_MAPPING = { + "llama": ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "down_proj", "up_proj"], + "llama4": ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "down_proj", "up_proj"], + "mistral": ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "down_proj", "up_proj"], + "mixtral": ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "down_proj", "up_proj"], + "gemma": ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "down_proj", "up_proj"], + "gemma2": ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "down_proj", "up_proj"], + "gemma3_text": ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "down_proj", "up_proj"], + "qwen2": ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "down_proj", "up_proj"], + "qwen3": ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "down_proj", "up_proj"], + "phi": ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "down_proj", "up_proj"], + "gpt2": ["c_attn", "c_proj"], + "bloom": ["query_key_value", "dense_4h_to_h"], + "opt": ["q_proj", "k_proj", "v_proj", "out_proj", "fc1", "fc2"], + "gptj": ["q_proj", "k_proj", "v_proj", "out_proj", "fc_in", "fc_out"], + "gpt_neox": ["query_key_value", "dense_4h_to_h"], + "falcon": ["query_key_value", "dense_4h_to_h"], + "gpt_bigcode": ["c_attn", "c_proj"], +} + TRANSFORMERS_MODELS_TO_RANDLORA_TARGET_MODULES_MAPPING = ( TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING # Leaving this for now but RandLoRA is flexible ) From 2eba741746d7536aff5bb01ce8a34c0c2d282e48 Mon Sep 17 00:00:00 2001 From: Nikhil Nayak Date: Thu, 14 Aug 2025 15:40:31 +0000 Subject: [PATCH 17/19] refactor: implement minimal PEFT integration for OSF tuner --- src/peft/tuners/osf/__init__.py | 3 +- src/peft/tuners/osf/config.py | 25 ++- src/peft/tuners/osf/layer.py | 274 ++++++++++++++++++++++++++++++++ src/peft/tuners/osf/model.py | 52 ++++-- 4 files changed, 341 insertions(+), 13 deletions(-) create mode 100644 src/peft/tuners/osf/layer.py diff --git a/src/peft/tuners/osf/__init__.py b/src/peft/tuners/osf/__init__.py index 4cf83ac38b..66e2517f46 100644 --- a/src/peft/tuners/osf/__init__.py +++ b/src/peft/tuners/osf/__init__.py @@ -1,9 +1,10 @@ from peft.utils import register_peft_method from .config import OSFConfig +from .layer import OSFLayer, Linear from .model import OSFModel -__all__ = ["OSFConfig", "OSFModel"] +__all__ = ["OSFConfig", "OSFModel", "OSFLayer", "Linear"] register_peft_method( name="osf", diff --git a/src/peft/tuners/osf/config.py b/src/peft/tuners/osf/config.py index 3e3f14945b..09dacfabe7 100644 --- a/src/peft/tuners/osf/config.py +++ b/src/peft/tuners/osf/config.py @@ -1,6 +1,7 @@ from __future__ import annotations from dataclasses import dataclass, field +from typing import Optional, Union from peft.config import PeftConfig from peft.utils import PeftType @@ -8,11 +9,29 @@ @dataclass class OSFConfig(PeftConfig): - """Configuration for Orthogonal Subspace Fine-tuning (OSF).""" + """ + Configuration for Orthogonal Subspace Fine-tuning (OSF). + + Args: + effective_rank (`int`, *optional*): + The effective rank for OSF decomposition. If None, defaults to 50% of min(weight.shape). + target_modules (`Union[list[str], str]`, *optional*): + The names of the modules to apply OSF to. Can be a list of module names or 'all-linear'. + rank_pattern (`dict[str, int]`, *optional*): + A dictionary of regex patterns to override effective_rank for specific modules. + """ - target_svd_config: dict[str, int] | None = field( + effective_rank: Optional[int] = field( default=None, - metadata={"help": "Mapping from parameter names to top_k rank."}, + metadata={"help": "The effective rank for OSF decomposition. If None, defaults to 50% of min(weight.shape)."} + ) + target_modules: Optional[Union[list[str], str]] = field( + default=None, + metadata={"help": "The names of the modules to apply OSF to. Can be a list of module names or 'all-linear'."} + ) + rank_pattern: Optional[dict[str, int]] = field( + default=None, + metadata={"help": "A dictionary of regex patterns to override effective_rank for specific modules."} ) def __post_init__(self): diff --git a/src/peft/tuners/osf/layer.py b/src/peft/tuners/osf/layer.py new file mode 100644 index 0000000000..0f0a4756eb --- /dev/null +++ b/src/peft/tuners/osf/layer.py @@ -0,0 +1,274 @@ +# Copyright 2025-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import warnings +from typing import Any, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from peft.tuners.tuners_utils import BaseTunerLayer + +from .utils import ( + decompose_weight_matrix, + project_gradient_to_orthogonal_space, + reconstruct_weight_matrix, +) + + +class OSFLayer(BaseTunerLayer): + # All names of layers that may contain (trainable) adapter weights + adapter_layer_names: tuple[str, ...] = ("osf_svd_params",) + # All names of other parameters that may contain adapter-related parameters + other_param_names: tuple[str, ...] = ("effective_rank",) + + def __init__(self, base_layer: nn.Module, **kwargs) -> None: + self.base_layer = base_layer + self.effective_rank = {} + self.osf_svd_params = nn.ModuleDict({}) + # Store high-rank (frozen) components as buffers + self._osf_U_high = {} + self._osf_S_high = {} + self._osf_V_high = {} + # Track hook handles for cleanup + self.hook_handles = [] + # Mark the weight as unmerged + self._disable_adapters = False + self.merged_adapters = [] + + # Get layer dimensions + base_layer = self.get_base_layer() + if isinstance(base_layer, nn.Linear): + in_features, out_features = base_layer.in_features, base_layer.out_features + elif hasattr(base_layer, "infeatures") and hasattr(base_layer, "outfeatures"): + # QuantLinear + in_features, out_features = base_layer.infeatures, base_layer.outfeatures + elif hasattr(base_layer, "input_size") and hasattr(base_layer, "output_size"): + # Megatron ColumnParallelLinear,RowParallelLinear + in_features, out_features = base_layer.input_size, base_layer.output_size + elif hasattr(base_layer, "in_features") and hasattr(base_layer, "out_features"): + in_features, out_features = base_layer.in_features, base_layer.out_features + else: + in_features, out_features = None, None + warnings.warn( + f"Unsupported layer type '{type(base_layer)}' encountered, proceed at your own risk.", UserWarning + ) + + self.in_features = in_features + self.out_features = out_features + + def update_layer(self, adapter_name: str, effective_rank: int, **kwargs): + """Update layer to add a new OSF adapter.""" + if effective_rank <= 0: + raise ValueError(f"`effective_rank` should be a positive integer value but the value passed is {effective_rank}") + + # Store the rank for this adapter + self.effective_rank[adapter_name] = effective_rank + + # Perform SVD decomposition on the base layer weight + base_layer = self.get_base_layer() + weight = base_layer.weight.data + svd_dict = decompose_weight_matrix(weight, top_k=effective_rank) + + # Store high-rank (frozen) components as buffers + self._osf_U_high[adapter_name] = svd_dict["U_high"] + self._osf_S_high[adapter_name] = svd_dict["S_high"] + self._osf_V_high[adapter_name] = svd_dict["V_high"] + + # Create module for trainable low-rank components + svd_module = nn.Module() + svd_module.U_low = svd_dict["U_low"] + svd_module.S_low = svd_dict["S_low"] + svd_module.V_low = svd_dict["V_low"] + svd_module.rank_high = svd_dict["rank_high"] + + self.osf_svd_params[adapter_name] = svd_module + + # Attach gradient hooks for orthogonal projection + self._attach_hooks(adapter_name) + + # Set the adapter as active + self.set_adapter(self.active_adapters) + + def _attach_hooks(self, adapter_name: str): + """Attach gradient hooks for the given adapter.""" + if adapter_name not in self.osf_svd_params: + return + + svd_module = self.osf_svd_params[adapter_name] + svd_dict = { + "U_high": self._osf_U_high[adapter_name], + "S_high": self._osf_S_high[adapter_name], + "V_high": self._osf_V_high[adapter_name], + "U_low": svd_module.U_low, + "S_low": svd_module.S_low, + "V_low": svd_module.V_low, + } + + def hook(grad): + project_gradient_to_orthogonal_space(svd_dict) + return grad + + # Store hook handles for later cleanup + handle_u = svd_module.U_low.register_hook(hook) + handle_s = svd_module.S_low.register_hook(hook) + handle_v = svd_module.V_low.register_hook(hook) + + self.hook_handles.extend([handle_u, handle_s, handle_v]) + + def _detach_hooks(self): + """Remove all gradient hooks.""" + for handle in self.hook_handles: + handle.remove() + self.hook_handles.clear() + + def _reconstruct_weight(self, adapter_name: str) -> torch.Tensor: + """Reconstruct weight matrix from SVD components for given adapter.""" + if adapter_name not in self.osf_svd_params: + return self.get_base_layer().weight + + svd_module = self.osf_svd_params[adapter_name] + svd_dict = { + "U_high": self._osf_U_high[adapter_name], + "S_high": self._osf_S_high[adapter_name], + "V_high": self._osf_V_high[adapter_name], + "U_low": svd_module.U_low, + "S_low": svd_module.S_low, + "V_low": svd_module.V_low, + } + return reconstruct_weight_matrix(svd_dict) + + def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None: + """ + Merge the active adapter weights into the base weights + + Args: + safe_merge (`bool`, *optional*): + If True, the merge operation will be performed in a copy of the original weights and check for NaNs + before merging the weights. This is useful if you want to check if the merge operation will produce + NaNs. Defaults to `False`. + adapter_names (`list[str]`, *optional*): + The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults + to `None`. + """ + if adapter_names is None: + adapter_names = self.active_adapters + + for active_adapter in adapter_names: + if active_adapter in self.osf_svd_params.keys(): + base_layer = self.get_base_layer() + if safe_merge: + # Note that safe_merge will be slower than the normal merge + # because of the copy operation. + orig_weight = base_layer.weight.data.clone() + new_weight = self._reconstruct_weight(active_adapter) + + if not torch.isfinite(new_weight).all(): + raise ValueError( + f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" + ) + + base_layer.weight.data = new_weight.to(orig_weight.dtype) + else: + new_weight = self._reconstruct_weight(active_adapter) + base_layer.weight.data = new_weight + + self.merged_adapters.append(active_adapter) + + def unmerge(self) -> None: + """ + This method unmerges all merged adapter layers from the base weights. + """ + if not self.merged: + warnings.warn("Already unmerged. Nothing to do.") + return + + # For OSF, unmerging means restoring the original weight + # Since we modify the weight in-place, we need to store the original weight + # This is a limitation of the current OSF implementation + warnings.warn("OSF does not support unmerging. Original weights are permanently modified.") + + def __del__(self): + """Cleanup hooks on deletion.""" + self._detach_hooks() + + +class Linear(nn.Module, OSFLayer): + # OSF implemented in a dense layer + def __init__( + self, + base_layer, + adapter_name: str, + effective_rank: int = None, + **kwargs, + ) -> None: + super().__init__() + OSFLayer.__init__(self, base_layer, **kwargs) + + # Set default effective_rank if not provided + if effective_rank is None: + # Default to 50% of min dimension + effective_rank = min(self.in_features, self.out_features) // 2 + + self._active_adapter = adapter_name + self.update_layer(adapter_name, effective_rank, **kwargs) + + def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: + if self.disable_adapters: + if self.merged: + self.unmerge() + result = self.base_layer(x, *args, **kwargs) + elif self.merged: + result = self.base_layer(x, *args, **kwargs) + else: + # Use reconstructed weight for forward pass + base_layer = self.get_base_layer() + bias = base_layer.bias + + # Use the active adapter's reconstructed weight + active_adapter = self.active_adapters[0] if self.active_adapters else None + if active_adapter and active_adapter in self.osf_svd_params: + weight = self._reconstruct_weight(active_adapter) + if weight.dtype != x.dtype: + weight = weight.to(x.dtype) + result = F.linear(x, weight, bias) + else: + result = self.base_layer(x, *args, **kwargs) + + return result + + def __repr__(self) -> str: + rep = super().__repr__() + return "osf." + rep + + +def dispatch_default( + target: torch.nn.Module, + adapter_name: str, + osf_config, + **kwargs, +) -> Optional[torch.nn.Module]: + new_module = None + + if isinstance(target, BaseTunerLayer): + target_base_layer = target.get_base_layer() + else: + target_base_layer = target + + if isinstance(target_base_layer, torch.nn.Linear): + new_module = Linear(target, adapter_name, **kwargs) + + return new_module \ No newline at end of file diff --git a/src/peft/tuners/osf/model.py b/src/peft/tuners/osf/model.py index 9b88122c12..4ae4efbdeb 100644 --- a/src/peft/tuners/osf/model.py +++ b/src/peft/tuners/osf/model.py @@ -1,8 +1,12 @@ from __future__ import annotations +import re import torch.nn as nn from peft.tuners.tuners_utils import BaseTuner +from peft.utils.constants import TRANSFORMERS_MODELS_TO_OSF_TARGET_MODULES_MAPPING + +from .layer import OSFLayer, Linear, dispatch_default from .utils import ( attach_gradient_hooks, auto_generate_target_osf_config, @@ -21,13 +25,11 @@ def __init__(self, model, config, adapter_name, low_cpu_mem_usage: bool = False) def _prepare_adapter_config(self, peft_config, model_config): return peft_config - def inject_adapter( - self, model: nn.Module, adapter_name: str, autocast_adapter_dtype: bool = True, low_cpu_mem_usage: bool = False - ) -> None: - svd_cfg = self.peft_config[adapter_name].target_svd_config - if svd_cfg is None: - svd_cfg = auto_generate_target_osf_config(model) - self.peft_config[adapter_name].target_svd_config = svd_cfg + def inject_adapter(self, model: nn.Module, adapter_name: str, autocast_adapter_dtype: bool = True, low_cpu_mem_usage: bool = False) -> None: + # For now, keep using the legacy approach + # TODO: Refactor to use _create_and_replace pattern + svd_cfg = auto_generate_target_osf_config(model) + OSFCls = create_osf_model_class(model.__class__) base_cfg = getattr(model, "config", None) osf_model = OSFCls(base_cfg, svd_config=svd_cfg, initialize_svd=False) @@ -36,8 +38,40 @@ def inject_adapter( attach_gradient_hooks(osf_model) self.model = osf_model - def _create_and_replace(self, *args, **kwargs): - pass + def _create_and_replace( + self, + osf_config, + adapter_name: str, + target: nn.Module, + target_name: str, + parent: nn.Module, + current_key: str, + ): + # OSF only works on 2D weight matrices + if not hasattr(target, 'weight') or len(target.weight.shape) != 2: + return None + + # Determine effective rank for this target + effective_rank = osf_config.effective_rank + if effective_rank is None: + # Default to 50% of min dimension + effective_rank = min(target.weight.shape) // 2 + + # Check for per-module rank overrides + if hasattr(osf_config, 'rank_pattern') and osf_config.rank_pattern: + for pattern, rank in osf_config.rank_pattern.items(): + if re.search(pattern, current_key): + effective_rank = rank + break + + kwargs = { + "effective_rank": effective_rank, + } + + # Create new OSF layer + new_module = dispatch_default(target, adapter_name, osf_config, **kwargs) + + return new_module def _check_target_module_exists(self, *args, **kwargs) -> bool: return True From 240267849261b6bd4122f5cdc46521e3b4bd76a6 Mon Sep 17 00:00:00 2001 From: Nikhil Nayak Date: Tue, 19 Aug 2025 15:12:03 +0000 Subject: [PATCH 18/19] documentation updates --- docs/source/package_reference/osf.md | 240 ++++++++++++++++++ docs/source/package_reference/osf_utils.md | 23 -- .../orthogonal_subspace_learning/README.md | 24 +- 3 files changed, 260 insertions(+), 27 deletions(-) create mode 100644 docs/source/package_reference/osf.md delete mode 100644 docs/source/package_reference/osf_utils.md diff --git a/docs/source/package_reference/osf.md b/docs/source/package_reference/osf.md new file mode 100644 index 0000000000..a00e403276 --- /dev/null +++ b/docs/source/package_reference/osf.md @@ -0,0 +1,240 @@ + + +# OSF (Orthogonal Subspace Fine-tuning) + +Orthogonal Subspace Fine-tuning ([OSF](https://arxiv.org/abs/2504.07097)) is a PEFT method designed for continual learning that constrains parameter updates to be orthogonal to previously important directions. This approach enables full fine-tuning while preventing catastrophic forgetting without requiring additional parameters or storing previous gradients. + +The abstract from the paper is: + +*Continual learning in large language models (LLMs) is prone to catastrophic forgetting, where adapting to new tasks significantly degrades performance on previously learned ones. Existing methods typically rely on low-rank, parameter-efficient updates that limit the model's expressivity and introduce additional parameters per task, leading to scalability issues. To address these limitations, we propose a novel continual full fine-tuning approach leveraging adaptive singular value decomposition (SVD). Our method dynamically identifies task-specific low-rank parameter subspaces and constrains updates to be orthogonal to critical directions associated with prior tasks, thus effectively minimizing interference without additional parameter overhead or storing previous task gradients. We evaluate our approach extensively on standard continual learning benchmarks using both encoder-decoder (T5-Large) and decoder-only (LLaMA-2 7B) models, spanning diverse tasks including classification, generation, and reasoning. Empirically, our method achieves state-of-the-art results, up to 7% higher average accuracy than recent baselines like O-LoRA, and notably maintains the model's general linguistic capabilities, instruction-following accuracy, and safety throughout the continual learning process by reducing forgetting to near-negligible levels. Our adaptive SVD framework effectively balances model plasticity and knowledge retention, providing a practical, theoretically grounded, and computationally scalable solution for continual learning scenarios in large language models.* + +## How OSF Works + +OSF decomposes each weight matrix into high-rank (frozen) and low-rank (trainable) components using SVD: + +``` +W = U_high * S_high * V_high^T + U_low * S_low * V_low^T +``` + +Where: +- `U_high, S_high, V_high`: Preserve important directions from previous tasks (frozen) +- `U_low, S_low, V_low`: Allow adaptation to new tasks (trainable) + +During training, gradients are projected to be orthogonal to the high-rank subspace, ensuring updates don't interfere with previously learned knowledge. + +## Basic Usage + +```python +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer +from peft import OSFConfig, get_peft_model + +# Load base model +model = AutoModelForCausalLM.from_pretrained("gpt2") + +# Configure OSF +config = OSFConfig( + target_modules=["c_attn", "c_proj"], # Target attention layers + effective_rank=8, # Default rank for decomposition + rank_pattern={"c_attn": 16} # Override rank for specific modules +) + +# Apply OSF +model = get_peft_model(model, config) + +# Train as usual +optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4) + +tokenizer = AutoTokenizer.from_pretrained("gpt2") +tokenizer.pad_token = tokenizer.eos_token + +inputs = tokenizer("Hello world", return_tensors="pt", padding=True) +loss = model(**inputs, labels=inputs.input_ids).loss +loss.backward() +optimizer.step() +optimizer.zero_grad() +``` + +## Configuration Options + +### Target Modules + +You can specify target modules in several ways: + +```python +# Specific module names +config = OSFConfig(target_modules=["q_proj", "k_proj", "v_proj", "o_proj"]) + +# All linear layers +config = OSFConfig(target_modules="all-linear") + +# Model-specific defaults (automatically detected) +config = OSFConfig() # Uses model-appropriate defaults +``` + +### Effective Rank Configuration + +Control the decomposition rank: + +```python +# Global rank (applies to all target modules) +config = OSFConfig(effective_rank=16) + +# Automatic rank (50% of min dimension) +config = OSFConfig(effective_rank=None) + +# Per-module rank overrides +config = OSFConfig( + effective_rank=8, + rank_pattern={ + "q_proj": 16, # Higher rank for query projection + "gate_proj": 4 # Lower rank for gate projection + } +) +``` + +## Training Advice for Continual Learning + +### Sequential Task Learning + +OSF is specifically designed for learning tasks sequentially: + +```python +# Task 1: Train on domain A +model = get_peft_model(base_model, OSFConfig(effective_rank=8)) +train_task(model, task_1_data) + +# Task 2: Continue training on domain B +# OSF automatically preserves Task 1 knowledge +train_task(model, task_2_data) + +# Task 3: Continue with domain C +train_task(model, task_3_data) +``` + +### Budget Allocation for Task Sequences + +When training on a known sequence of n tasks, one effective strategy is to progressively allocate model capacity to balance learning new tasks while preserving previous knowledge: + +- **Task 1**: Use full capacity (train everything) +- **Task 2**: Freeze 1/n of model capacity, train remaining (n-1)/n capacity +- **Task 3**: Freeze 2/n of model capacity, train remaining (n-2)/n capacity +- **Task n**: Freeze (n-1)/n of model capacity, use 1/n capacity for final task + +This approach ensures each task gets adequate learning capacity while progressively preserving more knowledge from previous tasks. + +```python +# Example: 4-task sequence with progressive budget allocation +n_tasks = 4 +base_rank = 32 # Starting rank for full capacity + +for task_id in range(n_tasks): + # Calculate remaining capacity for current task + freeze_fraction = task_id / n_tasks + remaining_capacity = 1.0 - freeze_fraction + current_rank = int(base_rank * remaining_capacity) + + config = OSFConfig( + target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], + effective_rank=current_rank + ) + + print(f"Task {task_id + 1}: Using rank {current_rank} " + f"({remaining_capacity:.1%} of full capacity)") + + # Train on current task + model = get_peft_model(base_model, config) + train_task(model, task_data[task_id]) +``` + +### Best Practices + +1. **Effective Rank Selection**: Start with `effective_rank=None` (automatic 50% rank) and adjust based on task complexity +2. **Learning Rate**: Use smaller learning rates (1e-5 to 1e-4) compared to standard fine-tuning +3. **Task Importance**: Use `rank_pattern` to allocate more capacity to critical modules +4. **Model Architecture**: OSF works best with transformer architectures having clear attention and MLP separations +5. **Capacity Planning**: For known task sequences, use progressive budget allocation (1/n, 2/n, ..., (n-1)/n freezing) to balance plasticity and stability + +### Memory Considerations + +OSF modifies weights in-place and doesn't add parameters, making it memory-efficient: + +```python +# Memory usage remains close to base model +print(f"Base model parameters: {base_model.num_parameters():,}") +print(f"OSF model parameters: {osf_model.num_parameters():,}") # Similar count +``` + +## Advanced Usage + +### Custom Target Modules + +For models with non-standard architectures: + +```python +config = OSFConfig( + target_modules=["dense", "intermediate.dense"], # Custom layer names + effective_rank=12, + rank_pattern={"dense": 8, "intermediate.dense": 16} +) +``` + +### Integration with Other Methods + +OSF can be combined with other techniques: + +```python +# Use with gradient checkpointing for memory efficiency +model.gradient_checkpointing_enable() + +# Apply weight decay selectively +optimizer = torch.optim.AdamW([ + {"params": [p for n, p in model.named_parameters() if "U_low" in n], "weight_decay": 0.01}, + {"params": [p for n, p in model.named_parameters() if "S_low" in n], "weight_decay": 0.001}, + {"params": [p for n, p in model.named_parameters() if "V_low" in n], "weight_decay": 0.01}, +], lr=1e-4) +``` + +## OSFConfig + +[[autodoc]] tuners.osf.config.OSFConfig + +## OSFModel + +[[autodoc]] tuners.osf.model.OSFModel + +## Utility Functions + +### Weight Decomposition + +[[autodoc]] tuners.osf.utils.decompose_weight_matrix + +[[autodoc]] tuners.osf.utils.reconstruct_weight_matrix + +### Gradient Projection + +[[autodoc]] tuners.osf.utils.project_gradient_to_orthogonal_space + +### Hook Management + +[[autodoc]] tuners.osf.utils.attach_gradient_hooks + +[[autodoc]] tuners.osf.utils.detach_gradient_hooks + +### Configuration Helpers + +[[autodoc]] tuners.osf.utils.auto_generate_target_osf_config \ No newline at end of file diff --git a/docs/source/package_reference/osf_utils.md b/docs/source/package_reference/osf_utils.md deleted file mode 100644 index d54bb7e66c..0000000000 --- a/docs/source/package_reference/osf_utils.md +++ /dev/null @@ -1,23 +0,0 @@ - - -# OSF utilities - -Helper functions for orthogonal subspace learning with Adaptive OSF. - -[[autodoc]] utils.osf_utils.decompose_weight_matrix - - all - -[[autodoc]] utils.osf_utils.reconstruct_weight_matrix - - all - -[[autodoc]] utils.osf_utils.project_gradient_to_orthogonal_space - - all - -[[autodoc]] utils.osf_utils.auto_generate_target_osf_config - - all - -[[autodoc]] utils.osf_utils.create_osf_model_class - - all - -[[autodoc]] utils.osf_utils.attach_gradient_hooks - - all \ No newline at end of file diff --git a/examples/orthogonal_subspace_learning/README.md b/examples/orthogonal_subspace_learning/README.md index f411b084a5..0e262ccf8b 100644 --- a/examples/orthogonal_subspace_learning/README.md +++ b/examples/orthogonal_subspace_learning/README.md @@ -1,6 +1,20 @@ # Orthogonal Subspace Learning with Adaptive OSF -This example shows how to wrap a pretrained model with SVD-decomposed weights to enable orthogonal subspace training. +## TODO: Runnable Example Needed + +This folder is a placeholder for a comprehensive OSF example. As suggested in the review feedback: + +> "If you can, provide a runnable example in this folder instead, you can take a look at the EVA example for inspiration. A runnable example can be a good place to showcase the different features. Jupyter notebooks are fine as well." + +### Planned Example Features: +- Complete continual learning scenario with multiple tasks +- Demonstration of OSF's catastrophic forgetting prevention +- Configuration examples (target_modules, effective_rank, rank_pattern) +- Performance comparison with baseline methods +- Memory usage analysis + +### Current Basic Usage: +For basic usage examples and API documentation, see the [OSF documentation](../../docs/source/package_reference/osf.md). ```python import torch @@ -8,13 +22,15 @@ from transformers import AutoModelForCausalLM, AutoTokenizer from peft import OSFConfig, get_peft_model model = AutoModelForCausalLM.from_pretrained("gpt2") -model = get_peft_model(model, OSFConfig()) # add trainable low-rank parameters +config = OSFConfig(target_modules=["c_attn", "c_proj"], effective_rank=8) +model = get_peft_model(model, config) optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4) tokenizer = AutoTokenizer.from_pretrained("gpt2") -input_ids = tokenizer("Hello world", return_tensors="pt").input_ids -loss = model(input_ids, labels=input_ids).loss +tokenizer.pad_token = tokenizer.eos_token +inputs = tokenizer("Hello world", return_tensors="pt", padding=True) +loss = model(**inputs, labels=inputs.input_ids).loss loss.backward() optimizer.step() optimizer.zero_grad() From 845479e2eabeb26da93a0e6465f2e9e0eab09abc Mon Sep 17 00:00:00 2001 From: Nikhil Nayak Date: Wed, 10 Sep 2025 20:06:42 +0000 Subject: [PATCH 19/19] OSF refactor + docs/tests cleanup --- docs/source/package_reference/osf.md | 36 ++--- src/peft/tuners/osf/layer.py | 76 +++++---- src/peft/tuners/osf/model.py | 90 +++++++---- src/peft/tuners/osf/utils.py | 223 +-------------------------- tests/test_osf.py | 65 ++++++++ tests/test_osf_utils.py | 118 -------------- tests/testing_common.py | 16 +- 7 files changed, 202 insertions(+), 422 deletions(-) create mode 100644 tests/test_osf.py delete mode 100644 tests/test_osf_utils.py diff --git a/docs/source/package_reference/osf.md b/docs/source/package_reference/osf.md index a00e403276..04db5d0c28 100644 --- a/docs/source/package_reference/osf.md +++ b/docs/source/package_reference/osf.md @@ -16,7 +16,7 @@ rendered properly in your Markdown viewer. # OSF (Orthogonal Subspace Fine-tuning) -Orthogonal Subspace Fine-tuning ([OSF](https://arxiv.org/abs/2504.07097)) is a PEFT method designed for continual learning that constrains parameter updates to be orthogonal to previously important directions. This approach enables full fine-tuning while preventing catastrophic forgetting without requiring additional parameters or storing previous gradients. +Orthogonal Subspace Fine-tuning ([OSF](https://huggingface.co/papers/2504.07097)) is a PEFT method designed for continual learning that constrains parameter updates to be orthogonal to previously important directions. This approach enables full fine-tuning while preventing catastrophic forgetting without requiring additional parameters or storing previous gradients. The abstract from the paper is: @@ -94,7 +94,7 @@ Control the decomposition rank: # Global rank (applies to all target modules) config = OSFConfig(effective_rank=16) -# Automatic rank (50% of min dimension) +# Automatic rank (50% of the smaller matrix dimension per target) config = OSFConfig(effective_rank=None) # Per-module rank overrides @@ -111,18 +111,24 @@ config = OSFConfig( ### Sequential Task Learning -OSF is specifically designed for learning tasks sequentially: +OSF is specifically designed for learning tasks sequentially. Between tasks, recompute the SVD so the preserved subspace reflects the latest weights. One simple way is to re-wrap the updated base model with OSF again: ```python -# Task 1: Train on domain A -model = get_peft_model(base_model, OSFConfig(effective_rank=8)) +# Task 1: train on domain A with initial preserved subspace +r = 8 # initial effective rank to preserve +model = get_peft_model(base_model, OSFConfig(effective_rank=r)) train_task(model, task_1_data) -# Task 2: Continue training on domain B -# OSF automatically preserves Task 1 knowledge +# Task 2: recompute SVD on updated weights and increase preserved subspace +base_model = model.base_model.model # unwrap updated base +r += 4 # grow preserved subspace to include Task 1 knowledge +model = get_peft_model(base_model, OSFConfig(effective_rank=r)) train_task(model, task_2_data) -# Task 3: Continue with domain C +# Task 3: recompute again and expand preserved subspace further +base_model = model.base_model.model +r += 4 +model = get_peft_model(base_model, OSFConfig(effective_rank=r)) train_task(model, task_3_data) ``` @@ -163,7 +169,7 @@ for task_id in range(n_tasks): ### Best Practices -1. **Effective Rank Selection**: Start with `effective_rank=None` (automatic 50% rank) and adjust based on task complexity +1. **Effective Rank Selection**: Start with `effective_rank=None` (auto sets rank to 50% of the smaller weight dimension per target module) and adjust based on task complexity 2. **Learning Rate**: Use smaller learning rates (1e-5 to 1e-4) compared to standard fine-tuning 3. **Task Importance**: Use `rank_pattern` to allocate more capacity to critical modules 4. **Model Architecture**: OSF works best with transformer architectures having clear attention and MLP separations @@ -201,7 +207,7 @@ OSF can be combined with other techniques: # Use with gradient checkpointing for memory efficiency model.gradient_checkpointing_enable() -# Apply weight decay selectively +# Apply weight decay selectively (regularizes low-rank factors to limit drift/overfitting in continual updates; keep small) optimizer = torch.optim.AdamW([ {"params": [p for n, p in model.named_parameters() if "U_low" in n], "weight_decay": 0.01}, {"params": [p for n, p in model.named_parameters() if "S_low" in n], "weight_decay": 0.001}, @@ -228,13 +234,3 @@ optimizer = torch.optim.AdamW([ ### Gradient Projection [[autodoc]] tuners.osf.utils.project_gradient_to_orthogonal_space - -### Hook Management - -[[autodoc]] tuners.osf.utils.attach_gradient_hooks - -[[autodoc]] tuners.osf.utils.detach_gradient_hooks - -### Configuration Helpers - -[[autodoc]] tuners.osf.utils.auto_generate_target_osf_config \ No newline at end of file diff --git a/src/peft/tuners/osf/layer.py b/src/peft/tuners/osf/layer.py index 0f0a4756eb..bb58026e1e 100644 --- a/src/peft/tuners/osf/layer.py +++ b/src/peft/tuners/osf/layer.py @@ -15,16 +15,17 @@ import warnings from typing import Any, Optional +from functools import partial import torch import torch.nn as nn import torch.nn.functional as F from peft.tuners.tuners_utils import BaseTunerLayer +from peft.tuners._buffer_dict import BufferDict from .utils import ( decompose_weight_matrix, - project_gradient_to_orthogonal_space, reconstruct_weight_matrix, ) @@ -33,16 +34,17 @@ class OSFLayer(BaseTunerLayer): # All names of layers that may contain (trainable) adapter weights adapter_layer_names: tuple[str, ...] = ("osf_svd_params",) # All names of other parameters that may contain adapter-related parameters - other_param_names: tuple[str, ...] = ("effective_rank",) + other_param_names: tuple[str, ...] = ("_osf_U_high", "_osf_S_high", "_osf_V_high") def __init__(self, base_layer: nn.Module, **kwargs) -> None: self.base_layer = base_layer self.effective_rank = {} + # Map adapter_name -> ParameterDict{"U_low", "S_low", "V_low"} self.osf_svd_params = nn.ModuleDict({}) - # Store high-rank (frozen) components as buffers - self._osf_U_high = {} - self._osf_S_high = {} - self._osf_V_high = {} + # Store high-rank (frozen) components as buffers that track device moves + self._osf_U_high = BufferDict({}) + self._osf_S_high = BufferDict({}) + self._osf_V_high = BufferDict({}) # Track hook handles for cleanup self.hook_handles = [] # Mark the weight as unmerged @@ -51,20 +53,24 @@ def __init__(self, base_layer: nn.Module, **kwargs) -> None: # Get layer dimensions base_layer = self.get_base_layer() - if isinstance(base_layer, nn.Linear): + # Prefer the universally available weight shape when possible. + if hasattr(base_layer, "weight") and isinstance(base_layer.weight, torch.Tensor) and base_layer.weight.ndim == 2: + # For Linear-like modules, weight is [out_features, in_features] + out_features, in_features = base_layer.weight.shape + elif isinstance(base_layer, nn.Linear): in_features, out_features = base_layer.in_features, base_layer.out_features elif hasattr(base_layer, "infeatures") and hasattr(base_layer, "outfeatures"): # QuantLinear in_features, out_features = base_layer.infeatures, base_layer.outfeatures elif hasattr(base_layer, "input_size") and hasattr(base_layer, "output_size"): - # Megatron ColumnParallelLinear,RowParallelLinear + # Megatron ColumnParallelLinear, RowParallelLinear in_features, out_features = base_layer.input_size, base_layer.output_size elif hasattr(base_layer, "in_features") and hasattr(base_layer, "out_features"): in_features, out_features = base_layer.in_features, base_layer.out_features else: in_features, out_features = None, None warnings.warn( - f"Unsupported layer type '{type(base_layer)}' encountered, proceed at your own risk.", UserWarning + f"Unsupported layer type '{type(base_layer)}' encountered; could not infer in/out features.", UserWarning ) self.in_features = in_features @@ -88,14 +94,15 @@ def update_layer(self, adapter_name: str, effective_rank: int, **kwargs): self._osf_S_high[adapter_name] = svd_dict["S_high"] self._osf_V_high[adapter_name] = svd_dict["V_high"] - # Create module for trainable low-rank components - svd_module = nn.Module() - svd_module.U_low = svd_dict["U_low"] - svd_module.S_low = svd_dict["S_low"] - svd_module.V_low = svd_dict["V_low"] - svd_module.rank_high = svd_dict["rank_high"] - - self.osf_svd_params[adapter_name] = svd_module + # Create ParameterDict for trainable low-rank components + svd_params = nn.ParameterDict( + { + "U_low": svd_dict["U_low"], + "S_low": svd_dict["S_low"], + "V_low": svd_dict["V_low"], + } + ) + self.osf_svd_params[adapter_name] = svd_params # Attach gradient hooks for orthogonal projection self._attach_hooks(adapter_name) @@ -113,21 +120,28 @@ def _attach_hooks(self, adapter_name: str): "U_high": self._osf_U_high[adapter_name], "S_high": self._osf_S_high[adapter_name], "V_high": self._osf_V_high[adapter_name], - "U_low": svd_module.U_low, - "S_low": svd_module.S_low, - "V_low": svd_module.V_low, + "U_low": svd_module["U_low"], + "S_low": svd_module["S_low"], + "V_low": svd_module["V_low"], } - def hook(grad): - project_gradient_to_orthogonal_space(svd_dict) + def hook(grad, name: str): + # Project gradient to be orthogonal to high-rank subspace for U_low/V_low + if name == "U_low": + U_high = svd_dict["U_high"] + proj = U_high @ (U_high.transpose(0, 1) @ grad) + return grad - proj + elif name == "V_low": + V_high = svd_dict["V_high"] + proj = (grad @ V_high.transpose(0, 1)) @ V_high + return grad - proj return grad # Store hook handles for later cleanup - handle_u = svd_module.U_low.register_hook(hook) - handle_s = svd_module.S_low.register_hook(hook) - handle_v = svd_module.V_low.register_hook(hook) - - self.hook_handles.extend([handle_u, handle_s, handle_v]) + handle_u = svd_module["U_low"].register_hook(partial(hook, name="U_low")) + handle_v = svd_module["V_low"].register_hook(partial(hook, name="V_low")) + + self.hook_handles.extend([handle_u, handle_v]) def _detach_hooks(self): """Remove all gradient hooks.""" @@ -145,9 +159,9 @@ def _reconstruct_weight(self, adapter_name: str) -> torch.Tensor: "U_high": self._osf_U_high[adapter_name], "S_high": self._osf_S_high[adapter_name], "V_high": self._osf_V_high[adapter_name], - "U_low": svd_module.U_low, - "S_low": svd_module.S_low, - "V_low": svd_module.V_low, + "U_low": svd_module["U_low"], + "S_low": svd_module["S_low"], + "V_low": svd_module["V_low"], } return reconstruct_weight_matrix(svd_dict) @@ -271,4 +285,4 @@ def dispatch_default( if isinstance(target_base_layer, torch.nn.Linear): new_module = Linear(target, adapter_name, **kwargs) - return new_module \ No newline at end of file + return new_module diff --git a/src/peft/tuners/osf/model.py b/src/peft/tuners/osf/model.py index 4ae4efbdeb..44449f2a65 100644 --- a/src/peft/tuners/osf/model.py +++ b/src/peft/tuners/osf/model.py @@ -3,15 +3,10 @@ import re import torch.nn as nn -from peft.tuners.tuners_utils import BaseTuner +from peft.tuners.tuners_utils import BaseTuner, check_target_module_exists from peft.utils.constants import TRANSFORMERS_MODELS_TO_OSF_TARGET_MODULES_MAPPING from .layer import OSFLayer, Linear, dispatch_default -from .utils import ( - attach_gradient_hooks, - auto_generate_target_osf_config, - create_osf_model_class, -) class OSFModel(BaseTuner): @@ -23,20 +18,30 @@ def __init__(self, model, config, adapter_name, low_cpu_mem_usage: bool = False) super().__init__(model, config, adapter_name, low_cpu_mem_usage=low_cpu_mem_usage) def _prepare_adapter_config(self, peft_config, model_config): + # Infer default target modules from mapping if not provided + if getattr(peft_config, "target_modules", None) is None: + model_type = model_config.get("model_type") + if model_type not in TRANSFORMERS_MODELS_TO_OSF_TARGET_MODULES_MAPPING: + raise ValueError("Please specify `target_modules` in `peft_config`") + peft_config.target_modules = set( + TRANSFORMERS_MODELS_TO_OSF_TARGET_MODULES_MAPPING[model_type] + ) return peft_config - def inject_adapter(self, model: nn.Module, adapter_name: str, autocast_adapter_dtype: bool = True, low_cpu_mem_usage: bool = False) -> None: - # For now, keep using the legacy approach - # TODO: Refactor to use _create_and_replace pattern - svd_cfg = auto_generate_target_osf_config(model) - - OSFCls = create_osf_model_class(model.__class__) - base_cfg = getattr(model, "config", None) - osf_model = OSFCls(base_cfg, svd_config=svd_cfg, initialize_svd=False) - osf_model.load_state_dict(model.state_dict()) - osf_model.reinitialize_svd() - attach_gradient_hooks(osf_model) - self.model = osf_model + def inject_adapter( + self, + model: nn.Module, + adapter_name: str, + autocast_adapter_dtype: bool = True, + low_cpu_mem_usage: bool = False, + ) -> None: + # Delegate to BaseTuner to perform standard target discovery and replacement + return super().inject_adapter( + model, + adapter_name, + autocast_adapter_dtype=autocast_adapter_dtype, + low_cpu_mem_usage=low_cpu_mem_usage, + ) def _create_and_replace( self, @@ -68,17 +73,29 @@ def _create_and_replace( "effective_rank": effective_rank, } - # Create new OSF layer - new_module = dispatch_default(target, adapter_name, osf_config, **kwargs) - - return new_module - - def _check_target_module_exists(self, *args, **kwargs) -> bool: - return True + # Create a new or update an existing OSF layer in place + if isinstance(target, OSFLayer): + target.update_layer(adapter_name, **kwargs) + else: + new_module = dispatch_default(target, adapter_name, osf_config, **kwargs) + if new_module is None: + return None + # If adding an additional adapter, keep it frozen initially + if adapter_name not in self.active_adapters: + new_module.requires_grad_(False) + self._replace_module(parent, target_name, new_module, target) + + @staticmethod + def _check_target_module_exists(osf_config, key): + return check_target_module_exists(osf_config, key) def _mark_only_adapters_as_trainable(self, model: nn.Module) -> None: for n, p in model.named_parameters(): - if "svd_params" not in n and not n.endswith(("_U_low", "_S_low", "_V_low")): + if ( + self.prefix not in n + and "svd_params" not in n + and not n.endswith(("_U_low", "_S_low", "_V_low")) + ): p.requires_grad = False def _set_adapter_layers(self, enabled: bool = True) -> None: @@ -103,4 +120,23 @@ def unmerge_adapter(self, *args, **kwargs): raise NotImplementedError("OSF models do not support merging") def merge_and_unload(self, *args, **kwargs): - raise NotImplementedError("OSF models do not support merging") \ No newline at end of file + raise NotImplementedError("OSF models do not support merging") + + def _replace_module(self, parent, child_name, new_module, child): + setattr(parent, child_name, new_module) + # child layer may wrap the original module, unpack it + if hasattr(child, "base_layer"): + child = child.base_layer + + # If new module is a simple wrapper, ensure weight/bias/state stay aligned + if not hasattr(new_module, "base_layer") and hasattr(child, "weight"): + new_module.weight = child.weight + if hasattr(child, "bias"): + new_module.bias = child.bias + + if getattr(child, "state", None) is not None: + if hasattr(new_module, "base_layer"): + new_module.base_layer.state = child.state + else: + new_module.state = child.state + new_module.to(child.weight.device) diff --git a/src/peft/tuners/osf/utils.py b/src/peft/tuners/osf/utils.py index 91f66ef384..0d81d9270d 100644 --- a/src/peft/tuners/osf/utils.py +++ b/src/peft/tuners/osf/utils.py @@ -22,15 +22,11 @@ from torch import nn from torch.nn import functional as F -from peft.utils.constants import TRANSFORMERS_MODELS_TO_OSF_TARGET_MODULES_MAPPING +# Note: OSF now relies on OSFLayer + BaseTuner; no model-level helpers required here. __all__ = [ - "attach_gradient_hooks", - "auto_generate_target_osf_config", - "create_osf_model_class", "decompose_weight_matrix", - "detach_gradient_hooks", "project_gradient_to_orthogonal_space", "reconstruct_weight_matrix", ] @@ -118,220 +114,3 @@ def project_gradient_to_orthogonal_space(svd_dict: dict[str, Any]) -> None: proj = (local_dV @ local_V_high.transpose(0, 1)) @ local_V_high local_dV.sub_(proj) dV.copy_(local_dV) - - -def attach_gradient_hooks(model: nn.Module) -> None: - """Attach gradient hooks to project OSF gradients automatically.""" - if not hasattr(model, "svd_params"): - return - - # Initialize hook_handles if not exists - if not hasattr(model, "hook_handles"): - model.hook_handles = [] - - for safe_name, module_svd in model.svd_params.items(): - svd_dict = { - "U_high": getattr(model, f"{safe_name}_U_high"), - "S_high": getattr(model, f"{safe_name}_S_high"), - "V_high": getattr(model, f"{safe_name}_V_high"), - "U_low": module_svd.U_low, - "S_low": module_svd.S_low, - "V_low": module_svd.V_low, - } - - def hook(grad): - project_gradient_to_orthogonal_space(svd_dict) - return grad - - # Store hook handles for later cleanup - handle_u = module_svd.U_low.register_hook(hook) - handle_s = module_svd.S_low.register_hook(hook) - handle_v = module_svd.V_low.register_hook(hook) - - model.hook_handles.extend([handle_u, handle_s, handle_v]) - - -def detach_gradient_hooks(model: nn.Module) -> None: - """Remove all gradient hooks that were attached by attach_gradient_hooks.""" - if hasattr(model, "hook_handles"): - for handle in model.hook_handles: - handle.remove() - model.hook_handles.clear() - - -def auto_generate_target_osf_config(model: nn.Module) -> dict[str, int]: - """Create a mapping from parameter names to ``top_k`` based on layer size.""" - # Get model type and corresponding target modules - model_type = getattr(model.config, "model_type", None) - target_modules = TRANSFORMERS_MODELS_TO_OSF_TARGET_MODULES_MAPPING.get(model_type, []) - - # Fallback to hardcoded patterns if model type not found - if not target_modules: - target_patterns = [ - "self_attn.q_proj", - "self_attn.k_proj", - "self_attn.v_proj", - "self_attn.o_proj", - "mlp.gate_proj", - "mlp.down_proj", - "mlp.up_proj", - ] - else: - # Convert module names to patterns that match the full parameter names - target_patterns = [f"self_attn.{mod}" if mod in ["q_proj", "k_proj", "v_proj", "o_proj"] - else f"mlp.{mod}" if mod in ["gate_proj", "down_proj", "up_proj"] - else mod for mod in target_modules] - - config: dict[str, int] = {} - for name, param in model.named_parameters(): - if any(pat in name for pat in target_patterns) and len(param.shape) == 2: - top_k = min(param.shape) // 2 - config[name] = top_k - return config - - -def create_osf_model_class(base_cls: type) -> type: - """Create a subclass of ``base_cls`` where selected linear weights are replaced by SVD decompositions.""" - - class ModelWithOSF(base_cls): - def __init__( - self, config=None, svd_config: dict[str, int] | None = None, initialize_svd: bool = True, **kwargs - ): - if config is not None: - try: - super().__init__(config, **kwargs) - except TypeError: - super().__init__(**kwargs) - self.config = config - else: - super().__init__(**kwargs) - self.svd_config = svd_config or {} - self.name_mapping: dict[str, str] = {} - self.svd_params = nn.ModuleDict() - if initialize_svd: - self._initialize_svd_parameters() - attach_gradient_hooks(self) - - @classmethod - def from_pretrained( - cls, pretrained_model_name_or_path, *model_args, svd_config: dict[str, int] | None = None, **kwargs - ): - model = super().from_pretrained( - pretrained_model_name_or_path, - *model_args, - svd_config=svd_config or {}, - **kwargs, - ) - if svd_config is None: - svd_config = auto_generate_target_osf_config(model) - model.svd_config = svd_config - model.reinitialize_svd() - return model - - def reinitialize_svd(self) -> None: - # Clean up existing hooks before reinitializing - detach_gradient_hooks(self) - self.name_mapping = {} - self.svd_params = nn.ModuleDict() - self._initialize_svd_parameters() - attach_gradient_hooks(self) - - def _get_module_by_name(self, name: str): - parts = name.split(".") - attr = parts[-1] - mod = self - for p in parts[:-1]: - if hasattr(mod, p): - mod = getattr(mod, p) - elif p.isdigit(): - mod = mod[int(p)] - else: - return None, None - return mod, attr - - def _initialize_svd_parameters(self) -> None: - for name, param in list(self.named_parameters()): - if len(param.shape) == 2 and name in self.svd_config and self.svd_config[name] > 0: - top_k = self.svd_config[name] - svd_dict = decompose_weight_matrix(param.data, top_k=top_k) - safe_name = name.replace(".", "_") - self.name_mapping[name] = safe_name - self.register_buffer(f"{safe_name}_U_high", svd_dict["U_high"]) - self.register_buffer(f"{safe_name}_S_high", svd_dict["S_high"]) - self.register_buffer(f"{safe_name}_V_high", svd_dict["V_high"]) - - module_svd = nn.Module() - module_svd.U_low = svd_dict["U_low"] - module_svd.S_low = svd_dict["S_low"] - module_svd.V_low = svd_dict["V_low"] - module_svd.rank_high = svd_dict["rank_high"] - module_svd.safe_name = safe_name - self.svd_params[safe_name] = module_svd - - mod, attr = self._get_module_by_name(name) - bias = mod.bias if hasattr(mod, "bias") else None - - def make_forward(sn: str, bias: torch.Tensor | None): - def forward(x): - W = self._reconstruct_weight_by_safe_name(sn) - if W.dtype != x.dtype: - W = W.to(x.dtype) - return F.linear(x, W, bias) - - return forward - - mod.forward = make_forward(safe_name, bias) - param.requires_grad = False - - def _reconstruct_weight_by_safe_name(self, safe_name: str) -> torch.Tensor: - U_high = getattr(self, f"{safe_name}_U_high") - S_high = getattr(self, f"{safe_name}_S_high") - V_high = getattr(self, f"{safe_name}_V_high") - module_svd = self.svd_params[safe_name] - svd_dict = { - "U_high": U_high, - "S_high": S_high, - "V_high": V_high, - "U_low": module_svd.U_low, - "S_low": module_svd.S_low, - "V_low": module_svd.V_low, - } - return reconstruct_weight_matrix(svd_dict) - - def project_gradients(self) -> None: - for safe_name, module_svd in self.svd_params.items(): - svd_dict = { - "U_high": getattr(self, f"{safe_name}_U_high"), - "S_high": getattr(self, f"{safe_name}_S_high"), - "V_high": getattr(self, f"{safe_name}_V_high"), - "U_low": module_svd.U_low, - "S_low": module_svd.S_low, - "V_low": module_svd.V_low, - } - project_gradient_to_orthogonal_space(svd_dict) - - def prepare_state_dict_for_save(self, state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: - if not hasattr(self, "name_mapping"): - return state_dict - for orig, safe in self.name_mapping.items(): - U_high = state_dict.pop(f"{safe}_U_high") - S_high = state_dict.pop(f"{safe}_S_high") - V_high = state_dict.pop(f"{safe}_V_high") - U_low = state_dict.pop(f"svd_params.{safe}.U_low") - S_low = state_dict.pop(f"svd_params.{safe}.S_low") - V_low = state_dict.pop(f"svd_params.{safe}.V_low") - W = reconstruct_weight_matrix( - { - "U_high": U_high, - "S_high": S_high, - "V_high": V_high, - "U_low": U_low, - "S_low": S_low, - "V_low": V_low, - } - ) - state_dict[orig] = W - return state_dict - - ModelWithOSF.__name__ = f"{base_cls.__name__}WithOSF" - return ModelWithOSF \ No newline at end of file diff --git a/tests/test_osf.py b/tests/test_osf.py new file mode 100644 index 0000000000..7688804cb8 --- /dev/null +++ b/tests/test_osf.py @@ -0,0 +1,65 @@ +from tempfile import TemporaryDirectory + +import pytest +import torch +from torch.testing import assert_close + +from peft import OSFConfig, PeftModel, get_peft_model +from peft.tuners.osf.utils import ( + decompose_weight_matrix, + reconstruct_weight_matrix, +) + + +def test_osf_roundtrip(): + w = torch.randn(10, 8) + svd = decompose_weight_matrix(w, top_k=4) + w_rec = reconstruct_weight_matrix(svd) + assert_close(w_rec, w, atol=1e-5, rtol=1e-5) + + +class DummyConfig(dict): + pass + + +class DummyModel(torch.nn.Module): + def __init__(self, config=None): + super().__init__() + self.config = config + self.linear = torch.nn.Linear(8, 4) + + def forward(self, x): + return self.linear(x) + + +def test_osf_gradient_projection_hook(): + torch.manual_seed(0) + model = DummyModel(DummyConfig()) + # Specify target module explicitly for DummyModel + cfg = OSFConfig(target_modules=["linear"], effective_rank=2) + wrapped = get_peft_model(model, cfg) + x = torch.randn(3, 8) + wrapped(x).sum().backward() + # Access the injected OSF layer + osf_linear = wrapped.base_model.model.linear + adapter = wrapped.base_model.active_adapters[0] + U_high = osf_linear._osf_U_high[adapter] + V_high = osf_linear._osf_V_high[adapter] + svd_params = osf_linear.osf_svd_params[adapter] + # Check orthogonality of gradients after projection + proj_u = U_high.T @ svd_params["U_low"].grad + proj_v = svd_params["V_low"].grad @ V_high.T + assert_close(proj_u, torch.zeros_like(proj_u), atol=1e-6, rtol=1e-6) + assert_close(proj_v, torch.zeros_like(proj_v), atol=1e-6, rtol=1e-6) + + +def test_osf_merge_unmerge_unsupported(): + model = DummyModel(DummyConfig()) + cfg = OSFConfig(target_modules=["linear"], effective_rank=2) + wrapped = get_peft_model(model, cfg) + with pytest.raises(NotImplementedError): + wrapped.merge_adapter() + with pytest.raises(NotImplementedError): + wrapped.unmerge_adapter() + with pytest.raises(NotImplementedError): + wrapped.merge_and_unload() diff --git a/tests/test_osf_utils.py b/tests/test_osf_utils.py deleted file mode 100644 index 2efb44d12b..0000000000 --- a/tests/test_osf_utils.py +++ /dev/null @@ -1,118 +0,0 @@ -from tempfile import TemporaryDirectory - -import pytest -import torch -from torch.testing import assert_close - -from peft import OSFConfig, PeftModel, get_peft_model -from peft.tuners.osf.utils import ( - decompose_weight_matrix, - reconstruct_weight_matrix, -) - - -def test_osf_roundtrip(): - w = torch.randn(10, 8) - svd = decompose_weight_matrix(w, top_k=4) - w_rec = reconstruct_weight_matrix(svd) - assert_close(w_rec, w, atol=1e-5, rtol=1e-5) - - -class DummyConfig(dict): - pass - - -class DummyModel(torch.nn.Module): - def __init__(self, config=None): - super().__init__() - self.config = config - self.linear = torch.nn.Linear(8, 4) - - def forward(self, x): - return self.linear(x) - - -def test_osf_get_peft_model_preserves_output(): - torch.manual_seed(0) - model = DummyModel(DummyConfig()) - x = torch.randn(2, 8) - y_ref = model(x) - wrapped = get_peft_model(model, OSFConfig()) - y = wrapped(x) - assert_close(y, y_ref, atol=1e-5, rtol=1e-5) - - -def test_osf_gradient_projection_hook(): - torch.manual_seed(0) - model = DummyModel(DummyConfig()) - cfg = OSFConfig(target_svd_config={"linear.weight": 2}) - wrapped = get_peft_model(model, cfg) - x = torch.randn(3, 8) - wrapped(x).sum().backward() - inner = wrapped.base_model.model - safe_name = next(iter(inner.svd_params)) - module_svd = inner.svd_params[safe_name] - U_high = getattr(inner, f"{safe_name}_U_high") - V_high = getattr(inner, f"{safe_name}_V_high") - assert_close( - U_high.T @ module_svd.U_low.grad, torch.zeros_like(U_high.T @ module_svd.U_low.grad), atol=1e-6, rtol=1e-6 - ) - assert_close( - module_svd.V_low.grad @ V_high.T, - torch.zeros_like(module_svd.V_low.grad @ V_high.T), - atol=1e-6, - rtol=1e-6, - ) - - -def test_osf_config_roundtrip(): - cfg = OSFConfig(target_svd_config={"linear.weight": 2}) - with TemporaryDirectory() as tmp: - cfg.save_pretrained(tmp) - loaded = OSFConfig.from_pretrained(tmp) - assert cfg.target_svd_config == loaded.target_svd_config - - -def test_osf_save_and_load_model(): - torch.manual_seed(0) - model = DummyModel(DummyConfig()) - cfg = OSFConfig(target_svd_config={"linear.weight": 2}) - wrapped = get_peft_model(model, cfg) - x = torch.randn(2, 8) - y_ref = wrapped(x) - with TemporaryDirectory() as tmp: - wrapped.save_pretrained(tmp) - torch.manual_seed(0) - base = DummyModel(DummyConfig()) - loaded = PeftModel.from_pretrained(base, tmp) - y = loaded(x) - assert_close(y, y_ref, atol=1e-5, rtol=1e-5) - - -def test_osf_save_and_load_autogen_config(): - torch.manual_seed(0) - model = DummyModel(DummyConfig()) - wrapped = get_peft_model(model, OSFConfig()) - x = torch.randn(2, 8) - y_ref = wrapped(x) - original_cfg = wrapped.peft_config["default"].target_svd_config - with TemporaryDirectory() as tmp: - wrapped.save_pretrained(tmp) - torch.manual_seed(0) - base = DummyModel(DummyConfig()) - loaded = PeftModel.from_pretrained(base, tmp) - y = loaded(x) - assert_close(y, y_ref, atol=1e-5, rtol=1e-5) - assert original_cfg == loaded.peft_config["default"].target_svd_config - - -def test_osf_merge_unmerge_unsupported(): - model = DummyModel(DummyConfig()) - cfg = OSFConfig(target_svd_config={"linear.weight": 2}) - wrapped = get_peft_model(model, cfg) - with pytest.raises(NotImplementedError): - wrapped.merge_adapter() - with pytest.raises(NotImplementedError): - wrapped.unmerge_adapter() - with pytest.raises(NotImplementedError): - wrapped.merge_and_unload() \ No newline at end of file diff --git a/tests/testing_common.py b/tests/testing_common.py index 3f441ecd2c..97428466c9 100644 --- a/tests/testing_common.py +++ b/tests/testing_common.py @@ -236,11 +236,19 @@ def test_something(model_id, config_kwargs): def _skip_if_merging_not_supported(model_id, config_cls): - """Skip tests for Conv2dGroups models or OSF configs where merging is not supported.""" - if model_id in ["Conv2dGroups", "Conv2dGroups2"] or issubclass(config_cls, OSFConfig): + """Skip tests for cases where adapter merge is unavailable. + + - Conv2dGroups: merge is not supported (by design) — see PR #2403. + - OSF: merge/unload are not implemented yet in the tuner. + """ + if model_id in ["Conv2dGroups", "Conv2dGroups2"]: + pytest.skip( + f"Skipping test for {model_id} as adapter merging is not supported for Conv2dGroups. " + "(See https://github.com/huggingface/peft/pull/2403)" + ) + if issubclass(config_cls, OSFConfig): pytest.skip( - f"Skipping test for {model_id} with {config_cls} as merging is not supported. " - "(See https://github.com/huggingface/peft/pull/2403 for details)" + f"Skipping test for {model_id} with {config_cls} as OSF adapter merge/unload are not implemented." )