Skip to content

Commit 1788872

Browse files
[OpenVINO backend] support categorical
1 parent be9b002 commit 1788872

File tree

3 files changed

+85
-3
lines changed

3 files changed

+85
-3
lines changed

keras/src/backend/openvino/excluded_concrete_tests.txt

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,3 +188,58 @@ CoreOpsDtypeTest::test_convert_to_tensor12
188188
CoreOpsDtypeTest::test_convert_to_tensor14
189189
CoreOpsDtypeTest::test_convert_to_tensor25
190190
CoreOpsDtypeTest::test_convert_to_tensor37
191+
RandomCorrectnessTest::test_beta0
192+
RandomCorrectnessTest::test_beta1
193+
RandomCorrectnessTest::test_beta2
194+
RandomCorrectnessTest::test_binomial0
195+
RandomCorrectnessTest::test_binomial1
196+
RandomCorrectnessTest::test_binomial2
197+
RandomCorrectnessTest::test_dropout
198+
RandomCorrectnessTest::test_dropout_noise_shape
199+
RandomCorrectnessTest::test_gamma0
200+
RandomCorrectnessTest::test_gamma1
201+
RandomCorrectnessTest::test_gamma2
202+
RandomCorrectnessTest::test_randint0
203+
RandomCorrectnessTest::test_randint1
204+
RandomCorrectnessTest::test_randint2
205+
RandomCorrectnessTest::test_randint3
206+
RandomCorrectnessTest::test_randint4
207+
RandomCorrectnessTest::test_shuffle
208+
RandomCorrectnessTest::test_truncated_normal0
209+
RandomCorrectnessTest::test_truncated_normal1
210+
RandomCorrectnessTest::test_truncated_normal2
211+
RandomCorrectnessTest::test_truncated_normal3
212+
RandomCorrectnessTest::test_truncated_normal4
213+
RandomCorrectnessTest::test_truncated_normal5
214+
RandomCorrectnessTest::test_uniform0
215+
RandomCorrectnessTest::test_uniform1
216+
RandomCorrectnessTest::test_uniform2
217+
RandomCorrectnessTest::test_uniform3
218+
RandomCorrectnessTest::test_uniform4
219+
RandomBehaviorTest::test_beta_tf_data_compatibility
220+
RandomDTypeTest::test_beta_bfloat16
221+
RandomDTypeTest::test_beta_float16
222+
RandomDTypeTest::test_beta_float32
223+
RandomDTypeTest::test_beta_float64
224+
RandomDTypeTest::test_binomial_bfloat16
225+
RandomDTypeTest::test_binomial_float16
226+
RandomDTypeTest::test_binomial_float32
227+
RandomDTypeTest::test_binomial_float64
228+
RandomDTypeTest::test_dropout_bfloat16
229+
RandomDTypeTest::test_dropout_float16
230+
RandomDTypeTest::test_dropout_float32
231+
RandomDTypeTest::test_dropout_float64
232+
RandomDTypeTest::test_gamma_bfloat16
233+
RandomDTypeTest::test_gamma_float16
234+
RandomDTypeTest::test_gamma_float32
235+
RandomDTypeTest::test_gamma_float64
236+
RandomDTypeTest::test_normal_bfloat16
237+
RandomDTypeTest::test_randint_int16
238+
RandomDTypeTest::test_randint_int32
239+
RandomDTypeTest::test_randint_int64
240+
RandomDTypeTest::test_randint_int8
241+
RandomDTypeTest::test_randint_uint16
242+
RandomDTypeTest::test_randint_uint32
243+
RandomDTypeTest::test_randint_uint8
244+
RandomDTypeTest::test_truncated_normal_bfloat16
245+
RandomDTypeTest::test_uniform_bfloat16

keras/src/backend/openvino/excluded_tests.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ keras/src/ops/math_test.py
3333
keras/src/ops/nn_test.py
3434
keras/src/optimizers
3535
keras/src/quantizers
36-
keras/src/random
36+
keras/src/random/seed_generator_test.py
3737
keras/src/regularizers
3838
keras/src/saving
3939
keras/src/trainers

keras/src/backend/openvino/random.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,36 @@ def uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None):
3939

4040

4141
def categorical(logits, num_samples, dtype="int64", seed=None):
42-
raise NotImplementedError(
43-
"`categorical` is not supported with openvino backend"
42+
if isinstance(logits, OpenVINOKerasTensor):
43+
logits = convert_to_numpy(logits)
44+
assert isinstance(logits, np.ndarray), (
45+
"logits must be a numpy array or an OpenVINOKerasTensor, "
46+
"got: {}".format(type(logits))
4447
)
48+
# Compute probabilities
49+
probs = np.exp(logits - np.max(logits, axis=-1, keepdims=True))
50+
probs = probs / np.sum(probs, axis=-1, keepdims=True)
51+
rng = np.random.default_rng(draw_seed(seed).data)
52+
# Batched sampling
53+
if probs.ndim == 1:
54+
samples = rng.choice(
55+
logits.shape[-1],
56+
size=(num_samples,),
57+
p=probs,
58+
replace=True,
59+
)
60+
samples = samples[None, :] # Add batch dim for consistency
61+
else:
62+
samples = np.empty((probs.shape[0], num_samples), dtype=dtype)
63+
for i in range(probs.shape[0]):
64+
samples[i] = rng.choice(
65+
logits.shape[-1],
66+
size=(num_samples,),
67+
p=probs[i],
68+
replace=True,
69+
)
70+
samples = samples.astype(dtype)
71+
return OpenVINOKerasTensor(ov_opset.constant(samples).output(0))
4572

4673

4774
def randint(shape, minval, maxval, dtype="int32", seed=None):

0 commit comments

Comments
 (0)