Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@
title: VeRA
- local: package_reference/fourierft
title: FourierFT
- local: package_reference/glora
title: GLoRA
- local: package_reference/vblora
title: VB-LoRA
- local: package_reference/hra
Expand Down
81 changes: 81 additions & 0 deletions docs/source/package_reference/glora.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.

⚠️ Note that this file is in Markdown but contains specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->

# GLora

Generalized Low-Rank Adaptation (**GLora**) is a highly flexible PEFT method that generalizes LoRA and related approaches. GLora allows you to decompose weight updates into multiple configurable low-rank, vector, or constant paths, providing a superset of LoRA's expressivity. Each path (A, B, C, D, E) can be independently configured, enabling a wide range of adaptation strategies.

GLora is especially useful for research and advanced applications where you want to experiment with different low-rank or structured update patterns, or combine multiple adaptation mechanisms in a single layer.

## GLoraConfig

[[autodoc]] tuners.glora.config.GLoraConfig

### Key Configuration Options
- `r`: The rank of the low-rank matrices (default: 4).
- `target_modules`: List or regex of module names to adapt (e.g., `["q_proj", "v_proj"]`).
- `config_A_B`: Path type for A and B ("LoRA", "vector", "constant", "none").
- `config_C`: Path type for C ("LoRA", "vector", "none").
- `config_D_E`: Path type for D and E ("constant", "vector", "none").

Each path can be set independently, allowing for highly customized adaptation.

## GLoraModel

[[autodoc]] tuners.glora.model.GLoraModel

- Wraps a base model and injects GLora adapters into the specified modules.
- Supports multiple adapters, adapter switching, merging/unmerging, and mixed-batch inference.
- Use `set_adapter`, `merge_and_unload`, and related methods for adapter management.

## GLoraLayer and GLoraLinear

[[autodoc]] tuners.glora.layer.GLoraLayer
[[autodoc]] tuners.glora.layer.Linear

- `GLoraLayer` is the core logic for generalized low-rank adaptation, supporting multiple adapters and flexible path configs.
- `GLoraLinear` is a drop-in replacement for `nn.Linear` with GLora support.

## Example Usage

```python
from transformers import AutoModelForCausalLM
from peft import GLoraConfig, get_peft_model

model = AutoModelForCausalLM.from_pretrained("your-model-id")
glora_config = GLoraConfig(
r=8,
target_modules=["q_proj", "v_proj"],
config_A_B="LoRA",
config_C="vector",
config_D_E="constant",
task_type="CAUSAL_LM",
)
model = get_peft_model(model, glora_config)
model.print_trainable_parameters()

# Switch adapters, merge, etc.
model.set_adapter("default")
model.merge_and_unload()
```

## Notes
- GLora is a superset of LoRA: setting all paths to "LoRA" recovers standard LoRA.
- You can use different path types for A/B/C/D/E to experiment with new adaptation strategies.
- GLora supports all standard PEFT adapter management features (add, delete, switch, merge, etc).

## See Also
- [Adapter conceptual guide](../conceptual_guides/adapter.md)
- [LoRA reference](./lora.md)
4 changes: 4 additions & 0 deletions src/peft/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@
EvaConfig,
FourierFTConfig,
FourierFTModel,
GLoraConfig,
GLoraModel,
HRAConfig,
HRAModel,
IA3Config,
Expand Down Expand Up @@ -155,6 +157,8 @@
"EvaConfig",
"FourierFTConfig",
"FourierFTModel",
"GLoraConfig",
"GLoraModel",
"HRAConfig",
"HRAModel",
"IA3Config",
Expand Down
3 changes: 3 additions & 0 deletions src/peft/tuners/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .c3a import C3AConfig, C3AModel
from .cpt import CPTConfig, CPTEmbedding
from .fourierft import FourierFTConfig, FourierFTModel
from .glora import GLoraConfig, GLoraModel
from .hra import HRAConfig, HRAModel
from .ia3 import IA3Config, IA3Model
from .ln_tuning import LNTuningConfig, LNTuningModel
Expand Down Expand Up @@ -69,6 +70,8 @@
"EvaConfig",
"FourierFTConfig",
"FourierFTModel",
"GLoraConfig",
"GLoraModel",
"HRAConfig",
"HRAModel",
"IA3Config",
Expand Down
25 changes: 25 additions & 0 deletions src/peft/tuners/glora/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Copyright 2025-present the HuggingFace Inc. team.

