Skip to content

Commit 03ce84c

Browse files
Initial commit
1 parent a00b966 commit 03ce84c

File tree

2 files changed

+44
-8
lines changed

2 files changed

+44
-8
lines changed

optimum/intel/openvino/quantization.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,14 +165,29 @@ def __init__(
165165
self.apply_caching = apply_caching
166166
self.inference_result_mock = inference_result_mock
167167
self.tensor_cache = {}
168+
self.stateful = len(request.query_state()) > 0
169+
self._reset_state_called = False
168170

169171
def collect_inputs(self, inputs):
172+
if self.stateful:
173+
if isinstance(inputs, dict) and is_nncf_version(">", "2.19"):
174+
from nncf.definitions import NNCF_DATASET_RESET_STATE_KEY
175+
176+
# To reflect the state resetting during NNCF calibration, we add a special key to the input dict
177+
# Shallow copying is done on purpose shallow copy on purpose: we only need to add a key to the top-level dict
178+
inputs = inputs.copy()
179+
inputs[NNCF_DATASET_RESET_STATE_KEY] = self._reset_state_called
180+
self._reset_state_called = False
181+
170182
if not self.apply_caching or not isinstance(inputs, dict):
171183
self.collected_inputs.append(copy.deepcopy(inputs))
172184
return
173185

174186
copied_inputs = {}
175187
for k, v in inputs.items():
188+
if isinstance(v, bool):
189+
copied_inputs[k] = v
190+
continue
176191
data = v
177192
if isinstance(data, openvino.Tensor):
178193
data = data.data
@@ -221,6 +236,10 @@ def wait(self):
221236
def get_tensor(self, name: str):
222237
return Tensor(self.request.results[name])
223238

239+
def reset_state(self):
240+
self.request.reset_state()
241+
self._reset_state_called = True
242+
224243
def __getattr__(self, attr):
225244
if attr in self.__dict__:
226245
return getattr(self, attr)

tests/openvino/test_quantization.py

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2143,10 +2143,12 @@ def _generate_random_audio_data(processor):
21432143
).input_features
21442144
return input_features
21452145

2146-
@parameterized.expand(itertools.product(MODEL_NAME, APPLY_CACHING))
2147-
def test_calibration_data_uniqueness(self, model_name, apply_caching):
2146+
@parameterized.expand(itertools.product(MODEL_NAME, STATEFUL, APPLY_CACHING))
2147+
def test_calibration_data_uniqueness(self, model_name, stateful, apply_caching):
21482148
model_id = MODEL_NAMES[model_name]
2149-
ov_model = OVModelForSpeechSeq2Seq.from_pretrained(model_id, export=True, compile=True, device=OPENVINO_DEVICE)
2149+
ov_model = OVModelForSpeechSeq2Seq.from_pretrained(
2150+
model_id, export=True, compile=True, stateful=stateful, device=OPENVINO_DEVICE
2151+
)
21502152
processor = AutoProcessor.from_pretrained(model_id)
21512153

21522154
calibration_data = []
@@ -2158,13 +2160,28 @@ def test_calibration_data_uniqueness(self, model_name, apply_caching):
21582160
ov_model.decoder.request = InferRequestWrapper(
21592161
ov_model.decoder.request, calibration_data, apply_caching=apply_caching
21602162
)
2161-
for _ in range(2):
2163+
n_samples = 3
2164+
for _ in range(n_samples):
21622165
input_features = self._generate_random_audio_data(processor)
21632166
ov_model.generate(input_features, max_new_tokens=10, min_new_tokens=10)
21642167

21652168
data_hashes_per_key = defaultdict(list)
21662169
data_id_per_key = defaultdict(set)
21672170

2171+
# Check that reset state flag is present and correctly set in collected inputs
2172+
if stateful and is_nncf_version(">", "2.19"):
2173+
from nncf.definitions import NNCF_DATASET_RESET_STATE_KEY
2174+
2175+
# All inputs should have reset state key
2176+
self.assertTrue(all(NNCF_DATASET_RESET_STATE_KEY in inputs_dict for inputs_dict in calibration_data))
2177+
# The number of times reset state flag is set to True should be equal to (2 * n_samples), because
2178+
# for each sequence generation, the state is reset twice
2179+
self.assertEqual(
2180+
sum(int(inputs_dict[NNCF_DATASET_RESET_STATE_KEY]) for inputs_dict in calibration_data), 2 * n_samples
2181+
)
2182+
# Remove reset state key from inputs to avoid affecting data uniqueness checks
2183+
[input_dict.pop(NNCF_DATASET_RESET_STATE_KEY) for input_dict in calibration_data]
2184+
21682185
for inputs_dict in calibration_data:
21692186
for k, v in inputs_dict.items():
21702187
if k in ["input_ids", "beam_idx"]:
@@ -2174,14 +2191,14 @@ def test_calibration_data_uniqueness(self, model_name, apply_caching):
21742191
data_hashes_per_key[k].append(hash(x.tobytes()))
21752192
data_id_per_key[k].add(id(v))
21762193
for k, data_hashes in data_hashes_per_key.items():
2177-
# All hashes can not be equal because calibration dataset contains at least 2 different samples
2194+
# All hashes can not be equal because calibration dataset contains at least n_samples different samples
21782195
self.assertTrue(any(data_hashes[0] != it for it in data_hashes))
21792196
if apply_caching:
2180-
# With caching, encoder hidden states tensors should be cached, resulting in only 2 tensors stored
2181-
self.assertEqual(len(data_id_per_key["encoder_hidden_states"]), 2)
2197+
# With caching, encoder hidden states tensors should be cached, resulting in only n_samples tensors stored
2198+
self.assertEqual(len(data_id_per_key["encoder_hidden_states"]), n_samples)
21822199
else:
21832200
# Without caching, encoder hidden states tensors will be unique for each collected input
2184-
self.assertGreater(len(data_id_per_key["encoder_hidden_states"]), 2)
2201+
self.assertGreater(len(data_id_per_key["encoder_hidden_states"]), n_samples)
21852202

21862203

21872204
def check_optimization_not_applicable_to_optimized_model(model, quantization_config):

0 commit comments

Comments
 (0)