diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 85439fd042..5ce6145ae1 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -114,6 +114,8 @@ title: VeRA - local: package_reference/fourierft title: FourierFT + - local: package_reference/glora + title: GLoRA - local: package_reference/vblora title: VB-LoRA - local: package_reference/hra diff --git a/docs/source/package_reference/glora.md b/docs/source/package_reference/glora.md new file mode 100644 index 0000000000..7c6c69eb98 --- /dev/null +++ b/docs/source/package_reference/glora.md @@ -0,0 +1,81 @@ + + +# GLora + +Generalized Low-Rank Adaptation (**GLora**) is a highly flexible PEFT method that generalizes LoRA and related approaches. GLora allows you to decompose weight updates into multiple configurable low-rank, vector, or constant paths, providing a superset of LoRA's expressivity. Each path (A, B, C, D, E) can be independently configured, enabling a wide range of adaptation strategies. + +GLora is especially useful for research and advanced applications where you want to experiment with different low-rank or structured update patterns, or combine multiple adaptation mechanisms in a single layer. + +## GLoraConfig + +[[autodoc]] tuners.glora.config.GLoraConfig + +### Key Configuration Options +- `r`: The rank of the low-rank matrices (default: 4). +- `target_modules`: List or regex of module names to adapt (e.g., `["q_proj", "v_proj"]`). +- `config_A_B`: Path type for A and B ("LoRA", "vector", "constant", "none"). +- `config_C`: Path type for C ("LoRA", "vector", "none"). +- `config_D_E`: Path type for D and E ("constant", "vector", "none"). + +Each path can be set independently, allowing for highly customized adaptation. + +## GLoraModel + +[[autodoc]] tuners.glora.model.GLoraModel + +- Wraps a base model and injects GLora adapters into the specified modules. +- Supports multiple adapters, adapter switching, merging/unmerging, and mixed-batch inference. +- Use `set_adapter`, `merge_and_unload`, and related methods for adapter management. + +## GLoraLayer and GLoraLinear + +[[autodoc]] tuners.glora.layer.GLoraLayer +[[autodoc]] tuners.glora.layer.Linear + +- `GLoraLayer` is the core logic for generalized low-rank adaptation, supporting multiple adapters and flexible path configs. +- `GLoraLinear` is a drop-in replacement for `nn.Linear` with GLora support. + +## Example Usage + +```python +from transformers import AutoModelForCausalLM +from peft import GLoraConfig, get_peft_model + +model = AutoModelForCausalLM.from_pretrained("your-model-id") +glora_config = GLoraConfig( + r=8, + target_modules=["q_proj", "v_proj"], + config_A_B="LoRA", + config_C="vector", + config_D_E="constant", + task_type="CAUSAL_LM", +) +model = get_peft_model(model, glora_config) +model.print_trainable_parameters() + +# Switch adapters, merge, etc. +model.set_adapter("default") +model.merge_and_unload() +``` + +## Notes +- GLora is a superset of LoRA: setting all paths to "LoRA" recovers standard LoRA. +- You can use different path types for A/B/C/D/E to experiment with new adaptation strategies. +- GLora supports all standard PEFT adapter management features (add, delete, switch, merge, etc). + +## See Also +- [Adapter conceptual guide](../conceptual_guides/adapter.md) +- [LoRA reference](./lora.md) diff --git a/src/peft/__init__.py b/src/peft/__init__.py index 03580a56ee..2192bd6c97 100644 --- a/src/peft/__init__.py +++ b/src/peft/__init__.py @@ -62,6 +62,8 @@ EvaConfig, FourierFTConfig, FourierFTModel, + GLoraConfig, + GLoraModel, HRAConfig, HRAModel, IA3Config, @@ -155,6 +157,8 @@ "EvaConfig", "FourierFTConfig", "FourierFTModel", + "GLoraConfig", + "GLoraModel", "HRAConfig", "HRAModel", "IA3Config", diff --git a/src/peft/tuners/__init__.py b/src/peft/tuners/__init__.py index 4c9bcc99ac..96cedf5716 100644 --- a/src/peft/tuners/__init__.py +++ b/src/peft/tuners/__init__.py @@ -19,6 +19,7 @@ from .c3a import C3AConfig, C3AModel from .cpt import CPTConfig, CPTEmbedding from .fourierft import FourierFTConfig, FourierFTModel +from .glora import GLoraConfig, GLoraModel from .hra import HRAConfig, HRAModel from .ia3 import IA3Config, IA3Model from .ln_tuning import LNTuningConfig, LNTuningModel @@ -69,6 +70,8 @@ "EvaConfig", "FourierFTConfig", "FourierFTModel", + "GLoraConfig", + "GLoraModel", "HRAConfig", "HRAModel", "IA3Config", diff --git a/src/peft/tuners/glora/__init__.py b/src/peft/tuners/glora/__init__.py new file mode 100644 index 0000000000..af887026c9 --- /dev/null +++ b/src/peft/tuners/glora/__init__.py @@ -0,0 +1,25 @@ +# Copyright 2025-present the HuggingFace Inc. team. + +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from peft.utils import register_peft_method + +from .config import GLoraConfig +from .layer import GLoraLayer, GLoraLinear +from .model import GLoraModel + + +__all__ = ["GLoraConfig", "GLoraLayer", "GLoraLinear", "GLoraModel"] + +register_peft_method(name="glora", config_cls=GLoraConfig, model_cls=GLoraModel) diff --git a/src/peft/tuners/glora/config.py b/src/peft/tuners/glora/config.py new file mode 100644 index 0000000000..df8c5cddf9 --- /dev/null +++ b/src/peft/tuners/glora/config.py @@ -0,0 +1,117 @@ +# Copyright 2025-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from typing import Optional, Union + +from peft.config import PeftConfig +from peft.utils.peft_types import PeftType + + +@dataclass +class GLoraConfig(PeftConfig): + """ + This is the configuration class to store the configuration of a [`GLoraModel`]. + + Args: + r (`int`): GLora attention dimension (rank of the LoRA matrices). + target_modules (`Optional[Union[List[str], str]]`): The names of the modules to apply GLora to. + config_A_B (`str`): Configuration for A and B matrices. Valid values: 'LoRA', 'vector', 'constant', 'none'. + config_C (`str`): Configuration for C matrix. Valid values: 'LoRA', 'vector', 'none'. + config_D_E (`str`): Configuration for D and E matrices. Valid values: 'constant', 'none', 'vector'. + """ + + _VALID_A_B_CONFIGS = {"LoRA", "vector", "constant", "none"} + _VALID_C_CONFIGS = {"LoRA", "vector", "none"} + _VALID_D_E_CONFIGS = {"constant", "none", "vector"} + + r: int = field( + default=4, metadata={"help": "Default rank of the LoRA matrices if the config contains LoRA parametrization."} + ) + target_modules: Optional[Union[list[str], str]] = field( + default=None, + metadata={ + "help": "List of module names or regex expression of the module names to replace with Lora." + "For example, ['q', 'v'] or '.*decoder.*(SelfAttention|EncDecAttention).*(q|v)$' " + }, + ) + + config_A_B: str = field( + default="LoRA", + metadata={ + "help": "Configuration for A and B matrices in GLora." + f"Valid values: {', '.join(_VALID_A_B_CONFIGS)}. " + "For LoRA, it will be post-processed to LoRA_." + }, + ) + + config_C: str = field( + default="LoRA", + metadata={ + "help": "Configuration for C matrix in GLora." + f"Valid values: {', '.join(_VALID_C_CONFIGS)}. " + "For LoRA, it will be post-processed to LoRA_." + }, + ) + + config_D_E: str = field( + default="constant", + metadata={ + "help": f"Configuration for D and E matrices in GLora. Valid values: {', '.join(_VALID_D_E_CONFIGS)}." + }, + ) + + def _validate_and_process_config( + self, config_value: str, valid_configs: set, config_name: str, allow_lora: bool = True + ) -> str: + """ + Validate and process a configuration value. + + Args: + config_value: The configuration value to validate + valid_configs: Set of valid configuration values + config_name: Name of the configuration (for error messages) + allow_lora: Whether LoRA configuration is allowed + + Returns: + Processed configuration value + + Raises: + ValueError: If the configuration value is invalid + """ + if config_value and "LoRA" in config_value: + if not allow_lora: + raise ValueError( + f"Invalid {config_name} value: {config_value}. LoRA is not supported for {config_name}." + ) + return f"LoRA_{self.r}" + + if config_value not in valid_configs: + raise ValueError( + f"Invalid {config_name} value: {config_value}. Valid values are: {', '.join(sorted(valid_configs))}." + ) + + return config_value + + def __post_init__(self): + self.peft_type = PeftType.GLORA + + # Validate and process each configuration + self.config_A_B = self._validate_and_process_config(self.config_A_B, self._VALID_A_B_CONFIGS, "config_A_B") + + self.config_C = self._validate_and_process_config(self.config_C, self._VALID_C_CONFIGS, "config_C") + + self.config_D_E = self._validate_and_process_config( + self.config_D_E, self._VALID_D_E_CONFIGS, "config_D_E", allow_lora=False + ) diff --git a/src/peft/tuners/glora/layer.py b/src/peft/tuners/glora/layer.py new file mode 100644 index 0000000000..96d21e240c --- /dev/null +++ b/src/peft/tuners/glora/layer.py @@ -0,0 +1,335 @@ +# Copyright 2025-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class GLoraLayer(nn.Module): + def __init__(self, in_features: int, out_features: int): + # Only call nn.Module.__init__ if not already initialized (i.e., not an nn.Linear) + if not isinstance(self, nn.Linear): + nn.Module.__init__(self) + self.in_features: int = in_features + self.out_features: int = out_features + self.r: dict[str, int] = {} + self.glora_Ad: nn.ParameterDict = nn.ParameterDict() + self.glora_Au: nn.ParameterDict = nn.ParameterDict() + self.glora_Bd: nn.ParameterDict = nn.ParameterDict() + self.glora_Bu: nn.ParameterDict = nn.ParameterDict() + self.glora_Cd: nn.ParameterDict = nn.ParameterDict() + self.glora_Cu: nn.ParameterDict = nn.ParameterDict() + self.glora_D: nn.ParameterDict = nn.ParameterDict() + self.glora_E: nn.ParameterDict = nn.ParameterDict() + self.eval_config: dict[str, dict[str, object]] = {} + self.merged_adapters: list[str] = [] + self._disable_adapters: bool = False + self.active_adapters: list[str] = [] + self.kwargs: dict[str, object] = {} + + def add_adapter(self, adapter_name: str, r: int, config_A_B: str, config_C: str, config_D_E: str): + self.r[adapter_name] = r + Ad, Au = self.make_param((self.out_features, self.in_features), f"LoRA_{r}") + Bd, Bu = self.make_param((self.out_features, self.in_features), f"LoRA_{r}") + Cd, Cu = self.make_param((self.in_features, 1), f"LoRA_{r}") + D = nn.Parameter(torch.zeros(self.out_features)) + E = nn.Parameter(torch.zeros(self.out_features)) + self.glora_Ad[adapter_name] = Ad + self.glora_Au[adapter_name] = Au + self.glora_Bd[adapter_name] = Bd + self.glora_Bu[adapter_name] = Bu + self.glora_Cd[adapter_name] = Cd + self.glora_Cu[adapter_name] = Cu + self.glora_D[adapter_name] = D + self.glora_E[adapter_name] = E + self.eval_config[adapter_name] = { + "A": config_A_B, + "B": config_A_B, + "C": config_C, + "D": config_D_E, + "E": config_D_E, + } + self.reset_glora_parameters(adapter_name) + if adapter_name not in self.active_adapters: + self.active_adapters.append(adapter_name) + + def reset_glora_parameters(self, adapter_name): + nn.init.kaiming_uniform_(self.glora_Au[adapter_name], a=math.sqrt(5)) + nn.init.kaiming_uniform_(self.glora_Bu[adapter_name], a=math.sqrt(5)) + nn.init.kaiming_uniform_(self.glora_Cu[adapter_name], a=math.sqrt(5)) + + def make_param(self, shape, config=None): + if config is not None and "LoRA" in config: + out_feature = shape[0] + in_feature = shape[1] + try: + rank = int(config.split("_")[1]) + except Exception: + rank = 4 + return nn.Parameter(torch.zeros(out_feature, rank)), nn.Parameter(torch.zeros(rank, in_feature)) + return nn.Parameter(torch.zeros(*shape)), nn.Parameter(torch.zeros(1, 1)) + + def set_adapter(self, adapter_name_or_list, inference_mode: Optional[bool] = None, **kwargs): + # Accepts inference_mode and other kwargs for PEFT compatibility + # If inference_mode is set, enable/disable adapters accordingly + if isinstance(adapter_name_or_list, str): + self.active_adapters = [adapter_name_or_list] + else: + self.active_adapters = list(adapter_name_or_list) + if inference_mode is not None: + if inference_mode: + self.disable_adapters() + else: + self.enable_adapters() + + def delete_adapter(self, adapter_name): + for d in [ + self.glora_Ad, + self.glora_Au, + self.glora_Bd, + self.glora_Bu, + self.glora_Cd, + self.glora_Cu, + self.glora_D, + self.glora_E, + ]: + if adapter_name in d: + del d[adapter_name] + if adapter_name in self.r: + del self.r[adapter_name] + if adapter_name in self.eval_config: + del self.eval_config[adapter_name] + if adapter_name in self.active_adapters: + self.active_adapters.remove(adapter_name) + if adapter_name in self.merged_adapters: + self.merged_adapters.remove(adapter_name) + + def enable_adapters(self): + self._disable_adapters = False + + def disable_adapters(self): + self._disable_adapters = True + + @property + def merged(self): + return len(self.merged_adapters) > 0 + + def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None): + if adapter_names is None: + adapter_names = self.active_adapters + for adapter_name in adapter_names: + if adapter_name in self.merged_adapters: + continue + path_config = self.eval_config[adapter_name] + # Ensure self.weight and self.bias are tensors + if not isinstance(self.weight, torch.Tensor): + raise TypeError(f"self.weight must be a torch.Tensor, got {type(self.weight)}") + weight = self.weight + if self.bias is not None and not isinstance(self.bias, torch.Tensor): + raise TypeError(f"self.bias must be a torch.Tensor or None, got {type(self.bias)}") + bias = self.bias + device, dtype = weight.device, weight.dtype + A = self.prepare_path( + path_config["A"], self.glora_Ad[adapter_name], self.glora_Au[adapter_name], device=device, dtype=dtype + ) + B = self.prepare_path( + path_config["B"], self.glora_Bd[adapter_name], self.glora_Bu[adapter_name], device=device, dtype=dtype + ) + C = self.prepare_path( + path_config["C"], self.glora_Cd[adapter_name], self.glora_Cu[adapter_name], device=device, dtype=dtype + ) + D = self.prepare_path(path_config["D"], self.glora_D[adapter_name], device=device, dtype=dtype) + E = self.prepare_path(path_config["E"], self.glora_E[adapter_name], device=device, dtype=dtype) + if safe_merge: + orig_weight = weight.clone() + orig_bias = bias.clone() if bias is not None else None + merged_weight = orig_weight + orig_weight * A + B + if not torch.isfinite(merged_weight).all(): + raise ValueError(f"NaNs detected in merged weights for adapter {adapter_name}") + self.weight.data = merged_weight + if bias is not None: + merged_bias = orig_bias + orig_bias * D + E + torch.matmul(weight, C).squeeze(-1) + if not torch.isfinite(merged_bias).all(): + raise ValueError(f"NaNs detected in merged bias for adapter {adapter_name}") + self.bias.data = merged_bias + else: + self.weight.data += (weight * A) + B + if bias is not None: + self.bias.data += (bias * D) + E + torch.matmul(weight, C).squeeze(-1) + elif E.numel() > 0 or C.numel() > 0: + new_bias_val = E + torch.matmul(weight, C).squeeze(-1) + if not torch.all(new_bias_val == 0): + self.bias = nn.Parameter(new_bias_val) + self.merged_adapters.append(adapter_name) + + def unmerge(self, adapter_names: Optional[list[str]] = None): + if adapter_names is None: + adapter_names = list(self.merged_adapters) + for adapter_name in adapter_names: + if adapter_name not in self.merged_adapters: + continue + path_config = self.eval_config[adapter_name] + if not isinstance(self.weight, torch.Tensor): + raise TypeError(f"self.weight must be a torch.Tensor, got {type(self.weight)}") + weight = self.weight + if self.bias is not None and not isinstance(self.bias, torch.Tensor): + raise TypeError(f"self.bias must be a torch.Tensor or None, got {type(self.bias)}") + bias = self.bias + device, dtype = weight.device, weight.dtype + A = self.prepare_path( + path_config["A"], self.glora_Ad[adapter_name], self.glora_Au[adapter_name], device=device, dtype=dtype + ) + B = self.prepare_path( + path_config["B"], self.glora_Bd[adapter_name], self.glora_Bu[adapter_name], device=device, dtype=dtype + ) + C = self.prepare_path( + path_config["C"], self.glora_Cd[adapter_name], self.glora_Cu[adapter_name], device=device, dtype=dtype + ) + D = self.prepare_path(path_config["D"], self.glora_D[adapter_name], device=device, dtype=dtype) + E = self.prepare_path(path_config["E"], self.glora_E[adapter_name], device=device, dtype=dtype) + self.weight.data -= (weight * A) + B + if bias is not None: + self.bias.data -= (bias * D) + E + torch.matmul(weight, C).squeeze(-1) + self.merged_adapters.remove(adapter_name) + + def forward(self, x: torch.Tensor, adapter_names: Optional[list[str]] = None) -> torch.Tensor: + if self._disable_adapters or not self.active_adapters: + return F.linear(x, self.weight, self.bias) + if adapter_names is not None: + result = F.linear(x, self.weight, self.bias) + unique_adapters = set(adapter_names) + sub_batch_indices_list = [ + [i for i, a in enumerate(adapter_names) if a == adapter] for adapter in unique_adapters + ] + for i, active_adapter in enumerate(unique_adapters): + if active_adapter not in self.glora_Ad: + continue + path_config = self.eval_config[active_adapter] + device, dtype = self.weight.device, self.weight.dtype + A = self.prepare_path( + path_config["A"], + self.glora_Ad[active_adapter], + self.glora_Au[active_adapter], + device=device, + dtype=dtype, + ) + B = self.prepare_path( + path_config["B"], + self.glora_Bd[active_adapter], + self.glora_Bu[active_adapter], + device=device, + dtype=dtype, + ) + C = self.prepare_path( + path_config["C"], + self.glora_Cd[active_adapter], + self.glora_Cu[active_adapter], + device=device, + dtype=dtype, + ) + D = self.prepare_path(path_config["D"], self.glora_D[active_adapter], device=device, dtype=dtype) + E = self.prepare_path(path_config["E"], self.glora_E[active_adapter], device=device, dtype=dtype) + sub_batch = x[sub_batch_indices_list[i]] + weight_eff = self.weight + self.weight * A + B + bias_eff = self.bias + if bias_eff is not None: + bias_eff = bias_eff + bias_eff * D + E + torch.matmul(self.weight, C).squeeze(-1) + else: + new_bias_val = E + torch.matmul(self.weight, C).squeeze(-1) + if not torch.all(new_bias_val == 0): + bias_eff = new_bias_val + result[sub_batch_indices_list[i]] = F.linear(sub_batch, weight_eff, bias=bias_eff) + return result + result = F.linear(x, self.weight, self.bias) + for active_adapter in self.active_adapters: + if active_adapter not in self.glora_Ad: + continue + path_config = self.eval_config[active_adapter] + device, dtype = self.weight.device, self.weight.dtype + A = self.prepare_path( + path_config["A"], + self.glora_Ad[active_adapter], + self.glora_Au[active_adapter], + device=device, + dtype=dtype, + ) + B = self.prepare_path( + path_config["B"], + self.glora_Bd[active_adapter], + self.glora_Bu[active_adapter], + device=device, + dtype=dtype, + ) + C = self.prepare_path( + path_config["C"], + self.glora_Cd[active_adapter], + self.glora_Cu[active_adapter], + device=device, + dtype=dtype, + ) + D = self.prepare_path(path_config["D"], self.glora_D[active_adapter], device=device, dtype=dtype) + E = self.prepare_path(path_config["E"], self.glora_E[active_adapter], device=device, dtype=dtype) + weight_eff = self.weight + self.weight * A + B + bias_eff = self.bias + if bias_eff is not None: + bias_eff = bias_eff + bias_eff * D + E + torch.matmul(self.weight, C).squeeze(-1) + else: + new_bias_val = E + torch.matmul(self.weight, C).squeeze(-1) + if not torch.all(new_bias_val == 0): + bias_eff = new_bias_val + result = F.linear(x, weight_eff, bias=bias_eff) + return result + + def prepare_path(self, config: str, Xd: nn.Parameter, Xu: Optional[nn.Parameter] = None, device=None, dtype=None): + device = device or Xd.device + dtype = dtype or Xd.dtype + if Xu is not None: + if "LoRA" in config: + rank = int(config.split("_")[1]) + X = torch.matmul(Xd[:, :rank], Xu[:rank, :]) + elif "vector" in config: + X = Xd[:, 0].unsqueeze(1) + elif "constant" in config: + X = Xd[0, 0] + elif "none" in config: + X = torch.zeros(Xd.shape[0], Xu.shape[1], device=device, dtype=dtype) + else: + raise ValueError(f"Unknown config choice: {config} for decomposable path") + else: + if "vector" in config: + X = Xd + elif "constant" in config: + X = Xd[0] + elif "none" in config: + X = torch.zeros(1, device=device, dtype=dtype) + else: + raise ValueError(f"Unknown config choice: {config} for non-decomposable path") + return X.to(device=device, dtype=dtype) + + +# Refactored GLoraLinear for PEFT compatibility +class GLoraLinear(GLoraLayer, nn.Linear): + def __init__(self, in_features, out_features, bias=True, **kwargs): + nn.Linear.__init__(self, in_features, out_features, bias=bias) + GLoraLayer.__init__(self, in_features=in_features, out_features=out_features) + self.weight.requires_grad = False + if self.bias is not None: + self.bias.requires_grad = False + self._disable_adapters = False + self.active_adapters = [] + self.merged_adapters = [] diff --git a/src/peft/tuners/glora/model.py b/src/peft/tuners/glora/model.py new file mode 100644 index 0000000000..104a97b4ac --- /dev/null +++ b/src/peft/tuners/glora/model.py @@ -0,0 +1,304 @@ +import re +from dataclasses import asdict +from enum import Enum +from typing import Any, Optional, Union + +import torch.nn as nn +from tqdm import tqdm + +from peft.tuners.tuners_utils import BaseTuner +from peft.utils import ( + TRANSFORMERS_MODELS_TO_GLORA_TARGET_MODULES_MAPPING, + ModulesToSaveWrapper, + _freeze_adapter, + _get_submodules, +) +from peft.utils.peft_types import PeftType + +from .config import GLoraConfig +from .layer import GLoraLinear + + +def mark_only_glora_as_trainable(model: nn.Module, bias: str = "none") -> None: + """ + Freezes all parameters of the model except the GLORA parameters. + If bias is 'glora_only', 'all', or 'some_other_custom', it handles bias terms as well. + """ + for n, p in model.named_parameters(): + if "glora_" not in n: + p.requires_grad = False + + if bias == "none": + return + elif bias == "all": + for n, p in model.named_parameters(): + if "bias" in n: + p.requires_grad = True + elif bias == "glora_only": + for m in model.modules(): + if isinstance(m, GLoraLinear) and hasattr(m, "bias") and m.bias is not None: + m.bias.requires_grad = True + + +class GLoraModel(BaseTuner): + """ + Creates Generalized Low Rank Adapter (GLora) model from a pretrained transformers model. + """ + + def __init__(self, model: nn.Module, config: GLoraConfig, adapter_name: str = "default"): + super().__init__(model, config, adapter_name) + self.model = model + self.forward = self.model.forward + + self.peft_config: dict[str, GLoraConfig] = {} + self.active_adapter: Union[str, list[str]] = adapter_name + self.peft_type = PeftType.GLORA + self.adapters_config_history: dict[str, Any] = {} + + # Accept both single config and dict of configs + if isinstance(config, GLoraConfig): + self.peft_config[adapter_name] = config + elif isinstance(config, dict): + for name, cfg in config.items(): + self.peft_config[name] = cfg + else: + raise TypeError(f"Unsupported config type: {type(config)}") + + # Add all adapters after peft_config is set + for name, cfg in self.peft_config.items(): + self.add_adapter(name, cfg) + + def add_adapter(self, adapter_name: str, config: GLoraConfig): + # Avoid re-adding if already present + if hasattr(self, "_added_adapters") and adapter_name in self._added_adapters: + return + if not hasattr(self, "_added_adapters"): + self._added_adapters = set() + + # Prepare config (resolve target_modules if needed) + model_config_dict = self.model.config.to_dict() if hasattr(self.model.config, "to_dict") else self.model.config + current_config = self._prepare_glora_config(config, model_config_dict) + self.peft_config[adapter_name] = current_config + + # Replace or add adapters in target modules + self._find_and_replace(adapter_name) + + # Mark only GLora params as trainable + mark_only_glora_as_trainable(self.model, bias=getattr(current_config, "bias", "none")) + + # Optionally freeze for inference + if getattr(current_config, "inference_mode", False): + _freeze_adapter(self.model, adapter_name) + + self._added_adapters.add(adapter_name) + + def _check_target_module_exists(self, glora_config: GLoraConfig, key: str) -> bool: + if isinstance(glora_config.target_modules, str): + return bool(re.fullmatch(glora_config.target_modules, key)) + elif isinstance(glora_config.target_modules, list): + return any(key.endswith(target_key) for target_key in glora_config.target_modules) + return False + + def _create_new_module(self, glora_config: GLoraConfig, adapter_name: str, target: nn.Module) -> GLoraLinear: + bias = hasattr(target, "bias") and target.bias is not None + if not isinstance(target, nn.Linear): + raise ValueError( + f"Target module {target} is not a nn.Linear layer, which is required for GLORA replacement." + ) + + in_features, out_features = target.in_features, target.out_features + kwargs_glora = { + "config_A_B": glora_config.config_A_B, + "config_C": glora_config.config_C, + "config_D_E": glora_config.config_D_E, + } + new_module = GLoraLinear(in_features, out_features, bias=bias, **kwargs_glora) + # Add the adapter to the new module + new_module.add_adapter( + adapter_name, + glora_config.r, + glora_config.config_A_B, + glora_config.config_C, + glora_config.config_D_E, + ) + return new_module + + def _find_and_replace(self, adapter_name: str): + glora_config = self.peft_config[adapter_name] + is_target_modules_in_base_model = False + key_list = [key for key, _ in self.model.named_modules()] # Cache keys + + for key in key_list: + if not self._check_target_module_exists(glora_config, key): + continue + + is_target_modules_in_base_model = True + parent, target, target_name = _get_submodules(self.model, key) + + if isinstance(target, GLoraLinear): + # Add adapter to existing GLoraLinear + target.add_adapter( + adapter_name, + glora_config.r, + glora_config.config_A_B, + glora_config.config_C, + glora_config.config_D_E, + ) + elif isinstance(target, nn.Linear): + new_module = self._create_new_module(glora_config, adapter_name, target) + self._replace_module(parent, target_name, new_module, target) + + if not is_target_modules_in_base_model: + raise ValueError( + f"Target modules {glora_config.target_modules} not found in the base model. " + f"Please check the target modules and try again." + ) + + def _replace_module(self, parent_module: nn.Module, child_name: str, new_module: nn.Module, old_module: nn.Module): + setattr(parent_module, child_name, new_module) + # Copy weights and bias + if hasattr(old_module, "weight") and hasattr(new_module, "weight"): + new_module.weight = old_module.weight + if hasattr(old_module, "bias") and hasattr(new_module, "bias") and old_module.bias is not None: + new_module.bias = old_module.bias + # Copy state if present + if getattr(old_module, "state", None) is not None: + new_module.state = old_module.state + new_module.to(old_module.weight.device) + + def __getattr__(self, name: str): + try: + return super().__getattr__(name) + except AttributeError: + if name == "model": + raise + return getattr(self.model, name) + + @staticmethod + def _prepare_glora_config(peft_config: GLoraConfig, model_config: dict) -> GLoraConfig: + if peft_config.target_modules is None: + if model_config["model_type"] not in TRANSFORMERS_MODELS_TO_GLORA_TARGET_MODULES_MAPPING: + raise ValueError( + f"Please specify `target_modules` in `GLoraConfig` for model_type {model_config['model_type']}" + ) + peft_config.target_modules = TRANSFORMERS_MODELS_TO_GLORA_TARGET_MODULES_MAPPING[ + model_config["model_type"] + ] + return peft_config + + def set_adapter(self, adapter_name_or_list, **kwargs): + if self.active_adapter == adapter_name_or_list: + print("Adapter already active, no change made.") + return + + print("adapter_name_or_list:", adapter_name_or_list) + + for module in self.model.modules(): + if hasattr(module, "set_adapter"): + print("module:", module.__class__) + + for module in self.model.modules(): + if hasattr(module, "set_adapter"): + module.set_adapter(adapter_name_or_list, **kwargs) + self.active_adapter = adapter_name_or_list + + def enable_adapter_layers(self): + for module in self.model.modules(): + if hasattr(module, "enable_adapters"): + module.enable_adapters() + + def disable_adapter_layers(self): + for module in self.model.modules(): + if hasattr(module, "disable_adapters"): + module.disable_adapters() + + def delete_adapter(self, adapter_name: str): + if adapter_name not in self.peft_config: + raise ValueError(f"Adapter {adapter_name} does not exist") + del self.peft_config[adapter_name] + for module in self.model.modules(): + if hasattr(module, "delete_adapter"): + module.delete_adapter(adapter_name) + # Update active_adapter if needed + if self.active_adapter == adapter_name: + self.active_adapter = next(iter(self.peft_config.keys()), None) + + def merge_and_unload(self, progressbar: bool = False, adapter_names: Optional[list[str]] = None): + """ + This method merges the GLora layers into the base model. + """ + if getattr(self, "hf_device_map", None): + raise ValueError("Merging LoRA weights is not supported when using HF device map.") + + key_list = [key for key, _ in self.model.named_modules()] + desc = "Merging GLORA layers" + + for key in tqdm(key_list, disable=not progressbar, desc=desc): + try: + parent, target, target_name = _get_submodules(self.model, key) + except AttributeError: + continue + + if isinstance(target, GLoraLinear): + if not target.active_adapters: + continue + # Merge all or specified adapters + merge_adapters = adapter_names if adapter_names is not None else target.active_adapters + target.merge(adapter_names=merge_adapters) + new_module = nn.Linear(target.in_features, target.out_features, bias=(target.bias is not None)) + new_module.weight.data = target.weight.data.clone() # Get merged weight + if target.bias is not None: + new_module.bias.data = target.bias.data.clone() # Get merged bias + self._replace_module(parent, target_name, new_module.to(target.weight.device), target) + + if isinstance(target, ModulesToSaveWrapper): + pass + return self.model + + def set_adapter_eval_config(self, adapter_name: str, eval_config: dict[str, str]): + """ + Sets the evaluation configuration for all GLoraLinear layers associated with a given adapter. + The eval_config dictionary should specify the path choices for A, B, C, D, E. + Example: {'A':'LoRA_4', 'B':'none', 'C':'vector', 'D':'constant', 'E':'none'} + """ + if adapter_name not in self.peft_config: + raise ValueError(f"Adapter {adapter_name} not found.") + + for module in self.model.modules(): + if isinstance(module, GLoraLinear): + if adapter_name in module.eval_config: + module.eval_config[adapter_name] = eval_config + self.adapters_config_history[adapter_name] = eval_config + + def get_peft_config_as_dict(self, inference: bool = False) -> dict[str, Any]: + config_dict = {} + for adapter_name, peft_config_obj in self.peft_config.items(): + config = asdict(peft_config_obj) + if inference: + config["inference_mode"] = True + for k, v in config.items(): + if isinstance(v, Enum): + config[k] = v.value + config_dict[adapter_name] = config + return config_dict + + def _create_and_replace( + self, + peft_config, + adapter_name, + target, + target_name, + parent, + current_key, + parameter_name: Optional[str] = None, + ): + new_module = self._create_new_module(peft_config, adapter_name, target) + self._replace_module(parent, target_name, new_module, target) + + @staticmethod + def _mark_only_adapters_as_trainable(model: nn.Module): + mark_only_glora_as_trainable(model) + + @staticmethod + def _prepare_adapter_config(peft_config, model_config): + return GLoraModel._prepare_glora_config(peft_config, model_config) diff --git a/src/peft/utils/__init__.py b/src/peft/utils/__init__.py index 36cb7b0611..5acfd1638c 100644 --- a/src/peft/utils/__init__.py +++ b/src/peft/utils/__init__.py @@ -21,6 +21,7 @@ TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING, TRANSFORMERS_MODELS_TO_C3A_TARGET_MODULES_MAPPING, TRANSFORMERS_MODELS_TO_FOURIERFT_TARGET_MODULES_MAPPING, + TRANSFORMERS_MODELS_TO_GLORA_TARGET_MODULES_MAPPING, TRANSFORMERS_MODELS_TO_IA3_FEEDFORWARD_MODULES_MAPPING, TRANSFORMERS_MODELS_TO_IA3_TARGET_MODULES_MAPPING, TRANSFORMERS_MODELS_TO_LNTUNING_TARGET_MODULES_MAPPING, @@ -68,6 +69,7 @@ "TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING", "TRANSFORMERS_MODELS_TO_C3A_TARGET_MODULES_MAPPING", "TRANSFORMERS_MODELS_TO_FOURIERFT_TARGET_MODULES_MAPPING", + "TRANSFORMERS_MODELS_TO_GLORA_TARGET_MODULES_MAPPING", "TRANSFORMERS_MODELS_TO_IA3_FEEDFORWARD_MODULES_MAPPING", "TRANSFORMERS_MODELS_TO_IA3_TARGET_MODULES_MAPPING", "TRANSFORMERS_MODELS_TO_LNTUNING_TARGET_MODULES_MAPPING", diff --git a/src/peft/utils/constants.py b/src/peft/utils/constants.py index 21d2b3ab4b..946bc890cf 100644 --- a/src/peft/utils/constants.py +++ b/src/peft/utils/constants.py @@ -136,6 +136,7 @@ def starcoder_model_postprocess_past_key_value(past_key_values): "qwen3": ["q_proj", "v_proj"], } +TRANSFORMERS_MODELS_TO_GLORA_TARGET_MODULES_MAPPING = TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING.copy() # need to check this later TRANSFORMERS_MODELS_TO_LOKR_TARGET_MODULES_MAPPING = TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING.copy() TRANSFORMERS_MODELS_TO_LOHA_TARGET_MODULES_MAPPING = TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING.copy() diff --git a/src/peft/utils/other.py b/src/peft/utils/other.py index 0d0b24b8c1..6d8e40af17 100644 --- a/src/peft/utils/other.py +++ b/src/peft/utils/other.py @@ -44,6 +44,7 @@ TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING, TRANSFORMERS_MODELS_TO_C3A_TARGET_MODULES_MAPPING, TRANSFORMERS_MODELS_TO_FOURIERFT_TARGET_MODULES_MAPPING, + TRANSFORMERS_MODELS_TO_GLORA_TARGET_MODULES_MAPPING, TRANSFORMERS_MODELS_TO_IA3_FEEDFORWARD_MODULES_MAPPING, TRANSFORMERS_MODELS_TO_IA3_TARGET_MODULES_MAPPING, TRANSFORMERS_MODELS_TO_LNTUNING_TARGET_MODULES_MAPPING, @@ -76,6 +77,7 @@ "TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING", "TRANSFORMERS_MODELS_TO_C3A_TARGET_MODULES_MAPPING", "TRANSFORMERS_MODELS_TO_FOURIERFT_TARGET_MODULES_MAPPING", + "TRANSFORMERS_MODELS_TO_GLORA_TARGET_MODULES_MAPPING", "TRANSFORMERS_MODELS_TO_IA3_FEEDFORWARD_MODULES_MAPPING", "TRANSFORMERS_MODELS_TO_IA3_TARGET_MODULES_MAPPING", "TRANSFORMERS_MODELS_TO_LNTUNING_TARGET_MODULES_MAPPING", diff --git a/src/peft/utils/peft_types.py b/src/peft/utils/peft_types.py index e13f8ac5f7..3d715aa4d4 100644 --- a/src/peft/utils/peft_types.py +++ b/src/peft/utils/peft_types.py @@ -38,6 +38,7 @@ class PeftType(str, enum.Enum): - LN_TUNING - VERA - FOURIERFT + - GLORA - HRA - BONE - MISS @@ -63,6 +64,7 @@ class PeftType(str, enum.Enum): LN_TUNING = "LN_TUNING" VERA = "VERA" FOURIERFT = "FOURIERFT" + GLORA = "GLORA" XLORA = "XLORA" HRA = "HRA" VBLORA = "VBLORA" diff --git a/tests/test_glora.py b/tests/test_glora.py new file mode 100644 index 0000000000..70006eedb3 --- /dev/null +++ b/tests/test_glora.py @@ -0,0 +1,221 @@ +import gc +import tempfile +import unittest + +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig + +from peft import ( + GLoraConfig, + PeftModel, + TaskType, + get_peft_model, + prepare_model_for_kbit_training, +) +from peft.tuners.glora.layer import Linear as GLoraLinear + + +# A very simple model for testing +class SimpleTransformer(torch.nn.Module): + def __init__(self, vocab_size=100, hidden_size=16, num_layers=2, num_heads=2): + super().__init__() + self.embedding = torch.nn.Embedding(vocab_size, hidden_size) + self.layers = torch.nn.ModuleList( + [ + torch.nn.TransformerEncoderLayer( + d_model=hidden_size, + nhead=num_heads, + dim_feedforward=hidden_size * 2, # Simple FFN + batch_first=True, + ) + for _ in range(num_layers) + ] + ) + self.linear = torch.nn.Linear(hidden_size, hidden_size) # A targetable linear layer + self.lm_head = torch.nn.Linear(hidden_size, vocab_size) + + # Add a config attribute similar to HF models for _prepare_glora_config + class SimpleConfig: + model_type = "simple_transformer" # Needs a mapping in constants.py or explicit target_modules + + self.config = SimpleConfig() + + def forward(self, input_ids): + x = self.embedding(input_ids) + for layer in self.layers: + x = layer(x) + x = self.linear(x) # Pass through the targetable linear layer + logits = self.lm_head(x) + return logits + + +class DummyTokenizer: + pad_token = 0 + eos_token = 0 + + def __call__(self, *args, **kwargs): + return {"input_ids": torch.randint(0, 100, (2, 5))} + + def batch_decode(self, *args, **kwargs): + return ["decoded text"] + + +class GLORATester(unittest.TestCase): + def setUp(self): + self.model_id = "HuggingFaceM4/tiny-random-Llama-3" + try: + self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) + self.base_model = AutoModelForCausalLM.from_pretrained(self.model_id) + except Exception: + print("Failed to load HF tiny model, using SimpleTransformer for tests.") + self.base_model = SimpleTransformer() + self.tokenizer = DummyTokenizer() + + def tearDown(self): + gc.collect() + torch.cuda.empty_cache() + gc.collect() + + def _get_target_modules(self): + if isinstance(self.base_model, SimpleTransformer): + return ["linear", "lm_head"] + else: + return [ + "q_proj", + "v_proj", + "o_proj", + "down_proj", + "up_proj", + "gate_proj", + ] + + def test_glora_model_creation_and_forward(self): + target_modules = self._get_target_modules() + glora_config = GLoraConfig(r=4, target_modules=target_modules, task_type=TaskType.CAUSAL_LM) + peft_model = get_peft_model(self.base_model, glora_config) + assert isinstance(peft_model, PeftModel) + assert isinstance(peft_model.base_model, type(self.base_model)) + + if hasattr(self.tokenizer, "pad_token") and getattr(self.tokenizer, "pad_token", None) is None: + self.tokenizer.pad_token = getattr(self.tokenizer, "eos_token", 0) + + if isinstance(self.base_model, SimpleTransformer): + dummy_input = torch.randint(0, 100, (2, 10)) + else: + dummy_input = self.tokenizer("This is a test prompt", return_tensors="pt")["input_ids"] + + peft_model.eval() + + # Set deterministic eval_config for all GLoraLinear layers + for module in peft_model.modules(): + if isinstance(module, GLoraLinear): + chosen_eval_config = module.configs[0] + module.eval_config = chosen_eval_config + + with torch.no_grad(): + output_peft = peft_model(dummy_input) + output_base = self.base_model(dummy_input) + assert isinstance(output_peft.shape, output_base.shape) + + def test_save_and_load_glora_adapter(self): + target_modules = self._get_target_modules() + glora_config = GLoraConfig(r=4, target_modules=target_modules, task_type=TaskType.CAUSAL_LM) + peft_model = get_peft_model(self.base_model, glora_config, adapter_name="test_adapter") + + with tempfile.TemporaryDirectory() as tmp_dirname: + peft_model.save_pretrained(tmp_dirname, safe_serialization=False) + if isinstance(self.base_model, SimpleTransformer): + loaded_base_model = SimpleTransformer() + else: + loaded_base_model = AutoModelForCausalLM.from_pretrained(self.model_id) + loaded_peft_model = PeftModel.from_pretrained( + loaded_base_model, tmp_dirname, adapter_name="test_adapter_loaded" + ) + assert isinstance(loaded_peft_model, PeftModel) + # Compare GLORA parameters + original_glora_params = { + k: v for k, v in peft_model.named_parameters() if "glora_" in k and v.requires_grad + } + loaded_glora_params = { + k: v for k, v in loaded_peft_model.named_parameters() if "glora_" in k and v.requires_grad + } + assert len(original_glora_params) == len(loaded_glora_params) + for (k_orig, v_orig), (k_load, v_load) in zip(original_glora_params.items(), loaded_glora_params.items()): + assert torch.allclose(v_orig, v_load) + + def test_merge_and_unload_glora(self): + target_modules = self._get_target_modules() + glora_config = GLoraConfig(r=4, target_modules=target_modules, task_type=TaskType.CAUSAL_LM) + peft_model = get_peft_model(self.base_model, glora_config) + # Set deterministic eval_config for all GLoraLinear layers + for module in peft_model.modules(): + if isinstance(module, GLoraLinear): + chosen_eval_config = module.configs[0] + module.eval_config = chosen_eval_config + # Store original weights for comparison + target_layer_name = ( + glora_config.target_modules[0] if isinstance(glora_config.target_modules, list) else "linear" + ) + module_ptr = peft_model.model + for part in target_layer_name.split("."): + module_ptr = getattr(module_ptr, part) + if isinstance(module_ptr, GLoraLinear): + original_weight = module_ptr.weight.data.clone() + else: + self.skipTest(f"Target module {target_layer_name} is not a GLoraLinear layer after PEFT application.") + merged_model = peft_model.merge_and_unload() + assert not isinstance(merged_model, PeftModel) + assert isinstance(merged_model, type(self.base_model)) + merged_weight_module_ptr = merged_model + for part in target_layer_name.split("."): + merged_weight_module_ptr = getattr(merged_weight_module_ptr, part) + merged_weight = merged_weight_module_ptr.weight.data + assert not torch.allclose(original_weight, merged_weight) + if isinstance(self.base_model, SimpleTransformer): + dummy_input = torch.randint(0, 100, (2, 10)) + else: + dummy_input = self.tokenizer("This is a test prompt after merging", return_tensors="pt")["input_ids"] + with torch.no_grad(): + merged_model.eval() + _ = merged_model(dummy_input) + + @unittest.skipIf( + not torch.cuda.is_available() + or not hasattr(torch.cuda, "is_bf16_supported") + or not torch.cuda.is_bf16_supported(), + "BF16 not supported or no CUDA", + ) + def test_glora_with_kbit_training(self): + try: + quantization_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_compute_dtype=torch.bfloat16, + ) + model_4bit = AutoModelForCausalLM.from_pretrained( + self.model_id, + quantization_config=quantization_config, + device_map="auto", + ) + model_4bit = prepare_model_for_kbit_training(model_4bit) + except Exception as e: + self.skipTest(f"bitsandbytes or quantized model loading failed: {e}") + glora_config = GLoraConfig( + r=4, + target_modules=["q_proj", "v_proj"], + task_type=TaskType.CAUSAL_LM, + ) + peft_model = get_peft_model(model_4bit, glora_config) + assert isinstance(peft_model, PeftModel) + dummy_input = self.tokenizer("Test with 4-bit GLORA", return_tensors="pt")["input_ids"].to(peft_model.device) + for module in peft_model.modules(): + if isinstance(module, GLoraLinear): + chosen_eval_config = module.configs[0] + module.eval_config = chosen_eval_config + with torch.no_grad(): + peft_model.eval() + output = peft_model(dummy_input) + assert output is not None + + +if __name__ == "__main__": + unittest.main(verbosity=2)