Skip to content

Commit 153c4dc

Browse files
committed
add kv cache disable tests
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 3fdeee8 commit 153c4dc

File tree

1 file changed

+31
-1
lines changed

1 file changed

+31
-1
lines changed

tests/llmcompressor/utils/test_helpers.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,27 @@
22

33
import pytest
44
import torch
5-
from transformers import PretrainedConfig, PreTrainedModel
5+
from transformers import (
6+
AutoModelForCausalLM,
7+
MllamaForConditionalGeneration,
8+
PretrainedConfig,
9+
PreTrainedModel,
10+
)
611

712
from llmcompressor.utils import (
813
ALL_TOKEN,
914
DisableQuantization,
1015
calibration_forward_context,
1116
convert_to_bool,
17+
disable_cache,
1218
flatten_iterable,
1319
getattr_chain,
1420
interpolate,
1521
patch_attr,
1622
validate_str_iterable,
1723
)
24+
from llmcompressor.utils.dev import skip_weights_download
25+
from tests.testing_utils import requires_gpu
1826

1927

2028
@pytest.mark.unit
@@ -173,3 +181,25 @@ def test_patch_attr():
173181
assert obj.attribute == "patched"
174182
obj.attribute = "modified"
175183
assert not hasattr(obj, "attribute")
184+
185+
186+
@requires_gpu
187+
@pytest.mark.unit
188+
@pytest.mark.parametrize(
189+
"model_cls,model_stub",
190+
[
191+
(MllamaForConditionalGeneration, "meta-llama/Llama-3.2-11B-Vision-Instruct"),
192+
(AutoModelForCausalLM, "nm-testing/llama2.c-stories15M"),
193+
],
194+
)
195+
def test_disable_cache(model_cls, model_stub):
196+
with skip_weights_download(model_cls):
197+
model = model_cls.from_pretrained(model_stub, device_map="cuda")
198+
inputs = {key: value.to(model.device) for key, value in model.dummy_inputs.items()}
199+
200+
with disable_cache(model):
201+
output = model(**inputs)
202+
assert output.past_key_values is None
203+
204+
output = model(**inputs)
205+
assert output.past_key_values is not None

0 commit comments

Comments
 (0)