|
2 | 2 |
|
3 | 3 | import pytest
|
4 | 4 | import torch
|
5 |
| -from transformers import PretrainedConfig, PreTrainedModel |
| 5 | +from transformers import ( |
| 6 | + AutoModelForCausalLM, |
| 7 | + MllamaForConditionalGeneration, |
| 8 | + PretrainedConfig, |
| 9 | + PreTrainedModel, |
| 10 | +) |
6 | 11 |
|
7 | 12 | from llmcompressor.utils import (
|
8 | 13 | ALL_TOKEN,
|
9 | 14 | DisableQuantization,
|
10 | 15 | calibration_forward_context,
|
11 | 16 | convert_to_bool,
|
| 17 | + disable_cache, |
12 | 18 | flatten_iterable,
|
13 | 19 | getattr_chain,
|
14 | 20 | interpolate,
|
15 | 21 | patch_attr,
|
16 | 22 | validate_str_iterable,
|
17 | 23 | )
|
| 24 | +from llmcompressor.utils.dev import skip_weights_download |
| 25 | +from tests.testing_utils import requires_gpu |
18 | 26 |
|
19 | 27 |
|
20 | 28 | @pytest.mark.unit
|
@@ -173,3 +181,25 @@ def test_patch_attr():
|
173 | 181 | assert obj.attribute == "patched"
|
174 | 182 | obj.attribute = "modified"
|
175 | 183 | assert not hasattr(obj, "attribute")
|
| 184 | + |
| 185 | + |
| 186 | +@requires_gpu |
| 187 | +@pytest.mark.unit |
| 188 | +@pytest.mark.parametrize( |
| 189 | + "model_cls,model_stub", |
| 190 | + [ |
| 191 | + (MllamaForConditionalGeneration, "meta-llama/Llama-3.2-11B-Vision-Instruct"), |
| 192 | + (AutoModelForCausalLM, "nm-testing/llama2.c-stories15M"), |
| 193 | + ], |
| 194 | +) |
| 195 | +def test_disable_cache(model_cls, model_stub): |
| 196 | + with skip_weights_download(model_cls): |
| 197 | + model = model_cls.from_pretrained(model_stub, device_map="cuda") |
| 198 | + inputs = {key: value.to(model.device) for key, value in model.dummy_inputs.items()} |
| 199 | + |
| 200 | + with disable_cache(model): |
| 201 | + output = model(**inputs) |
| 202 | + assert output.past_key_values is None |
| 203 | + |
| 204 | + output = model(**inputs) |
| 205 | + assert output.past_key_values is not None |
0 commit comments