Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
bc16e34
- initial commit for hira adapter
hqsiswiliam Jun 8, 2025
3c27937
- This initial modification of HiRA's config
hqsiswiliam Jun 8, 2025
aeb3d54
- update HiRA Model
hqsiswiliam Jun 10, 2025
d290008
- update HiRA Layer
hqsiswiliam Jun 30, 2025
dcdbe27
- update HiRA Layer partially
hqsiswiliam Jun 30, 2025
8f48e2c
- update HiRA Layer partially (Embedding Layer)
hqsiswiliam Jul 2, 2025
86e5195
- update HiRA Layer partially (ConvNd Layer)
hqsiswiliam Jul 2, 2025
da12aab
- update HiRA Layer partially (ConvNd Layer)
hqsiswiliam Jul 4, 2025
69ace05
- update HiRA Layer partially (Conv1/2/3d Layer)
hqsiswiliam Jul 4, 2025
2c53c8d
- update HiRA Layer partially (MultiheadAttention)
hqsiswiliam Jul 4, 2025
32f6a4d
- remove HiRA Layer partially (MultiheadAttention)
hqsiswiliam Jul 4, 2025
f86c9a9
- update HiRA `layer`, `model`, and `config`
hqsiswiliam Jul 4, 2025
54c8de7
- add bnb implementation and __init__.py
hqsiswiliam Jul 4, 2025
ef18d9f
- add HiRA's Linear8bitLt implementation
hqsiswiliam Jul 4, 2025
7c4718b
- update HiRA's layer comment
hqsiswiliam Jul 4, 2025
8506413
- add HiRA's Linear4bit
hqsiswiliam Jul 4, 2025
9e8c017
- complete HiRA's Linear4bit
hqsiswiliam Jul 4, 2025
71907b4
- add test_hira
hqsiswiliam Jul 4, 2025
ce782b6
- HiRA: updates to peft init, tuners, types, and GPU tests
hqsiswiliam Jul 4, 2025
d20332e
Merge remote-tracking branch 'upstream/main'
hqsiswiliam Jul 4, 2025
d76e328
- HiRA: updates to HiRA layer, and HiRA testing
hqsiswiliam Jul 4, 2025
e933f2a
- HiRA: formatting hira
hqsiswiliam Jul 5, 2025
0a4b3aa
- HiRA: formatting hira
hqsiswiliam Jul 5, 2025
6b4092a
- HiRA: add document
hqsiswiliam Jul 5, 2025
aab9204
- apply merge
hqsiswiliam Jul 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 90 additions & 0 deletions docs/source/package_reference/hira.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# HiRA

