Skip to content

Commit 18f8341

Browse files
committed
fix TestApplyQuantization test
Signed-off-by: Kyle Sayers <[email protected]>
1 parent d5676dc commit 18f8341

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

tests/llmcompressor/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from parameterized import parameterized
66

77
from llmcompressor.modifiers.obcq import SparseGPTModifier
8-
from llmcompressor.modifiers.quantization.calibration import calibrate_input_hook
98
from llmcompressor.modifiers.quantization.gptq import GPTQModifier
109
from tests.llmcompressor.modifiers.conf import (
1110
LifecyleTestingHarness,
@@ -61,7 +60,7 @@ def test_successful_layerwise_recipe(self):
6160

6261

6362
@pytest.mark.unit
64-
class TestCreateDefaultQuantModifier(unittest.TestCase):
63+
class TestApplyQuantization(unittest.TestCase):
6564
def setUp(self):
6665
setup_modifier_factory()
6766

@@ -77,8 +76,10 @@ def test_create_default_quant_modifier(self):
7776
assert hasattr(module, "quantization_scheme")
7877
assert hasattr(module, "input_observer")
7978
assert hasattr(module, "weight_observer")
80-
assert module._forward_pre_hooks[0] is calibrate_input_hook
81-
assert module._forward_hooks[0] is modifier.calibrate_module
79+
pre_hooks = list(module._forward_pre_hooks.values())
80+
post_hooks = list(module._forward_hooks.values())
81+
assert pre_hooks[0].__name__ == "calibrate_input_hook"
82+
assert post_hooks[0].__name__ == "calibrate_module"
8283

8384

8485
class TestSetQuantInGPTQ(unittest.TestCase):

0 commit comments

Comments
 (0)