Skip to content

Commit 7da7c2c

Browse files
add comments for clarification
1 parent 96408fa commit 7da7c2c

File tree

1 file changed

+29
-0
lines changed

1 file changed

+29
-0
lines changed

keras/src/backend/openvino/numpy.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1484,6 +1484,8 @@ def trace(x, offset=0, axis1=0, axis2=1):
14841484

14851485

14861486
def tri(N, M=None, k=0, dtype=None):
1487+
# Create a lower-triangular matrix with ones below and on the k-th diagonal,
1488+
# zeros elsewhere.
14871489
if M is None:
14881490
M = N
14891491
if dtype is None:
@@ -1495,6 +1497,7 @@ def tri(N, M=None, k=0, dtype=None):
14951497
M = ov_opset.constant(M, Type.i32)
14961498
k = ov_opset.constant(k, Type.i32)
14971499

1500+
# Create row and column indices: [0, 1, ..., N-1] and [0, 1, ..., M-1]
14981501
row_range = ov_opset.range(
14991502
ov_opset.constant(0, Type.i32),
15001503
N,
@@ -1508,16 +1511,20 @@ def tri(N, M=None, k=0, dtype=None):
15081511
output_type=Type.i32,
15091512
)
15101513

1514+
# Reshape row/col indices to 2D for broadcasting:
1515+
# row_idx: shape (N, 1), col_idx: shape (1, M)
15111516
row_idx = ov_opset.unsqueeze(row_range, ov_opset.constant([1], Type.i32))
15121517
col_idx = ov_opset.unsqueeze(col_range, ov_opset.constant([0], Type.i32))
15131518

1519+
# Broadcast row_idx and col_idx to (N, M) so we can compare every pair
15141520
target_shape = ov_opset.concat(
15151521
[ov_opset.unsqueeze(N, [0]), ov_opset.unsqueeze(M, [0])], axis=0
15161522
)
15171523

15181524
row_idx = ov_opset.broadcast(row_idx, target_shape)
15191525
col_idx = ov_opset.broadcast(col_idx, target_shape)
15201526

1527+
# Create mask: 1 where col_idx <= row_idx + k (i.e., lower triangle), else 0
15211528
mask = ov_opset.less_equal(col_idx, ov_opset.add(row_idx, k))
15221529

15231530
if ov_dtype == Type.boolean:
@@ -1529,7 +1536,10 @@ def tri(N, M=None, k=0, dtype=None):
15291536

15301537

15311538
def tril(x, k=0):
1539+
# Applies a lower-triangular mask to the last two dims of x,
1540+
# keeping elements below/on k-th diagonal.
15321541
def get_shape_dims(x):
1542+
# get shape as 1D tensor
15331543
shape = ov_opset.shape_of(x, Type.i32)
15341544
rank_tensor = ov_opset.shape_of(shape, Type.i32)
15351545
rank_scalar = ov_opset.squeeze(
@@ -1548,6 +1558,7 @@ def get_shape_dims(x):
15481558
input_shape = ov_opset.shape_of(x, Type.i32)
15491559
shape = get_shape_dims(x)
15501560

1561+
# Get matrix dimensions (last two dims)
15511562
zero_const = ov_opset.constant(0, Type.i32)
15521563
minus2 = ov_opset.constant([-2], Type.i32)
15531564
minus1 = ov_opset.constant([-1], Type.i32)
@@ -1561,6 +1572,7 @@ def get_shape_dims(x):
15611572
ov_opset.constant([0], Type.i32),
15621573
)
15631574

1575+
# Create row and column indices for the matrix part
15641576
row_range = ov_opset.range(
15651577
ov_opset.constant(0, Type.i32),
15661578
M,
@@ -1574,6 +1586,7 @@ def get_shape_dims(x):
15741586
output_type=Type.i32,
15751587
)
15761588

1589+
# Reshape for broadcasting to (M, N)
15771590
row_idx = ov_opset.unsqueeze(row_range, ov_opset.constant([1], Type.i32))
15781591
col_idx = ov_opset.unsqueeze(col_range, ov_opset.constant([0], Type.i32))
15791592

@@ -1584,10 +1597,13 @@ def get_shape_dims(x):
15841597
row_idx = ov_opset.broadcast(row_idx, target_shape)
15851598
col_idx = ov_opset.broadcast(col_idx, target_shape)
15861599

1600+
# Mask for lower triangle (col <= row + k)
15871601
k_const = ov_opset.constant(k, Type.i32)
15881602
mask = ov_opset.less_equal(col_idx, ov_opset.add(row_idx, k_const))
15891603
mask = ov_opset.convert(mask, ov_type)
15901604

