@@ -1484,6 +1484,8 @@ def trace(x, offset=0, axis1=0, axis2=1):
1484
1484
1485
1485
1486
1486
def tri (N , M = None , k = 0 , dtype = None ):
1487
+ # Create a lower-triangular matrix with ones below and on the k-th diagonal,
1488
+ # zeros elsewhere.
1487
1489
if M is None :
1488
1490
M = N
1489
1491
if dtype is None :
@@ -1495,6 +1497,7 @@ def tri(N, M=None, k=0, dtype=None):
1495
1497
M = ov_opset .constant (M , Type .i32 )
1496
1498
k = ov_opset .constant (k , Type .i32 )
1497
1499
1500
+ # Create row and column indices: [0, 1, ..., N-1] and [0, 1, ..., M-1]
1498
1501
row_range = ov_opset .range (
1499
1502
ov_opset .constant (0 , Type .i32 ),
1500
1503
N ,
@@ -1508,16 +1511,20 @@ def tri(N, M=None, k=0, dtype=None):
1508
1511
output_type = Type .i32 ,
1509
1512
)
1510
1513
1514
+ # Reshape row/col indices to 2D for broadcasting:
1515
+ # row_idx: shape (N, 1), col_idx: shape (1, M)
1511
1516
row_idx = ov_opset .unsqueeze (row_range , ov_opset .constant ([1 ], Type .i32 ))
1512
1517
col_idx = ov_opset .unsqueeze (col_range , ov_opset .constant ([0 ], Type .i32 ))
1513
1518
1519
+ # Broadcast row_idx and col_idx to (N, M) so we can compare every pair
1514
1520
target_shape = ov_opset .concat (
1515
1521
[ov_opset .unsqueeze (N , [0 ]), ov_opset .unsqueeze (M , [0 ])], axis = 0
1516
1522
)
1517
1523
1518
1524
row_idx = ov_opset .broadcast (row_idx , target_shape )
1519
1525
col_idx = ov_opset .broadcast (col_idx , target_shape )
1520
1526
1527
+ # Create mask: 1 where col_idx <= row_idx + k (i.e., lower triangle), else 0
1521
1528
mask = ov_opset .less_equal (col_idx , ov_opset .add (row_idx , k ))
1522
1529
1523
1530
if ov_dtype == Type .boolean :
@@ -1529,7 +1536,10 @@ def tri(N, M=None, k=0, dtype=None):
1529
1536
1530
1537
1531
1538
def tril (x , k = 0 ):
1539
+ # Applies a lower-triangular mask to the last two dims of x,
1540
+ # keeping elements below/on k-th diagonal.
1532
1541
def get_shape_dims (x ):
1542
+ # get shape as 1D tensor
1533
1543
shape = ov_opset .shape_of (x , Type .i32 )
1534
1544
rank_tensor = ov_opset .shape_of (shape , Type .i32 )
1535
1545
rank_scalar = ov_opset .squeeze (
@@ -1548,6 +1558,7 @@ def get_shape_dims(x):
1548
1558
input_shape = ov_opset .shape_of (x , Type .i32 )
1549
1559
shape = get_shape_dims (x )
1550
1560
1561
+ # Get matrix dimensions (last two dims)
1551
1562
zero_const = ov_opset .constant (0 , Type .i32 )
1552
1563
minus2 = ov_opset .constant ([- 2 ], Type .i32 )
1553
1564
minus1 = ov_opset .constant ([- 1 ], Type .i32 )
@@ -1561,6 +1572,7 @@ def get_shape_dims(x):
1561
1572
ov_opset .constant ([0 ], Type .i32 ),
1562
1573
)
1563
1574
1575
+ # Create row and column indices for the matrix part
1564
1576
row_range = ov_opset .range (
1565
1577
ov_opset .constant (0 , Type .i32 ),
1566
1578
M ,
@@ -1574,6 +1586,7 @@ def get_shape_dims(x):
1574
1586
output_type = Type .i32 ,
1575
1587
)
1576
1588
1589
+ # Reshape for broadcasting to (M, N)
1577
1590
row_idx = ov_opset .unsqueeze (row_range , ov_opset .constant ([1 ], Type .i32 ))
1578
1591
col_idx = ov_opset .unsqueeze (col_range , ov_opset .constant ([0 ], Type .i32 ))
1579
1592
@@ -1584,10 +1597,13 @@ def get_shape_dims(x):
1584
1597
row_idx = ov_opset .broadcast (row_idx , target_shape )
1585
1598
col_idx = ov_opset .broadcast (col_idx , target_shape )
1586
1599
1600
+ # Mask for lower triangle (col <= row + k)
1587
1601
k_const = ov_opset .constant (k , Type .i32 )
1588
1602
mask = ov_opset .less_equal (col_idx , ov_opset .add (row_idx , k_const ))
1589
1603
mask = ov_opset .convert (mask , ov_type )
1590
1604
1605
+ # --- Batch broadcasting logic ---
1606
+ # Compute the number of batch dimensions (all dims except last two)
1591
1607
shape_rank_tensor = ov_opset .shape_of (input_shape , Type .i32 )
1592
1608
shape_rank = ov_opset .squeeze (
1593
1609
shape_rank_tensor , ov_opset .constant ([0 ], Type .i32 )
@@ -1599,15 +1615,18 @@ def get_shape_dims(x):
1599
1615
batch_dims_count , ov_opset .constant ([0 ], Type .i32 )
1600
1616
)
1601
1617
1618
+ # Create a range for batch dimension indices
1602
1619
batch_indices = ov_opset .range (
1603
1620
start = ov_opset .constant (0 , Type .i32 ),
1604
1621
stop = batch_dims_count ,
1605
1622
step = ov_opset .constant (1 , Type .i32 ),
1606
1623
output_type = Type .i32 ,
1607
1624
)
1608
1625
1626
+ # Gather the batch shape from input_shape using batch_indices
1609
1627
batch_shape = ov_opset .gather (input_shape , batch_indices , axis = 0 )
1610
1628
full_mask_shape = ov_opset .concat ([batch_shape , M_1d , N_1d ], axis = 0 )
1629
+ # Broadcast the mask to the full input shape (including batch)
1611
1630
mask = ov_opset .broadcast (mask , full_mask_shape )
1612
1631
1613
1632
if ov_type == Type .boolean :
@@ -1618,6 +1637,8 @@ def get_shape_dims(x):
1618
1637
1619
1638
1620
1639
def triu (x , k = 0 ):
1640
+ # Applies an upper-triangular mask to the last two dims of x,
1641
+ # keeping elements above/on k-th diagonal.
1621
1642
def get_shape_dims (x ):
1622
1643
shape = ov_opset .shape_of (x , Type .i32 )
1623
1644
rank_tensor = ov_opset .shape_of (shape , Type .i32 )
@@ -1637,6 +1658,7 @@ def get_shape_dims(x):
1637
1658
input_shape = ov_opset .shape_of (x , Type .i32 )
1638
1659
shape = get_shape_dims (x )
1639
1660
1661
+ # Get matrix dimensions (last two dims)
1640
1662
zero_const = ov_opset .constant (0 , Type .i32 )
1641
1663
minus2 = ov_opset .constant ([- 2 ], Type .i32 )
1642
1664
minus1 = ov_opset .constant ([- 1 ], Type .i32 )
@@ -1650,6 +1672,7 @@ def get_shape_dims(x):
1650
1672
ov_opset .constant ([0 ], Type .i32 ),
1651
1673
)
1652
1674
1675
+ # Create row and column indices for the matrix part
1653
1676
row_range = ov_opset .range (
1654
1677
ov_opset .constant (0 , Type .i32 ),
1655
1678
M ,
@@ -1663,6 +1686,7 @@ def get_shape_dims(x):
1663
1686
output_type = Type .i32 ,
1664
1687
)
1665
1688
1689
+ # Reshape for broadcasting to (M, N)
1666
1690
row_idx = ov_opset .unsqueeze (row_range , ov_opset .constant ([1 ], Type .i32 ))
1667
1691
col_idx = ov_opset .unsqueeze (col_range , ov_opset .constant ([0 ], Type .i32 ))
1668
1692
@@ -1673,10 +1697,13 @@ def get_shape_dims(x):
1673
1697
row_idx = ov_opset .broadcast (row_idx , target_shape )
1674
1698
col_idx = ov_opset .broadcast (col_idx , target_shape )
1675
1699
1700
+ # Mask for upper triangle (col >= row + k)
1676
1701
k_const = ov_opset .constant (k , Type .i32 )
1677
1702
mask = ov_opset .greater_equal (col_idx , ov_opset .add (row_idx , k_const ))
1678
1703
mask = ov_opset .convert (mask , ov_type )
1679
1704
1705
+ # --- Batch broadcasting logic ---
1706
+ # Compute the number of batch dimensions (all dims except last two)
1680
1707
shape_rank_tensor = ov_opset .shape_of (input_shape , Type .i32 )
1681
1708
shape_rank = ov_opset .squeeze (
1682
1709
shape_rank_tensor , ov_opset .constant ([0 ], Type .i32 )
@@ -1688,6 +1715,7 @@ def get_shape_dims(x):
1688
1715
batch_dims_count , ov_opset .constant ([0 ], Type .i32 )
1689
1716
)
1690
1717
1718
+ # Create a range for batch dimension indices
1691
1719
batch_indices = ov_opset .range (
1692
1720
start = ov_opset .constant (0 , Type .i32 ),
1693
1721
stop = batch_dims_count ,
@@ -1697,6 +1725,7 @@ def get_shape_dims(x):
1697
1725
1698
1726
batch_shape = ov_opset .gather (input_shape , batch_indices , axis = 0 )
1699
1727
full_mask_shape = ov_opset .concat ([batch_shape , M_1d , N_1d ], axis = 0 )
1728
+ # Broadcast the mask to the full input shape (including batch)
1700
1729
mask = ov_opset .broadcast (mask , full_mask_shape )
1701
1730
1702
1731
if ov_type == Type .boolean :
0 commit comments