High-Rank Adaptation ([HiRA](https://openreview.net/pdf?id=TwJrTz9cRS)) is a PEFT method that extends the LoRA approach by applying an element-wise modulation on the original weight matrix. Instead of adding a low-rank update directly, HiRA computes:

$$
W' = W_0 + W_0 \odot (B A)
$$

where $W_0$ is the base weight, and $A, B$ are low-rank factors with rank $r \ll \min( \text{in_features}, \text{out_features})$. This formulation allows HiRA to adapt existing weights with a multiplicative, input-dependent modulation, often improving fine-tuning efficiency on downstream tasks.

The abstract from the HiRA paper is:

> *We propose Hadamard High-Rank Adaptation (HiRA), a parameter-efficient fine-tuning (PEFT) method that enhances the adaptability of Large Language Models (LLMs). While Low-rank Adaptation (LoRA) is widely used to reduce resource demands, its low-rank updates may limit its expressiveness for new tasks. HiRA addresses this by using a Hadamard product to retain high-rank update parameters, improving the model capacity. Empirically, HiRA outperforms LoRA and its variants on several tasks, with extensive ablation studies validating its effectiveness.*


## Examples

```python
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import get_peft_model
from peft.tuners.hira import HiRAConfig

# Example 1: HiRA on opt-125m for causal language modeling
model_id = "facebook/opt-125m"
base_model = AutoModelForCausalLM.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)

# Define HiRA configuration: apply to the MLP dense layers in each transformer block
hira_config = HiRAConfig(
r=32,
target_modules=["k_proj", "q_proj","v_proj","fc1","fc2"],
hira_dropout=0.0,
init_hira_weights=True,
)
peft_model = get_peft_model(base_model, hira_config)

peft_model.print_trainable_parameters()
# trainable params: 4,718,592 || all params: 129,957,888 || trainable%: 3.6309
```

## HiRAConfig

[[autodoc]] tuners.hira.config.HiRAConfig

## Core Layers

### HiRALayer

[[autodoc]] tuners.hira.layer.HiRALayer

### Linear Adapter

[[autodoc]] tuners.hira.layer.Linear

### Embedding Adapter

[[autodoc]] tuners.hira.layer.Embedding

### Convolutional Adapters

[[autodoc]] tuners.hira.layer.Conv1d [[autodoc]] tuners.hira.layer.Conv2d [[autodoc]] tuners.hira.layer.ConvNd

## BitsAndBytes Integration

* **8-bit Quantized**: [[autodoc]] tuners.hira.bnb.Linear8bitLt
* **4-bit Quantized**: [[autodoc]] tuners.hira.bnb.Linear4bit
* **Dispatch Utilities**:

* [[autodoc]] tuners.hira.bnb.dispatch_bnb_8bit
* [[autodoc]] tuners.hira.bnb.dispatch_bnb_4bit

## Dispatch Handler

Default layer replacement for HiRA adapters:

[[autodoc]] tuners.hira.dispatch.dispatch_default


## Citation:
If you found HiRA is useful, please cite HiRA as:
```
@inproceedings{
huang2025hira,
title={Hi{RA}: Parameter-Efficient Hadamard High-Rank Adaptation for Large Language Models},
author={Qiushi Huang and Tom Ko and Zhan Zhuang and Lilian Tang and Yu Zhang},
booktitle={The Thirteenth International Conference on Learning Representations},
year={2025},
url={https://openreview.net/forum?id=TwJrTz9cRS}
}
```
16 changes: 8 additions & 8 deletions method_comparison/sanitizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,14 @@ def _evaluate_node(df, node):
raise ValueError("Right side of comparison must be a literal (number, string, list).")

operator_map = {
ast.Gt: lambda c, v: df[c] > v,
ast.GtE: lambda c, v: df[c] >= v,
ast.Lt: lambda c, v: df[c] < v,
ast.LtE: lambda c, v: df[c] <= v,
ast.Eq: lambda c, v: df[c] == v,
ast.Gt: lambda c, v: df[c] > v,
ast.GtE: lambda c, v: df[c] >= v,
ast.Lt: lambda c, v: df[c] < v,
ast.LtE: lambda c, v: df[c] <= v,
ast.Eq: lambda c, v: df[c] == v,
ast.NotEq: lambda c, v: df[c] != v,
ast.In: lambda c, v: df[c].isin(v),
ast.NotIn: lambda c, v: ~df[c].isin(v)
ast.In: lambda c, v: df[c].isin(v),
ast.NotIn: lambda c, v: ~df[c].isin(v),
}
op_type = type(op_node)
if op_type not in operator_map:
Expand Down Expand Up @@ -90,7 +90,7 @@ def parse_and_filter(df, filter_str):

try:
# 'eval' mode ensures the source is a single expression.
tree = ast.parse(filter_str, mode='eval')
tree = ast.parse(filter_str, mode="eval")
expression_node = tree.body
except (SyntaxError, ValueError) as e:
raise ValueError(f"Invalid filter syntax: {e}")
Expand Down
38 changes: 20 additions & 18 deletions method_comparison/test_sanitizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,32 +7,34 @@
@pytest.fixture
def df_products():
data = {
'product_id': [101, 102, 103, 104, 105, 106],
'category': ['Electronics', 'Books', 'Electronics', 'Home Goods', 'Books', 'Electronics'],
'price': [799.99, 19.99, 49.50, 120.00, 24.99, 150.00],
'stock': [15, 300, 50, 25, 150, 0]
"product_id": [101, 102, 103, 104, 105, 106],
"category": ["Electronics", "Books", "Electronics", "Home Goods", "Books", "Electronics"],
"price": [799.99, 19.99, 49.50, 120.00, 24.99, 150.00],
"stock": [15, 300, 50, 25, 150, 0],
}
return pd.DataFrame(data)


def test_exploit_fails(df_products):
with pytest.raises(ValueError) as e:
mask1 = parse_and_filter(df_products,
"""price < 50 and @os.system("/bin/echo password")""")
assert 'Invalid filter syntax' in str(e)
mask1 = parse_and_filter(df_products, """price < 50 and @os.system("/bin/echo password")""")
assert "Invalid filter syntax" in str(e)


@pytest.mark.parametrize('expression,ids', [
("price < 50", [102, 103, 105]),
("product_id in [101, 102]", [101, 102]),
("price < 50 and category == 'Electronics'", [103]),
("stock < 100 or category == 'Home Goods'", [101, 103, 104, 106]),
("(price > 100 and stock < 20) or category == 'Books'", [101, 102, 105, 106]),
("not (price > 50 or stock > 100)", [103]),
("not price > 50", [102, 103, 105]),
("(price < 50) & (category == 'Electronics')", [103]),
("(stock < 100) | (category == 'Home Goods')", [101, 103, 104, 106]),
])
@pytest.mark.parametrize(
"expression,ids",
[
("price < 50", [102, 103, 105]),
("product_id in [101, 102]", [101, 102]),
("price < 50 and category == 'Electronics'", [103]),
("stock < 100 or category == 'Home Goods'", [101, 103, 104, 106]),
("(price > 100 and stock < 20) or category == 'Books'", [101, 102, 105, 106]),
("not (price > 50 or stock > 100)", [103]),
("not price > 50", [102, 103, 105]),
("(price < 50) & (category == 'Electronics')", [103]),
("(stock < 100) | (category == 'Home Goods')", [101, 103, 104, 106]),
],
)
def test_operations(df_products, expression, ids):
mask1 = parse_and_filter(df_products, expression)
assert sorted(df_products[mask1].product_id) == sorted(ids)
6 changes: 6 additions & 0 deletions src/peft/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@
EvaConfig,
FourierFTConfig,
FourierFTModel,
HiRAConfig,
HiRAModel,
HiRARuntimeConfig,
HRAConfig,
HRAModel,
IA3Config,
Expand Down Expand Up @@ -149,6 +152,9 @@
"FourierFTModel",
"HRAConfig",
"HRAModel",
"HiRAConfig",
"HiRAModel",
"HiRARuntimeConfig",
"IA3Config",
"IA3Model",
"LNTuningConfig",
Expand Down
8 changes: 8 additions & 0 deletions src/peft/tuners/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@
from .c3a import C3AConfig, C3AModel
from .cpt import CPTConfig, CPTEmbedding
from .fourierft import FourierFTConfig, FourierFTModel
from .hira import (
HiRAConfig,
HiRAModel,
HiRARuntimeConfig,
)
from .hra import HRAConfig, HRAModel
from .ia3 import IA3Config, IA3Model
from .ln_tuning import LNTuningConfig, LNTuningModel
Expand Down Expand Up @@ -66,6 +71,9 @@
"FourierFTModel",
"HRAConfig",
"HRAModel",
"HiRAConfig",
"HiRAModel",
"HiRARuntimeConfig",
"IA3Config",
"IA3Model",
"LNTuningConfig",
Expand Down
55 changes: 55 additions & 0 deletions src/peft/tuners/hira/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright 2023-present the HuggingFace Inc. team.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Copyright 2023-present the HuggingFace Inc. team.
# Copyright 2025-present the HuggingFace Inc. team.

Let's update every date to 2025.

#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from peft.import_utils import is_bnb_4bit_available, is_bnb_available
from peft.utils import register_peft_method

from .config import HiRAConfig, HiRARuntimeConfig
from .layer import Conv2d, Conv3d, Embedding, HiRALayer, Linear
from .model import HiRAModel


__all__ = [
"Conv2d",
"Conv3d",
"Embedding",
"HiRAConfig",
"HiRALayer",
"HiRAModel",
"HiRARuntimeConfig",
"Linear",
]

register_peft_method(name="hira", config_cls=HiRAConfig, model_cls=HiRAModel, is_mixed_compatible=True)


def __getattr__(name):
if (name == "Linear8bitLt") and is_bnb_available():
from .bnb import Linear8bitLt

return Linear8bitLt

if (name == "Linear4bit") and is_bnb_4bit_available():
from .bnb import Linear4bit

return Linear4bit


#
# if (name == "EetqLoraLinear") and is_eetq_available():
# from .eetq import EetqLoraLinear
#
# return EetqLoraLinear
#
# raise AttributeError(f"module {__name__} has no attribute {name}")
Loading