Skip to content

Commit 753f15b

Browse files
remove broadcasting and optimize
1 parent c1e1ebe commit 753f15b

File tree

1 file changed

+0
-62
lines changed

1 file changed

+0
-62
lines changed

keras/src/backend/openvino/numpy.py

Lines changed: 0 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1523,14 +1523,6 @@ def ensure_constant(value, default_type=Type.i32):
15231523
row_idx = ov_opset.unsqueeze(row_range, ov_opset.constant([1], Type.i32))
15241524
col_idx = ov_opset.unsqueeze(col_range, ov_opset.constant([0], Type.i32))
15251525

1526-
# Target shape for broadcasting
1527-
target_shape = ov_opset.concat(
1528-
[ov_opset.unsqueeze(N_const, [0]), ov_opset.unsqueeze(M_const, [0])],
1529-
axis=0,
1530-
)
1531-
1532-
row_idx = ov_opset.broadcast(row_idx, target_shape)
1533-
col_idx = ov_opset.broadcast(col_idx, target_shape)
15341526
mask = ov_opset.less_equal(col_idx, ov_opset.add(row_idx, k_const))
15351527

15361528
if ov_dtype == Type.boolean:
@@ -1558,7 +1550,6 @@ def get_shape_dims(x):
15581550

15591551
x = get_ov_output(x)
15601552
ov_type = x.get_element_type()
1561-
input_shape = ov_opset.shape_of(x, Type.i32)
15621553
shape = get_shape_dims(x)
15631554
zero_const = ov_opset.constant(0, Type.i32)
15641555
minus2 = ov_opset.constant([-2], Type.i32)
@@ -1571,31 +1562,6 @@ def get_shape_dims(x):
15711562

15721563
mask = ov_opset.convert(tri_mask, ov_type)
15731564

1574-
# Broadcast mask to input shape (including batch dims)
1575-
shape_rank = ov_opset.squeeze(
1576-
ov_opset.shape_of(input_shape, Type.i32), zero_const
1577-
)
1578-
batch_dims = ov_opset.subtract(shape_rank, ov_opset.constant(2, Type.i32))
1579-
batch_indices = ov_opset.range(
1580-
zero_const,
1581-
batch_dims,
1582-
ov_opset.constant(1, Type.i32),
1583-
output_type=Type.i32,
1584-
)
1585-
batch_shape = ov_opset.gather(input_shape, batch_indices, zero_const)
1586-
1587-
M_reshaped = ov_opset.unsqueeze(M, zero_const)
1588-
N_reshaped = ov_opset.unsqueeze(N, zero_const)
1589-
1590-
concat_inputs = [
1591-
batch_shape.output(0),
1592-
M_reshaped.output(0),
1593-
N_reshaped.output(0),
1594-
]
1595-
1596-
full_mask_shape = ov_opset.concat(concat_inputs, axis=0)
1597-
mask = ov_opset.broadcast(mask, full_mask_shape)
1598-
15991565
if ov_type == Type.boolean:
16001566
out = ov_opset.logical_and(x, mask)
16011567
else:
@@ -1621,7 +1587,6 @@ def get_shape_dims(x):
16211587

16221588
x = get_ov_output(x)
16231589
ov_type = x.get_element_type()
1624-
input_shape = ov_opset.shape_of(x, Type.i32)
16251590
shape = get_shape_dims(x)
16261591
zero_const = ov_opset.constant(0, Type.i32)
16271592
minus2 = ov_opset.constant([-2], Type.i32)
@@ -1631,7 +1596,6 @@ def get_shape_dims(x):
16311596

16321597
tri_mask = tri(M, N, k=k - 1, dtype="bool").output
16331598

1634-
# Handle boolean type differently since subtract doesn't work with boolean
16351599
if ov_type == Type.boolean:
16361600
mask = ov_opset.logical_not(tri_mask)
16371601
else:
@@ -1640,32 +1604,6 @@ def get_shape_dims(x):
16401604
)
16411605
mask = ov_opset.subtract(ones, ov_opset.convert(tri_mask, ov_type))
16421606

1643-
# Broadcast mask
1644-
shape_rank = ov_opset.squeeze(
1645-
ov_opset.shape_of(input_shape, Type.i32), zero_const
1646-
)
1647-
batch_dims = ov_opset.subtract(shape_rank, ov_opset.constant(2, Type.i32))
1648-
batch_indices = ov_opset.range(
1649-
zero_const,
1650-
batch_dims,
1651-
ov_opset.constant(1, Type.i32),
1652-
output_type=Type.i32,
1653-
)
1654-
batch_shape = ov_opset.gather(input_shape, batch_indices, zero_const)
1655-
1656-
# Ensure all tensors are properly shaped before concat
1657-
M_reshaped = ov_opset.unsqueeze(M, zero_const)
1658-
N_reshaped = ov_opset.unsqueeze(N, zero_const)
1659-
1660-
concat_inputs = [
1661-
batch_shape.output(0),
1662-
M_reshaped.output(0),
1663-
N_reshaped.output(0),
1664-
]
1665-
1666-
full_mask_shape = ov_opset.concat(concat_inputs, axis=0)
1667-
mask = ov_opset.broadcast(mask, full_mask_shape)
1668-
16691607
if ov_type == Type.boolean:
16701608
out = ov_opset.logical_and(x, mask)
16711609
else:

0 commit comments

Comments
 (0)