Skip to content

Commit 46fa854

Browse files
[OpenVINO backend] support categorical
1 parent be9b002 commit 46fa854

File tree

3 files changed

+113
-3
lines changed

3 files changed

+113
-3
lines changed

keras/src/backend/openvino/excluded_concrete_tests.txt

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,3 +188,58 @@ CoreOpsDtypeTest::test_convert_to_tensor12
188188
CoreOpsDtypeTest::test_convert_to_tensor14
189189
CoreOpsDtypeTest::test_convert_to_tensor25
190190
CoreOpsDtypeTest::test_convert_to_tensor37
191+
RandomCorrectnessTest::test_beta0
192+
RandomCorrectnessTest::test_beta1
193+
RandomCorrectnessTest::test_beta2
194+
RandomCorrectnessTest::test_binomial0
195+
RandomCorrectnessTest::test_binomial1
196+
RandomCorrectnessTest::test_binomial2
197+
RandomCorrectnessTest::test_dropout
198+
RandomCorrectnessTest::test_dropout_noise_shape
199+
RandomCorrectnessTest::test_gamma0
200+
RandomCorrectnessTest::test_gamma1
201+
RandomCorrectnessTest::test_gamma2
202+
RandomCorrectnessTest::test_randint0
203+
RandomCorrectnessTest::test_randint1
204+
RandomCorrectnessTest::test_randint2
205+
RandomCorrectnessTest::test_randint3
206+
RandomCorrectnessTest::test_randint4
207+
RandomCorrectnessTest::test_shuffle
208+
RandomCorrectnessTest::test_truncated_normal0
209+
RandomCorrectnessTest::test_truncated_normal1
210+
RandomCorrectnessTest::test_truncated_normal2
211+
RandomCorrectnessTest::test_truncated_normal3
212+
RandomCorrectnessTest::test_truncated_normal4
213+
RandomCorrectnessTest::test_truncated_normal5
214+
RandomCorrectnessTest::test_uniform0
215+
RandomCorrectnessTest::test_uniform1
216+
RandomCorrectnessTest::test_uniform2
217+
RandomCorrectnessTest::test_uniform3
218+
RandomCorrectnessTest::test_uniform4
219+
RandomBehaviorTest::test_beta_tf_data_compatibility
220+
RandomDTypeTest::test_beta_bfloat16
221+
RandomDTypeTest::test_beta_float16
222+
RandomDTypeTest::test_beta_float32
223+
RandomDTypeTest::test_beta_float64
224+
RandomDTypeTest::test_binomial_bfloat16
225+
RandomDTypeTest::test_binomial_float16
226+
RandomDTypeTest::test_binomial_float32
227+
RandomDTypeTest::test_binomial_float64
228+
RandomDTypeTest::test_dropout_bfloat16
229+
RandomDTypeTest::test_dropout_float16
230+
RandomDTypeTest::test_dropout_float32
231+
RandomDTypeTest::test_dropout_float64
232+
RandomDTypeTest::test_gamma_bfloat16
233+
RandomDTypeTest::test_gamma_float16
234+
RandomDTypeTest::test_gamma_float32
235+
RandomDTypeTest::test_gamma_float64
236+
RandomDTypeTest::test_normal_bfloat16
237+
RandomDTypeTest::test_randint_int16
238+
RandomDTypeTest::test_randint_int32
239+
RandomDTypeTest::test_randint_int64
240+
RandomDTypeTest::test_randint_int8
241+
RandomDTypeTest::test_randint_uint16
242+
RandomDTypeTest::test_randint_uint32
243+
RandomDTypeTest::test_randint_uint8
244+
RandomDTypeTest::test_truncated_normal_bfloat16
245+
RandomDTypeTest::test_uniform_bfloat16

keras/src/backend/openvino/excluded_tests.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ keras/src/ops/math_test.py
3333
keras/src/ops/nn_test.py
3434
keras/src/optimizers
3535
keras/src/quantizers
36-
keras/src/random
36+
keras/src/random/seed_generator_test.py
3737
keras/src/regularizers
3838
keras/src/saving
3939
keras/src/trainers

