Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
d1569d2
Add orthogonal subspace learning via SVD
NikhilNayak-debug Jul 15, 2025
0a30d28
Add wrapper for SVD models
NikhilNayak-debug Jul 15, 2025
af65172
docs: add adaptive SVD utilities
NikhilNayak-debug Jul 15, 2025
bc3bb88
changes for unified OSF method implementation in PEFT and documentati…
NikhilNayak-debug Jul 23, 2025
0e40089
naming changes and compatibility
NikhilNayak-debug Jul 28, 2025
8f77e57
unifying implementation with other PEFT methods and added gradient hooks
NikhilNayak-debug Jul 31, 2025
49774d8
adding test cases for various functionality of the method
NikhilNayak-debug Aug 6, 2025
22df96d
adding test for loading and saving OSFT model
NikhilNayak-debug Aug 6, 2025
aacbef2
adding more test cases for OSF method
NikhilNayak-debug Aug 13, 2025
fdb6d73
moving OSF util methods to the appropriate directory
NikhilNayak-debug Aug 13, 2025
7533cdf
removed redundant check while generating osf config
NikhilNayak-debug Aug 13, 2025
5b87c7d
handle async calls in OSF gradient projection
NikhilNayak-debug Aug 13, 2025
1dd5c68
removed unnecessary DTensor distinction
NikhilNayak-debug Aug 13, 2025
3cb88fb
simplifying gradient hook method
NikhilNayak-debug Aug 13, 2025
a0e445e
fix: implement proper gradient hook management for OSF tuner
NikhilNayak-debug Aug 13, 2025
d28b9d7
adding model-specific constants for OSF target modules
NikhilNayak-debug Aug 13, 2025
2eba741
refactor: implement minimal PEFT integration for OSF tuner
NikhilNayak-debug Aug 14, 2025
2402678
documentation updates
NikhilNayak-debug Aug 19, 2025
845479e
OSF refactor + docs/tests cleanup
NikhilNayak-debug Sep 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,8 @@
title: Model merge
- local: package_reference/helpers
title: Helpers
- local: package_reference/osf_utils
title: OSF utilities
- local: package_reference/hotswap
title: Hotswapping adapters
title: Utilities
Expand Down
236 changes: 236 additions & 0 deletions docs/source/package_reference/osf.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.

⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.

-->

# OSF (Orthogonal Subspace Fine-tuning)

