diff --git a/docs/source/developer_guides/lora.md b/docs/source/developer_guides/lora.md index 01d6c869ff..a47e7837c1 100644 --- a/docs/source/developer_guides/lora.md +++ b/docs/source/developer_guides/lora.md @@ -200,17 +200,37 @@ from peft import PeftModel model = PeftModel.from_pretrained(base_model, peft_model_id, ephemeral_gpu_offload=True) ``` +#### Optimization + DoRA is optimized (computes faster and takes less memory) for models in the evaluation mode, or when dropout is set to 0. We reuse the base result at those times to get the speedup. Running [dora finetuning](https://github.com/huggingface/peft/blob/main/examples/dora_finetuning/dora_finetuning.py) -with `CUDA_VISIBLE_DEVICES=0 ZE_AFFINITY_MASK=0 time python examples/dora_finetuning/dora_finetuning.py --quantize --lora_dropout 0 --batch_size 16 --eval_step 2 --use_dora` -on a 4090 with gradient accumulation set to 2 and max step to 20 resulted with the following observations: +with `CUDA_VISIBLE_DEVICES=0 ZE_AFFINITY_MASK=0 time python examples/dora_finetuning/dora_finetuning.py --quantize --lora_dropout 0 --batch_size 16 --eval_step 2 --use_dora` on a 4090 with gradient accumulation set to 2 and max step to 20 resulted with the following observations: | | Without Optimization | With Optimization | | :--: | :--: | :--: | -| train_runtime | 359.7298 | **279.2676** | -| train_samples_per_second | 1.779 | **2.292** | -| train_steps_per_second | 0.056 | **0.072** | +| train runtime (sec) | 359.7298 | **279.2676** | +| train samples per second | 1.779 | **2.292** | +| train steps per second | 0.056 | **0.072** | + +Moreover, it is possible to further increase runtime performance of DoRA by using the [`DoraCaching`] helper context. This requires the model to be in `eval` mode: + +```py +from peft.helpers import DoraCaching + +model.eval() +with DoraCaching(): + output = model(inputs) +``` + +For [`meta-llama/Llama-3.1-8B`](https://huggingface.co/meta-llama/Llama-3.1-8B), the [DoRA caching benchmark script](https://github.com/huggingface/peft/blob/main/examples/dora_finetuning/dora-caching.py) shows that, compared to LoRA: + +- DoRA without caching requires 139% more time +- DoRA without caching requires 4% more memory +- DoRA with caching requires 17% more time +- DoRA with caching requires 41% more memory + +Caching can thus make inference with DoRA significantly faster but it also requires signficantly more memory. Ideally, if the use case allows it, just merge the DoRA adapter to avoid both memory and runtime overhead. #### Caveats diff --git a/docs/source/package_reference/helpers.md b/docs/source/package_reference/helpers.md index 83e129d6ea..5cc83e5444 100644 --- a/docs/source/package_reference/helpers.md +++ b/docs/source/package_reference/helpers.md @@ -20,3 +20,8 @@ A collection of helper functions for PEFT. [[autodoc]] helpers.disable_input_dtype_casting - all + +## Context manager to enable DoRA caching (faster at inference time but requires more memory) + +[[autodoc]] helpers.DoraCaching + - all diff --git a/examples/dora_finetuning/dora-caching.py b/examples/dora_finetuning/dora-caching.py new file mode 100644 index 0000000000..bdcff097ac --- /dev/null +++ b/examples/dora_finetuning/dora-caching.py @@ -0,0 +1,126 @@ +# Copyright 2025-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Small script to measure DoRA caching efficiency +""" + +import argparse +import time +from contextlib import contextmanager + +import torch +from transformers import AutoModelForCausalLM + +from peft import LoraConfig, get_peft_model +from peft.helpers import DoraCaching +from peft.utils import infer_device + + +device = infer_device() +# check for CPU +if device == "cpu": + raise ValueError("This benchmark requires a hardware accelerator, only found CPU") +torch_accelerator_module = getattr(torch, device, torch.cuda) + + +@contextmanager +def timeit(logs): + start = time.perf_counter() + yield + end = time.perf_counter() + dur = end - start + logs["time"].append(dur) + + +def run_benchmark(model, num_runs): + logs = { + "time": [], + } + + mem_start = torch_accelerator_module.max_memory_reserved() + for _ in range(num_runs + 1): + with timeit(logs): + for i in range(3): + x = torch.randint(10, 100, (1, 50)).to(device) + model(x) + mem_end = torch_accelerator_module.max_memory_reserved() + logs["memory"] = (mem_end - mem_start) / 1024**2 + + # remove the first run (warm up) + del logs["time"][0] + return logs + + +def main(model_id, num_runs): + model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map=device) + base_memory = torch_accelerator_module.max_memory_reserved() / 1024**2 + + # LORA + config = LoraConfig(init_lora_weights=False, use_dora=False) + model = get_peft_model(model, config) + model.eval() + torch_accelerator_module.reset_peak_memory_stats() + logs_lora = run_benchmark(model, num_runs) + avg_duration_lora = sum(logs_lora["time"]) / num_runs + max_memory_lora = logs_lora["memory"] + base_memory + + # DORA + del model + torch_accelerator_module.empty_cache() + + model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map=device) + base_memory = torch_accelerator_module.max_memory_reserved() / 1024**2 + config = LoraConfig(init_lora_weights=False, use_dora=True) + model = get_peft_model(model, config) + model.eval() + + # WITHOUT CACHING + torch_accelerator_module.reset_peak_memory_stats() + logs_dora_no_caching = run_benchmark(model, num_runs) + avg_duration_no_caching = sum(logs_dora_no_caching["time"]) / num_runs + max_memory_no_caching = logs_dora_no_caching["memory"] + base_memory + + # WITH CACHING + torch_accelerator_module.reset_peak_memory_stats() + with DoraCaching(): + logs_dora_caching = run_benchmark(model, num_runs) + avg_duration_caching = sum(logs_dora_caching["time"]) / num_runs + max_memory_caching = logs_dora_caching["memory"] + base_memory + + print( + f"Benchmark results for model {model_id} with {num_runs} runs:\n\n" + f"avg time LoRA: {avg_duration_lora:.4f} sec\n" + f"avg time DoRA no caching: {avg_duration_no_caching:.4f} sec\n" + f"avg time DoRA with caching: {avg_duration_caching:.4f} sec\n" + f"\n" + f"memory LoRA: {max_memory_lora:.2f} MB\n" + f"memory DoRA no caching: {max_memory_no_caching:.2f} MB\n" + f"memory DoRA with caching: {max_memory_caching:.2f} MB\n" + f"\n" + f"DoRA time overhead no caching: {(avg_duration_no_caching - avg_duration_lora) / avg_duration_lora * 100:.2f}%\n" + f"DoRA time overhead with caching: {(avg_duration_caching - avg_duration_lora) / avg_duration_lora * 100:.2f}%\n" + f"\n" + f"DoRA memory overhead no caching: {(max_memory_no_caching - max_memory_lora) / max_memory_lora * 100:.2f}%\n" + f"DoRA memory overhead with caching: {(max_memory_caching - max_memory_lora) / max_memory_lora * 100:.2f}%" + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Benchmark DoRA caching efficiency") + parser.add_argument("--model_id", type=str, default="meta-llama/Llama-3.1-8B", help="Model ID to benchmark") + parser.add_argument("--num_runs", type=int, default=10, help="Number of runs for the benchmark") + args = parser.parse_args() + + main(args.model_id, args.num_runs) diff --git a/src/peft/helpers.py b/src/peft/helpers.py index d748c62e69..05b70d8785 100644 --- a/src/peft/helpers.py +++ b/src/peft/helpers.py @@ -21,7 +21,7 @@ from torch import nn from .peft_model import PeftConfig, PeftModel -from .tuners.lora import LoraLayer +from .tuners.lora import LoraLayer, dora from .tuners.tuners_utils import BaseTunerLayer @@ -249,3 +249,44 @@ def disable_input_dtype_casting(model: nn.Module, active: bool = True): continue if name in original_values: module.cast_input_dtype_enabled = original_values[name] + + +class DoraCaching: + """Context manager to enable DoRA caching, which improves speed of DoRA inference at the expense of memory. + + Even within the caching context, if the model is in training mode, caching is disabled. When the model switches to + training mode, the cache will be cleared. + + Example: + + ```py + >>> from peft.helpers import enable_dora_scaling + + >>> model.eval() # put in eval model for caching to work + + >>> with DoraCaching(): # use as a context manager + ... output = model(inputs) + + >>> dora_caching = DoraCaching() + >>> dora_caching(enabled=True) # permanently enable caching + >>> output = model(inputs) + >>> dora_caching(enabled=False) # permanently disable caching + >>> output = model(inputs) + ``` + + """ + + def __init__(self, enabled: bool = True) -> None: + self.enabled = enabled + self.prev_value = None + + def __enter__(self): + self.prev_value = dora.ENABLE_DORA_CACHING + dora.ENABLE_DORA_CACHING = self.enabled + + def __exit__(self, type, value, traceback): + dora.ENABLE_DORA_CACHING = self.prev_value + self.prev_value = None + + def __call__(self, enabled: bool = True): + dora.ENABLE_DORA_CACHING = enabled diff --git a/src/peft/tuners/lora/dora.py b/src/peft/tuners/lora/dora.py index f38a7df125..c0bc1387d2 100644 --- a/src/peft/tuners/lora/dora.py +++ b/src/peft/tuners/lora/dora.py @@ -13,6 +13,8 @@ # limitations under the License. from copy import deepcopy +from functools import wraps +from typing import Any, Optional import torch import torch.nn.functional as F @@ -22,18 +24,80 @@ from peft.utils.other import transpose +ENABLE_DORA_CACHING = False +"""Whether to enable DoRA caching, which makes it faster at inference but requires more memory""" + + +def cache_decorator(cache_key: str): + """Caching decorator for DoRA + + Caching is only enabled if ENABLE_DORA_CACHING is set to True (default: False), when in eval mode, and when the + adapter_name is passed (e.g. not during layer initialization). + + """ + + def cache_value(func): + @wraps(func) + def wrapper(self, *args, **kwargs): + # if adapter_name is not passed, no caching + adapter_name = kwargs.get("adapter_name") + if (not ENABLE_DORA_CACHING) or self.training or (adapter_name is None): + self._cache_clear() + return func(self, *args, **kwargs) + + cache_key_adapter = f"{cache_key}-{adapter_name}" + output = self._cache_get(cache_key_adapter, None) + if output is not None: + return output + + output = func(self, *args, **kwargs) + self._cache_store(cache_key_adapter, output) + return output + + return wrapper + + return cache_value + + class DoraLinearLayer(nn.Module): def __init__(self, fan_in_fan_out): super().__init__() self.fan_in_fan_out = fan_in_fan_out + self._dora_cache: dict[str, Any] = {} # small ad hoc cache; values are not part of the state_dict + + def _cache_store(self, key: str, value: Any) -> None: + # cache intermediate values, e.g. weight norm of DoRA + self._dora_cache[key] = value - def get_weight_norm(self, weight, lora_weight, scaling) -> torch.Tensor: + def _cache_get(self, key: str, default: Optional[Any]) -> Optional[Any]: + # retrieve from ad hoc cache + return self._dora_cache.get(key, default) + + def _cache_clear(self) -> None: + self._dora_cache.clear() + + def train(self, mode: bool = True): + if mode: + self._cache_clear() + super().train(mode=mode) + return self + + @cache_decorator("weight-norm") + def get_weight_norm(self, weight, lora_weight, scaling, adapter_name: Optional[str] = None) -> torch.Tensor: # calculate L2 norm of weight matrix, column-wise weight = transpose(weight, self.fan_in_fan_out) weight = weight + scaling * lora_weight weight_norm = torch.linalg.norm(weight, dim=1).to(weight.dtype) return weight_norm + @cache_decorator("lora-weight") + def get_lora_weight(self, lora_A, lora_B, adapter_name: Optional[str] = None): + # Don't use `lora_weight = lora_B.weight @ lora_A.weight` because this causes errors with FSDP. Instead, + # calculate the same but using forward. + x_eye = torch.eye(lora_A.weight.shape[1], device=lora_A.weight.device, dtype=lora_A.weight.dtype) + lora_weight = lora_B(lora_A(x_eye)).T + return lora_weight + def update_layer(self, *, base_layer, lora_A, lora_B, scaling, place_on_cpu=False) -> None: # temporarily convert fp16 to fp32, as fp16 can cause trouble on CPU with PyTorch < 2.2 dtype_is_fp16 = lora_A.dtype == torch.float16 @@ -57,26 +121,28 @@ def update_layer(self, *, base_layer, lora_A, lora_B, scaling, place_on_cpu=Fals if dtype_is_fp16: lora_weight = lora_weight.half() - weight_norm = self.get_weight_norm(weight.to(lora_A.device), lora_weight, scaling) + weight_norm = self.get_weight_norm( + weight=weight.to(lora_A.device), lora_weight=lora_weight, scaling=scaling + ) if place_on_cpu: weight_norm = weight_norm.to("cpu") self.weight = nn.Parameter(weight_norm, requires_grad=True) - def forward(self, x, *, lora_A, lora_B, scaling, base_layer, base_result=None): + def forward(self, x, *, lora_A, lora_B, scaling, base_layer, base_result=None, adapter_name="default"): """ For DoRA, calculate the extra output from LoRA with DoRA applied. This should be added on top of the base layer output. """ - # Don't use `lora_weight = lora_B.weight @ lora_A.weight` because this causes errors with FSDP. Instead, - # calculate the same but using forward. - x_eye = torch.eye(lora_A.weight.shape[1], device=lora_A.weight.device, dtype=x.dtype) - lora_weight = lora_B(lora_A(x_eye)).T + lora_weight = self.get_lora_weight(lora_A=lora_A, lora_B=lora_B, adapter_name=adapter_name) + lora_weight = lora_weight.to(x.dtype) magnitude = self.weight weight = dequantize_module_weight(base_layer) weight = weight.to(x.dtype) - weight_norm = self.get_weight_norm(weight, lora_weight.detach(), scaling) + weight_norm = self.get_weight_norm( + weight=weight, lora_weight=lora_weight.detach(), scaling=scaling, adapter_name=adapter_name + ) # see section 4.3 of DoRA (https://huggingface.co/papers/2402.09353) # "[...] we suggest treating ||V +∆V ||_c in # Eq. (5) as a constant, thereby detaching it from the gradient @@ -97,7 +163,6 @@ def forward(self, x, *, lora_A, lora_B, scaling, base_layer, base_result=None): base_result = F.linear(x, transpose(weight, self.fan_in_fan_out)) result_dora = (mag_norm_scale - 1) * base_result + mag_norm_scale * lora_result * scaling - return result_dora def __repr__(self) -> str: @@ -106,15 +171,21 @@ def __repr__(self) -> str: class DoraEmbeddingLayer(DoraLinearLayer): - def forward(self, x, *, lora_A, lora_B, scaling, base_layer, embed_fn): + @cache_decorator("lora-weight") + def get_lora_weight(self, lora_A, lora_B, adapter_name: Optional[str] = None): + return (lora_A @ lora_B).T + + def forward(self, x, *, lora_A, lora_B, scaling, base_layer, embed_fn, adapter_name="default"): """ For DoRA, calculate the extra output from LoRA with DoRA applied. This should be added on top of the base layer output. """ - lora_weight = (lora_A @ lora_B).T + lora_weight = self.get_lora_weight(lora_A=lora_A, lora_B=lora_B, adapter_name=adapter_name) magnitude = self.weight weight = base_layer.weight - weight_norm = self.get_weight_norm(weight, lora_weight.detach(), scaling) + weight_norm = self.get_weight_norm( + weight=weight, lora_weight=lora_weight.detach(), scaling=scaling, adapter_name=adapter_name + ) # see section 4.3 of DoRA (https://huggingface.co/papers/2402.09353) # "[...] we suggest treating ||V +∆V ||_c in # Eq. (5) as a constant, thereby detaching it from the gradient @@ -132,7 +203,8 @@ def __repr__(self) -> str: class _DoraConvNdLayer(DoraLinearLayer): - def get_weight_norm(self, weight, lora_weight, scaling) -> torch.Tensor: + @cache_decorator("weight-norm") + def get_weight_norm(self, weight, lora_weight, scaling, adapter_name: Optional[str] = None) -> torch.Tensor: # calculate L2 norm of weight matrix, column-wise weight = weight + scaling * lora_weight # the following is needed to have compatibility with the 4/5D weight tensors of Conv2D/3D @@ -140,17 +212,30 @@ def get_weight_norm(self, weight, lora_weight, scaling) -> torch.Tensor: weight_norm = weight.norm(p=2, dim=dim, keepdim=True).transpose(1, 0) return weight_norm - def forward(self, x, *, lora_A, lora_B, scaling, base_layer, base_result=None): + @cache_decorator("lora-weight") + def get_lora_weight(self, lora_A, lora_B, adapter_name: Optional[str] = None) -> torch.Tensor: + # Don't use `lora_weight = lora_B.weight @ lora_A.weight` because this causes errors with FSDP. Instead, + # calculate the same but using forward. + r = lora_A.weight.shape[0] + lora_weight = torch.mm(lora_B.weight.view([-1, r]), lora_A.weight.view([r, -1])) + lora_weight = lora_weight + return lora_weight + + def forward( + self, x, *, lora_A, lora_B, scaling, base_layer, base_result=None, adapter_name: str = "default" + ) -> torch.Tensor: """ For DoRA, calculate the extra output from LoRA with DoRA applied. This should be added on top of the base layer output. """ weight = base_layer.weight - r = lora_A.weight.shape[0] - lora_weight = torch.mm(lora_B.weight.view([-1, r]), lora_A.weight.view([r, -1])) - lora_weight = lora_weight.reshape(weight.shape) + lora_weight = self.get_lora_weight(lora_A=lora_A, lora_B=lora_B, adapter_name=adapter_name).reshape( + weight.shape + ) magnitude = self.weight - weight_norm = self.get_weight_norm(weight, lora_weight.detach(), scaling) + weight_norm = self.get_weight_norm( + weight=weight, lora_weight=lora_weight.detach(), scaling=scaling, adapter_name=adapter_name + ) # see section 4.3 of DoRA (https://huggingface.co/papers/2402.09353) # "[...] we suggest treating ||V +∆V ||_c in # Eq. (5) as a constant, thereby detaching it from the gradient diff --git a/src/peft/tuners/lora/layer.py b/src/peft/tuners/lora/layer.py index 22e13d0801..e807e869ff 100644 --- a/src/peft/tuners/lora/layer.py +++ b/src/peft/tuners/lora/layer.py @@ -104,7 +104,7 @@ def __init__(self, base_layer: nn.Module, ephemeral_gpu_offload: bool = False, * self.use_dora: dict[str, bool] = {} # not actively used anymore after #2443, keep it for BC self.lora_bias: dict[str, bool] = {} self.lora_magnitude_vector = torch.nn.ModuleDict() # for DoRA - self._caches: dict[str, Any] = {} + self._caches: dict[str, Any] = {} # small ad hoc cache; values are not part of the state_dict self.ephemeral_gpu_offload: bool = ephemeral_gpu_offload # flag to enable/disable casting of input to weight dtype during forward call self.cast_input_dtype_enabled: bool = True @@ -479,9 +479,11 @@ def orthogonal_init(self, adapter_name): self.lora_B[adapter_name].weight = nn.Parameter(lora_B.contiguous().to(dtype)) def _cache_store(self, key: str, value: Any) -> None: + # cache intermediate values, e.g. weight norm of DoRA self._caches[key] = value def _cache_pop(self, key: str) -> Any: + # retrieve and remove from ad hoc cache value = self._caches.pop(key) return value diff --git a/src/peft/tuners/lora/variants.py b/src/peft/tuners/lora/variants.py index 6b99637390..8f7e95a871 100644 --- a/src/peft/tuners/lora/variants.py +++ b/src/peft/tuners/lora/variants.py @@ -126,6 +126,7 @@ def forward(module: Linear, active_adapter: str, x: torch.Tensor, result: torch. scaling=scaling, base_layer=module.get_base_layer(), base_result=base_result, + adapter_name=active_adapter, ) return result @@ -209,6 +210,7 @@ def forward(module: Embedding, active_adapter: str, x: torch.Tensor, result: tor scaling=scaling, base_layer=module.get_base_layer(), embed_fn=module._embed, + adapter_name=active_adapter, ) result = mag_norm_scale * result + dora_result return result @@ -292,6 +294,7 @@ def forward(module: _ConvNd, active_adapter: str, x: torch.Tensor, result: torch scaling=scaling, base_layer=module.get_base_layer(), base_result=base_result, + adapter_name=active_adapter, ) return result diff --git a/tests/test_helpers.py b/tests/test_helpers.py index 501bd146a2..2055a0b4a8 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -20,10 +20,12 @@ from transformers import AutoModelForCausalLM, AutoTokenizer from peft import LoraConfig, get_peft_model -from peft.helpers import check_if_peft_model, disable_input_dtype_casting, rescale_adapter_scale +from peft.helpers import DoraCaching, check_if_peft_model, disable_input_dtype_casting, rescale_adapter_scale from peft.tuners.lora.layer import LoraLayer from peft.utils import infer_device +from .testing_utils import hub_online_once + class TestCheckIsPeftModel: def test_valid_hub_model(self): @@ -471,3 +473,127 @@ def test_disable_input_dtype_casting_inactive_after_existing_context(self, model msg = r"expected m.*1 and m.*2 to have the same dtype" with pytest.raises(RuntimeError, match=msg): model(inputs) + + +class TestDoraCaching: + # Check that DoRA caching works (same results with and without caching, cache is filled/cleared). Note that this test + # does not check the actual runtime benefit of caching, because this could be flaky and measuring it reliably and in + # realistic conditions is expensive. Run examples/dora_finetuning/dora-caching.py instead to measure this. + device = infer_device() + + @pytest.fixture(autouse=True) + def disable_dora_caching(self): + # auto-fixture to ensure that no test accidentically permanently enables DoRA caching + DoraCaching()(enabled=False) + + def get_caches(self, model): + # utility function to collect all the caches in the model + caches = [] + for module in model.modules(): + if hasattr(module, "_dora_cache"): + caches.append(module._dora_cache) + return caches + + def get_output(self, model, inputs): + output = model(inputs) + if hasattr(output, "logits"): + return output.logits + return output + + def test_dora_caching_linear(self): + # ensure that the results don't change due to caching + inputs = torch.arange(10).view(1, -1).to(self.device) + model_id = "trl-internal-testing/tiny-random-LlamaForCausalLM" + config = LoraConfig(init_lora_weights=False, use_dora=True) + with hub_online_once(model_id): + model = AutoModelForCausalLM.from_pretrained(model_id).to(self.device) + self.check_dora_caching(model, config, inputs) + + def test_dora_caching_embedding(self): + # ensure that the results don't change due to caching + inputs = torch.arange(10).view(1, -1).to(self.device) + model_id = "trl-internal-testing/tiny-random-LlamaForCausalLM" + config = LoraConfig(init_lora_weights=False, use_dora=True, target_modules=["model.embed_tokens"]) + with hub_online_once(model_id): + model = AutoModelForCausalLM.from_pretrained(model_id).to(self.device) + self.check_dora_caching(model, config, inputs) + + def test_dora_caching_conv(self): + # ensure that the results don't change due to caching + # note: don't use something like small resnet, because batch norm affects outputs in train mode + + class ModelConv2D(nn.Module): + def __init__(self): + super().__init__() + self.conv0 = nn.Conv2d(3, 5, kernel_size=3, stride=1, padding=1) + self.conv1 = nn.Conv2d(5, 5, kernel_size=3, stride=1, padding=1) + self.linear = nn.Linear(5 * 3 * 3, 10) + + def forward(self, X): + X = self.conv0(X) + X = nn.functional.relu(X) + X = self.conv1(X) + X = nn.functional.relu(X) + X = X.view(X.size(0), -1) + X = self.linear(X) + return X + + inputs = torch.randn(1, 3, 3, 3).to(self.device) + config = LoraConfig(init_lora_weights=False, use_dora=True, target_modules=["conv0", "conv1"]) + model = ModelConv2D().to(self.device) + self.check_dora_caching(model, config, inputs) + + def check_dora_caching(self, model, config, inputs): + atol, rtol = 1e-6, 1e-6 + + # BASE RESULT + base_result = self.get_output(model, inputs) + + # DEFAULT: WITHOUT DoRA CACHING + model = get_peft_model(model, config) + caches = self.get_caches(model) + dora_result = self.get_output(model, inputs) + + # sanity check: the results should be different + assert not torch.allclose(base_result, dora_result, atol=atol, rtol=rtol) + # ensure that there are dora caches but they're all empty + assert caches + assert not any(cache for cache in caches) + + # ENABLE DORA CACHING + model.eval() + with DoraCaching(): + cached_result = self.get_output(model, inputs) + # the caches should be populated now + assert all(cache for cache in caches) + # the results should be the same + assert torch.allclose(cached_result, dora_result, atol=atol, rtol=rtol) + + # AFTER EXITING THE CONTEXT + cached_result_after_context = self.get_output(model, inputs) + assert torch.allclose(cached_result_after_context, dora_result, atol=atol, rtol=rtol) + # since we called forward outside of the context, the caches should be cleared + assert not any(cache for cache in caches) + + # NO CACHING IN TRAIN MODE + model.train() + # switching to train model immediately clears the caches + assert not any(cache for cache in caches) + with DoraCaching(): + results_train_mode = self.get_output(model, inputs) + # the caches should still be empty + assert not any(cache for cache in caches) + # results should not change + assert torch.allclose(results_train_mode, dora_result, atol=atol, rtol=rtol) + # still not any caches expected + assert not any(cache for cache in caches) + + # PERMANENTLY ENABLE DORA CACHING + DoraCaching()(enabled=True) + model.eval() + # putting the model in eval mode clears the caches + assert not any(cache for cache in caches) + # the results should be the same + cached_result_permanent = self.get_output(model, inputs) + assert torch.allclose(cached_result_permanent, dora_result, atol=atol, rtol=rtol) + DoraCaching()(enabled=False)