keras/src/backend/openvino/random.py

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from keras.src.backend.openvino.core import OPENVINO_DTYPES
77
from keras.src.backend.openvino.core import OpenVINOKerasTensor
88
from keras.src.backend.openvino.core import convert_to_numpy
9+
from keras.src.backend.openvino.core import get_ov_output
910
from keras.src.random.seed_generator import SeedGenerator
1011
from keras.src.random.seed_generator import draw_seed
1112
from keras.src.random.seed_generator import make_default_seed
@@ -39,9 +40,63 @@ def uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None):
3940

4041

4142
def categorical(logits, num_samples, dtype="int64", seed=None):
42-
raise NotImplementedError(
43-
"`categorical` is not supported with openvino backend"
43+
def get_shape_dims(x):
44+
shape = ov_opset.shape_of(x, Type.i32)
45+
rank_tensor = ov_opset.shape_of(shape, Type.i32)
46+
rank_scalar = ov_opset.squeeze(
47+
rank_tensor, ov_opset.constant([0], Type.i32)
48+
)
49+
indices = ov_opset.range(
50+
ov_opset.constant(0, Type.i32),
51+
rank_scalar,
52+
ov_opset.constant(1, Type.i32),
53+
output_type=Type.i32,
54+
)
55+
return ov_opset.gather(shape, indices, axis=0)
56+
57+
dtype = dtype or "int64"
58+
ov_dtype = OPENVINO_DTYPES[dtype]
59+
logits = get_ov_output(logits)
60+
probs = ov_opset.softmax(logits, axis=-1)
61+
cumsum_probs = ov_opset.cumsum(probs, ov_opset.constant(-1, dtype="int32"))
62+
shape = get_shape_dims(logits)
63+
rank_tensor = ov_opset.shape_of(shape, Type.i32)
64+
rank = ov_opset.squeeze(rank_tensor, ov_opset.constant([0], dtype=Type.i32))
65+
rank_minus_1 = ov_opset.subtract(rank, ov_opset.constant(1, dtype=Type.i32))
66+
indices = ov_opset.range(
67+
ov_opset.constant(0, dtype=Type.i32),
68+
rank_minus_1,
69+
ov_opset.constant(1, dtype=Type.i32),
70+
output_type=Type.i32,
71+
)
72+
batch_shape = ov_opset.gather(shape, indices, axis=0)
73+
final_shape = ov_opset.concat(
74+
[batch_shape, ov_opset.constant([num_samples], dtype=Type.i32)], axis=0
75+
)
76+
seed_tensor = draw_seed(seed)
77+
if isinstance(seed_tensor, OpenVINOKerasTensor):
78+
seed1, seed2 = convert_to_numpy(seed_tensor)
79+
else:
80+
seed1, seed2 = seed_tensor.data
81+
rand = ov_opset.random_uniform(
82+
final_shape,
83+
ov_opset.constant(0.0, dtype=probs.get_element_type()),
84+
ov_opset.constant(1.0, dtype=probs.get_element_type()),
85+
probs.get_element_type(),
86+
seed1,
87+
seed2,
88+
)
89+
rand = ov_opset.unsqueeze(rand, [-1])
90+
cumsum_probs = ov_opset.unsqueeze(
91+
cumsum_probs, ov_opset.constant([1], dtype=Type.i32)
92+
)
93+
greater = ov_opset.greater(rand, cumsum_probs)
94+
samples = ov_opset.reduce_sum(
95+
ov_opset.convert(greater, Type.i32),
96+
ov_opset.constant([-1], dtype=Type.i32),
4497
)
98+
samples = ov_opset.convert(samples, ov_dtype)
99+
return OpenVINOKerasTensor(samples.output(0))
45100

46101

47102
def randint(shape, minval, maxval, dtype="int32", seed=None):

0 commit comments

Comments
 (0)