#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from peft.utils import register_peft_method

from .config import GLoraConfig
from .layer import GLoraLayer, GLoraLinear
from .model import GLoraModel


__all__ = ["GLoraConfig", "GLoraLayer", "GLoraLinear", "GLoraModel"]

register_peft_method(name="glora", config_cls=GLoraConfig, model_cls=GLoraModel)
117 changes: 117 additions & 0 deletions src/peft/tuners/glora/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# Copyright 2025-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass, field
from typing import Optional, Union

from peft.config import PeftConfig
from peft.utils.peft_types import PeftType


@dataclass
class GLoraConfig(PeftConfig):
"""
This is the configuration class to store the configuration of a [`GLoraModel`].

Args:
r (`int`): GLora attention dimension (rank of the LoRA matrices).
target_modules (`Optional[Union[List[str], str]]`): The names of the modules to apply GLora to.
config_A_B (`str`): Configuration for A and B matrices. Valid values: 'LoRA', 'vector', 'constant', 'none'.
config_C (`str`): Configuration for C matrix. Valid values: 'LoRA', 'vector', 'none'.
config_D_E (`str`): Configuration for D and E matrices. Valid values: 'constant', 'none', 'vector'.
"""

_VALID_A_B_CONFIGS = {"LoRA", "vector", "constant", "none"}
_VALID_C_CONFIGS = {"LoRA", "vector", "none"}
_VALID_D_E_CONFIGS = {"constant", "none", "vector"}

r: int = field(
default=4, metadata={"help": "Default rank of the LoRA matrices if the config contains LoRA parametrization."}
)
target_modules: Optional[Union[list[str], str]] = field(
default=None,
metadata={
"help": "List of module names or regex expression of the module names to replace with Lora."
"For example, ['q', 'v'] or '.*decoder.*(SelfAttention|EncDecAttention).*(q|v)$' "
},
)

config_A_B: str = field(
default="LoRA",
metadata={
"help": "Configuration for A and B matrices in GLora."
f"Valid values: {', '.join(_VALID_A_B_CONFIGS)}. "
"For LoRA, it will be post-processed to LoRA_<rank>."
},
)

config_C: str = field(
default="LoRA",
metadata={
"help": "Configuration for C matrix in GLora."
f"Valid values: {', '.join(_VALID_C_CONFIGS)}. "
"For LoRA, it will be post-processed to LoRA_<rank>."
},
)

config_D_E: str = field(
default="constant",
metadata={
"help": f"Configuration for D and E matrices in GLora. Valid values: {', '.join(_VALID_D_E_CONFIGS)}."
},
)

def _validate_and_process_config(
self, config_value: str, valid_configs: set, config_name: str, allow_lora: bool = True
) -> str:
"""
Validate and process a configuration value.

Args:
config_value: The configuration value to validate
valid_configs: Set of valid configuration values
config_name: Name of the configuration (for error messages)
allow_lora: Whether LoRA configuration is allowed

Returns:
Processed configuration value

Raises:
ValueError: If the configuration value is invalid
"""
if config_value and "LoRA" in config_value:
if not allow_lora:
raise ValueError(
f"Invalid {config_name} value: {config_value}. LoRA is not supported for {config_name}."
)
return f"LoRA_{self.r}"

if config_value not in valid_configs:
raise ValueError(
f"Invalid {config_name} value: {config_value}. Valid values are: {', '.join(sorted(valid_configs))}."
)

return config_value

def __post_init__(self):
self.peft_type = PeftType.GLORA

# Validate and process each configuration
self.config_A_B = self._validate_and_process_config(self.config_A_B, self._VALID_A_B_CONFIGS, "config_A_B")

self.config_C = self._validate_and_process_config(self.config_C, self._VALID_C_CONFIGS, "config_C")

self.config_D_E = self._validate_and_process_config(
self.config_D_E, self._VALID_D_E_CONFIGS, "config_D_E", allow_lora=False
)
Loading