1605+
# --- Batch broadcasting logic ---
1606+
# Compute the number of batch dimensions (all dims except last two)
15911607
shape_rank_tensor = ov_opset.shape_of(input_shape, Type.i32)
15921608
shape_rank = ov_opset.squeeze(
15931609
shape_rank_tensor, ov_opset.constant([0], Type.i32)
@@ -1599,15 +1615,18 @@ def get_shape_dims(x):
15991615
batch_dims_count, ov_opset.constant([0], Type.i32)
16001616
)
16011617

1618+
# Create a range for batch dimension indices
16021619
batch_indices = ov_opset.range(
16031620
start=ov_opset.constant(0, Type.i32),
16041621
stop=batch_dims_count,
16051622
step=ov_opset.constant(1, Type.i32),
16061623
output_type=Type.i32,
16071624
)
16081625

1626+
# Gather the batch shape from input_shape using batch_indices
16091627
batch_shape = ov_opset.gather(input_shape, batch_indices, axis=0)
16101628
full_mask_shape = ov_opset.concat([batch_shape, M_1d, N_1d], axis=0)
1629+
# Broadcast the mask to the full input shape (including batch)
16111630
mask = ov_opset.broadcast(mask, full_mask_shape)
16121631

16131632
if ov_type == Type.boolean:
@@ -1618,6 +1637,8 @@ def get_shape_dims(x):
16181637

16191638

16201639
def triu(x, k=0):
1640+
# Applies an upper-triangular mask to the last two dims of x,
1641+
# keeping elements above/on k-th diagonal.
16211642
def get_shape_dims(x):
16221643
shape = ov_opset.shape_of(x, Type.i32)
16231644
rank_tensor = ov_opset.shape_of(shape, Type.i32)
@@ -1637,6 +1658,7 @@ def get_shape_dims(x):
16371658
input_shape = ov_opset.shape_of(x, Type.i32)
16381659
shape = get_shape_dims(x)
16391660

1661+
# Get matrix dimensions (last two dims)
16401662
zero_const = ov_opset.constant(0, Type.i32)
16411663
minus2 = ov_opset.constant([-2], Type.i32)
16421664
minus1 = ov_opset.constant([-1], Type.i32)
@@ -1650,6 +1672,7 @@ def get_shape_dims(x):
16501672
ov_opset.constant([0], Type.i32),
16511673
)
16521674

1675+
# Create row and column indices for the matrix part
16531676
row_range = ov_opset.range(
16541677
ov_opset.constant(0, Type.i32),
16551678
M,
@@ -1663,6 +1686,7 @@ def get_shape_dims(x):
16631686
output_type=Type.i32,
16641687
)
16651688

1689+
# Reshape for broadcasting to (M, N)
16661690
row_idx = ov_opset.unsqueeze(row_range, ov_opset.constant([1], Type.i32))
16671691
col_idx = ov_opset.unsqueeze(col_range, ov_opset.constant([0], Type.i32))
16681692

@@ -1673,10 +1697,13 @@ def get_shape_dims(x):
16731697
row_idx = ov_opset.broadcast(row_idx, target_shape)
16741698
col_idx = ov_opset.broadcast(col_idx, target_shape)
16751699

1700+
# Mask for upper triangle (col >= row + k)
16761701
k_const = ov_opset.constant(k, Type.i32)
16771702
mask = ov_opset.greater_equal(col_idx, ov_opset.add(row_idx, k_const))
16781703
mask = ov_opset.convert(mask, ov_type)
16791704

1705+
# --- Batch broadcasting logic ---
1706+
# Compute the number of batch dimensions (all dims except last two)
16801707
shape_rank_tensor = ov_opset.shape_of(input_shape, Type.i32)
16811708
shape_rank = ov_opset.squeeze(
16821709
shape_rank_tensor, ov_opset.constant([0], Type.i32)
@@ -1688,6 +1715,7 @@ def get_shape_dims(x):
16881715
batch_dims_count, ov_opset.constant([0], Type.i32)
16891716
)
16901717

1718+
# Create a range for batch dimension indices
16911719
batch_indices = ov_opset.range(
16921720
start=ov_opset.constant(0, Type.i32),
16931721
stop=batch_dims_count,
@@ -1697,6 +1725,7 @@ def get_shape_dims(x):
16971725

16981726
batch_shape = ov_opset.gather(input_shape, batch_indices, axis=0)
16991727
full_mask_shape = ov_opset.concat([batch_shape, M_1d, N_1d], axis=0)
1728+
# Broadcast the mask to the full input shape (including batch)
17001729
mask = ov_opset.broadcast(mask, full_mask_shape)
17011730

17021731
if ov_type == Type.boolean:

0 commit comments

Comments
 (0)