Skip to content

Commit eba87d3

Browse files
reworked on final review comments
1 parent 364d4da commit eba87d3

File tree

10 files changed

+407
-238
lines changed

10 files changed

+407
-238
lines changed

keras/api/_tf_keras/keras/quantizers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from keras.src.quantizers import deserialize as deserialize
88
from keras.src.quantizers import get as get
99
from keras.src.quantizers import serialize as serialize
10+
from keras.src.quantizers.gptq_config import GPTQConfig as GPTQConfig
1011
from keras.src.quantizers.quantizers import AbsMaxQuantizer as AbsMaxQuantizer
1112
from keras.src.quantizers.quantizers import Quantizer as Quantizer
1213
from keras.src.quantizers.quantizers import abs_max_quantize as abs_max_quantize

keras/api/quantizers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from keras.src.quantizers import deserialize as deserialize
88
from keras.src.quantizers import get as get
99
from keras.src.quantizers import serialize as serialize
10+
from keras.src.quantizers.gptq_config import GPTQConfig as GPTQConfig
1011
from keras.src.quantizers.quantizers import AbsMaxQuantizer as AbsMaxQuantizer
1112
from keras.src.quantizers.quantizers import Quantizer as Quantizer
1213
from keras.src.quantizers.quantizers import abs_max_quantize as abs_max_quantize

keras/src/models/model.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -436,10 +436,9 @@ def quantize(self, mode, config=None, **kwargs):
436436

437437
if mode == "gptq":
438438
if not isinstance(config, GPTQConfig):
439-
raise TypeError(
440-
"When using 'gptq' mode, you must pass a `config` "
441-
"argument of type "
442-
"`keras.quantizers.gptq_config.GPTQConfig`."
439+
raise ValueError(
440+
"The `config` argument must be of type "
441+
"`keras.quantizers.GPTQConfig`."
443442
)
444443
# The config object's own quantize method drives the process
445444
config.quantize(self)

keras/src/models/model_test.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1243,12 +1243,12 @@ def test_export_error(self):
12431243

12441244

12451245
# Helper function to generate dummy data for quick testing.
1246-
def dummy_dataset_generator(nsamples, seqlen, vocab_size=1000):
1246+
def dummy_dataset_generator(num_samples, sequence_length, vocab_size=1000):
12471247
"""A generator that yields random numpy arrays for fast,
12481248
self-contained tests."""
12491249
rng = np.random.default_rng(seed=42)
1250-
for _ in range(nsamples):
1251-
yield rng.integers(low=0, high=vocab_size, size=(1, seqlen))
1250+
for _ in range(num_samples):
1251+
yield rng.integers(low=0, high=vocab_size, size=(1, sequence_length))
12521252

12531253

12541254
# Helper function to build a simple transformer model.
@@ -1295,13 +1295,13 @@ def call(self, inputs):
12951295
DATASETS = {
12961296
"string_dataset": [long_text],
12971297
"generator_dataset": lambda: dummy_dataset_generator(
1298-
nsamples=16, seqlen=128
1298+
num_samples=16, sequence_length=128
12991299
),
13001300
}
13011301
CONFIGS = {
13021302
"default": {},
13031303
"per_channel": {"group_size": -1},
1304-
"act_order": {"act_order": True},
1304+
"act_order": {"activation_order": True},
13051305
"symmetric": {"symmetric": True},
13061306
}
13071307

@@ -1315,9 +1315,9 @@ def _get_simple_model():
13151315
# --- Error Scenarios ---
13161316
(
13171317
"gptq",
1318-
{"wbits": 4}, # Invalid config (dict, not GPTQConfig)
1319-
TypeError,
1320-
"must pass a `config` argument of type",
1318+
{"weight_bits": 4}, # Invalid config (dict, not GPTQConfig)
1319+
ValueError,
1320+
"The `config` argument must be of type",
13211321
"gptq_with_invalid_config",
13221322
),
13231323
(
@@ -1360,12 +1360,12 @@ def _run_gptq_test_on_dataset(self, dataset, **config_kwargs):
13601360
base_config = {
13611361
"dataset": dataset,
13621362
"tokenizer": mock_tokenizer,
1363-
"wbits": W_BITS,
1364-
"nsamples": NUM_SAMPLES,
1365-
"seqlen": SEQUENCE_LENGTH,
1363+
"weight_bits": W_BITS,
1364+
"num_samples": NUM_SAMPLES,
1365+
"sequence_length": SEQUENCE_LENGTH,
13661366
"group_size": 32,
13671367
"symmetric": False,
1368-
"act_order": False,
1368+
"activation_order": False,
13691369
}
13701370

13711371
target_layer = model.layers[2].ffn.layers[0]

0 commit comments

Comments
 (0)