Orthogonal Subspace Fine-tuning ([OSF](https://huggingface.co/papers/2504.07097)) is a PEFT method designed for continual learning that constrains parameter updates to be orthogonal to previously important directions. This approach enables full fine-tuning while preventing catastrophic forgetting without requiring additional parameters or storing previous gradients.

The abstract from the paper is:

*Continual learning in large language models (LLMs) is prone to catastrophic forgetting, where adapting to new tasks significantly degrades performance on previously learned ones. Existing methods typically rely on low-rank, parameter-efficient updates that limit the model's expressivity and introduce additional parameters per task, leading to scalability issues. To address these limitations, we propose a novel continual full fine-tuning approach leveraging adaptive singular value decomposition (SVD). Our method dynamically identifies task-specific low-rank parameter subspaces and constrains updates to be orthogonal to critical directions associated with prior tasks, thus effectively minimizing interference without additional parameter overhead or storing previous task gradients. We evaluate our approach extensively on standard continual learning benchmarks using both encoder-decoder (T5-Large) and decoder-only (LLaMA-2 7B) models, spanning diverse tasks including classification, generation, and reasoning. Empirically, our method achieves state-of-the-art results, up to 7% higher average accuracy than recent baselines like O-LoRA, and notably maintains the model's general linguistic capabilities, instruction-following accuracy, and safety throughout the continual learning process by reducing forgetting to near-negligible levels. Our adaptive SVD framework effectively balances model plasticity and knowledge retention, providing a practical, theoretically grounded, and computationally scalable solution for continual learning scenarios in large language models.*

## How OSF Works

OSF decomposes each weight matrix into high-rank (frozen) and low-rank (trainable) components using SVD:

```
W = U_high * S_high * V_high^T + U_low * S_low * V_low^T
```

Where:
- `U_high, S_high, V_high`: Preserve important directions from previous tasks (frozen)
- `U_low, S_low, V_low`: Allow adaptation to new tasks (trainable)

During training, gradients are projected to be orthogonal to the high-rank subspace, ensuring updates don't interfere with previously learned knowledge.

## Basic Usage

```python
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import OSFConfig, get_peft_model

# Load base model
model = AutoModelForCausalLM.from_pretrained("gpt2")

# Configure OSF
config = OSFConfig(
target_modules=["c_attn", "c_proj"], # Target attention layers
effective_rank=8, # Default rank for decomposition
rank_pattern={"c_attn": 16} # Override rank for specific modules
)

# Apply OSF
model = get_peft_model(model, config)

# Train as usual
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)

tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

inputs = tokenizer("Hello world", return_tensors="pt", padding=True)
loss = model(**inputs, labels=inputs.input_ids).loss
loss.backward()
optimizer.step()
optimizer.zero_grad()
```

## Configuration Options

### Target Modules

You can specify target modules in several ways:

```python
# Specific module names
config = OSFConfig(target_modules=["q_proj", "k_proj", "v_proj", "o_proj"])

# All linear layers
config = OSFConfig(target_modules="all-linear")

# Model-specific defaults (automatically detected)
config = OSFConfig() # Uses model-appropriate defaults
```

### Effective Rank Configuration

Control the decomposition rank:

```python
# Global rank (applies to all target modules)
config = OSFConfig(effective_rank=16)

# Automatic rank (50% of the smaller matrix dimension per target)
config = OSFConfig(effective_rank=None)

# Per-module rank overrides
config = OSFConfig(
effective_rank=8,
rank_pattern={
"q_proj": 16, # Higher rank for query projection
"gate_proj": 4 # Lower rank for gate projection
}
)
```

## Training Advice for Continual Learning

### Sequential Task Learning

OSF is specifically designed for learning tasks sequentially. Between tasks, recompute the SVD so the preserved subspace reflects the latest weights. One simple way is to re-wrap the updated base model with OSF again:

```python
# Task 1: train on domain A with initial preserved subspace
r = 8 # initial effective rank to preserve
model = get_peft_model(base_model, OSFConfig(effective_rank=r))
train_task(model, task_1_data)

# Task 2: recompute SVD on updated weights and increase preserved subspace
base_model = model.base_model.model # unwrap updated base
r += 4 # grow preserved subspace to include Task 1 knowledge
model = get_peft_model(base_model, OSFConfig(effective_rank=r))
train_task(model, task_2_data)

# Task 3: recompute again and expand preserved subspace further
base_model = model.base_model.model
r += 4
model = get_peft_model(base_model, OSFConfig(effective_rank=r))
train_task(model, task_3_data)
```

### Budget Allocation for Task Sequences

When training on a known sequence of n tasks, one effective strategy is to progressively allocate model capacity to balance learning new tasks while preserving previous knowledge:

- **Task 1**: Use full capacity (train everything)
- **Task 2**: Freeze 1/n of model capacity, train remaining (n-1)/n capacity
- **Task 3**: Freeze 2/n of model capacity, train remaining (n-2)/n capacity
- **Task n**: Freeze (n-1)/n of model capacity, use 1/n capacity for final task

This approach ensures each task gets adequate learning capacity while progressively preserving more knowledge from previous tasks.

```python
# Example: 4-task sequence with progressive budget allocation
n_tasks = 4
base_rank = 32 # Starting rank for full capacity

for task_id in range(n_tasks):
# Calculate remaining capacity for current task
freeze_fraction = task_id / n_tasks
remaining_capacity = 1.0 - freeze_fraction
current_rank = int(base_rank * remaining_capacity)

config = OSFConfig(
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
effective_rank=current_rank
)

print(f"Task {task_id + 1}: Using rank {current_rank} "
f"({remaining_capacity:.1%} of full capacity)")

# Train on current task
model = get_peft_model(base_model, config)
train_task(model, task_data[task_id])
```

### Best Practices

1. **Effective Rank Selection**: Start with `effective_rank=None` (auto sets rank to 50% of the smaller weight dimension per target module) and adjust based on task complexity
2. **Learning Rate**: Use smaller learning rates (1e-5 to 1e-4) compared to standard fine-tuning
3. **Task Importance**: Use `rank_pattern` to allocate more capacity to critical modules
4. **Model Architecture**: OSF works best with transformer architectures having clear attention and MLP separations
5. **Capacity Planning**: For known task sequences, use progressive budget allocation (1/n, 2/n, ..., (n-1)/n freezing) to balance plasticity and stability

### Memory Considerations

OSF modifies weights in-place and doesn't add parameters, making it memory-efficient:

```python
# Memory usage remains close to base model
print(f"Base model parameters: {base_model.num_parameters():,}")
print(f"OSF model parameters: {osf_model.num_parameters():,}") # Similar count
```

## Advanced Usage

### Custom Target Modules

For models with non-standard architectures:

```python
config = OSFConfig(
target_modules=["dense", "intermediate.dense"], # Custom layer names
effective_rank=12,
rank_pattern={"dense": 8, "intermediate.dense": 16}
)
```

### Integration with Other Methods

OSF can be combined with other techniques:

```python
# Use with gradient checkpointing for memory efficiency
model.gradient_checkpointing_enable()

# Apply weight decay selectively (regularizes low-rank factors to limit drift/overfitting in continual updates; keep small)
optimizer = torch.optim.AdamW([
{"params": [p for n, p in model.named_parameters() if "U_low" in n], "weight_decay": 0.01},
{"params": [p for n, p in model.named_parameters() if "S_low" in n], "weight_decay": 0.001},
{"params": [p for n, p in model.named_parameters() if "V_low" in n], "weight_decay": 0.01},
], lr=1e-4)
```

## OSFConfig

[[autodoc]] tuners.osf.config.OSFConfig

## OSFModel

[[autodoc]] tuners.osf.model.OSFModel

## Utility Functions

### Weight Decomposition

[[autodoc]] tuners.osf.utils.decompose_weight_matrix

[[autodoc]] tuners.osf.utils.reconstruct_weight_matrix

### Gradient Projection

[[autodoc]] tuners.osf.utils.project_gradient_to_orthogonal_space
37 changes: 37 additions & 0 deletions examples/orthogonal_subspace_learning/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Orthogonal Subspace Learning with Adaptive OSF

## TODO: Runnable Example Needed

This folder is a placeholder for a comprehensive OSF example. As suggested in the review feedback:

> "If you can, provide a runnable example in this folder instead, you can take a look at the EVA example for inspiration. A runnable example can be a good place to showcase the different features. Jupyter notebooks are fine as well."

### Planned Example Features:
- Complete continual learning scenario with multiple tasks
- Demonstration of OSF's catastrophic forgetting prevention
- Configuration examples (target_modules, effective_rank, rank_pattern)
- Performance comparison with baseline methods
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the performance comparison with baseline methods - at least for single tasks - is best done in the PEFT method comparison (MetaMathQA). Of course, feel free to provide a comparison with methods for support multi-task learning if it fits into the example without too much effort.

- Memory usage analysis

### Current Basic Usage:
For basic usage examples and API documentation, see the [OSF documentation](../../docs/source/package_reference/osf.md).

```python
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import OSFConfig, get_peft_model

model = AutoModelForCausalLM.from_pretrained("gpt2")
config = OSFConfig(target_modules=["c_attn", "c_proj"], effective_rank=8)
model = get_peft_model(model, config)

optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)

tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
inputs = tokenizer("Hello world", return_tensors="pt", padding=True)
loss = model(**inputs, labels=inputs.input_ids).loss
loss.backward()
optimizer.step()
optimizer.zero_grad()
```
4 changes: 4 additions & 0 deletions src/peft/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@
ShiraModel,
TrainableTokensConfig,
TrainableTokensModel,
OSFConfig,
OSFModel,
VBLoRAConfig,
VBLoRAModel,
VeraConfig,
Expand Down Expand Up @@ -193,6 +195,8 @@
"TaskType",
"TrainableTokensConfig",
"TrainableTokensModel",
"OSFConfig",
"OSFModel",
"VBLoRAConfig",
"VBLoRAConfig",
"VBLoRAModel",
Expand Down
3 changes: 3 additions & 0 deletions src/peft/tuners/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from .mixed import MixedModel
from .multitask_prompt_tuning import MultitaskPromptEmbedding, MultitaskPromptTuningConfig, MultitaskPromptTuningInit
from .oft import OFTConfig, OFTModel
from .osf import OSFConfig, OSFModel
from .p_tuning import PromptEncoder, PromptEncoderConfig, PromptEncoderReparameterizationType
from .poly import PolyConfig, PolyModel
from .prefix_tuning import PrefixEncoder, PrefixTuningConfig
Expand Down Expand Up @@ -100,6 +101,8 @@
"ShiraModel",
"TrainableTokensConfig",
"TrainableTokensModel",
"OSFConfig",
"OSFModel",
"VBLoRAConfig",
"VBLoRAModel",
"VeraConfig",
Expand Down
14 changes: 14 additions & 0 deletions src/peft/tuners/osf/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from peft.utils import register_peft_method

from .config import OSFConfig
from .layer import OSFLayer, Linear
from .model import OSFModel

__all__ = ["OSFConfig", "OSFModel", "OSFLayer", "Linear"]

register_peft_method(
name="osf",
config_cls=OSFConfig,
model_cls=OSFModel,
is_mixed_compatible=False,
)
38 changes: 38 additions & 0 deletions src/peft/tuners/osf/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from __future__ import annotations

from dataclasses import dataclass, field
from typing import Optional, Union

from peft.config import PeftConfig
from peft.utils import PeftType


@dataclass
class OSFConfig(PeftConfig):
"""
Configuration for Orthogonal Subspace Fine-tuning (OSF).

Args:
effective_rank (`int`, *optional*):
The effective rank for OSF decomposition. If None, defaults to 50% of min(weight.shape).
target_modules (`Union[list[str], str]`, *optional*):
The names of the modules to apply OSF to. Can be a list of module names or 'all-linear'.
rank_pattern (`dict[str, int]`, *optional*):
A dictionary of regex patterns to override effective_rank for specific modules.
"""

effective_rank: Optional[int] = field(
default=None,
metadata={"help": "The effective rank for OSF decomposition. If None, defaults to 50% of min(weight.shape)."}
)
target_modules: Optional[Union[list[str], str]] = field(
default=None,
metadata={"help": "The names of the modules to apply OSF to. Can be a list of module names or 'all-linear'."}
)
rank_pattern: Optional[dict[str, int]] = field(
default=None,
metadata={"help": "A dictionary of regex patterns to override effective_rank for specific modules."}
)

def __post_init__(self):
self.peft_type = PeftType.OSF
Loading