@@ -1484,15 +1484,226 @@ def trace(x, offset=0, axis1=0, axis2=1):
1484
1484
1485
1485
1486
1486
def tri (N , M = None , k = 0 , dtype = None ):
1487
- raise NotImplementedError ("`tri` is not supported with openvino backend" )
1487
+ if M is None :
1488
+ M = N
1489
+ if dtype is None :
1490
+ dtype = "float32"
1491
+
1492
+ ov_dtype = OPENVINO_DTYPES [dtype ]
1493
+
1494
+ N = ov_opset .constant (N , Type .i32 )
1495
+ M = ov_opset .constant (M , Type .i32 )
1496
+ k = ov_opset .constant (k , Type .i32 )
1497
+
1498
+ row_range = ov_opset .range (
1499
+ ov_opset .constant (0 , Type .i32 ),
1500
+ N ,
1501
+ ov_opset .constant (1 , Type .i32 ),
1502
+ output_type = Type .i32 ,
1503
+ )
1504
+ col_range = ov_opset .range (
1505
+ ov_opset .constant (0 , Type .i32 ),
1506
+ M ,
1507
+ ov_opset .constant (1 , Type .i32 ),
1508
+ output_type = Type .i32 ,
1509
+ )
1510
+
1511
+ row_idx = ov_opset .unsqueeze (row_range , ov_opset .constant ([1 ], Type .i32 ))
1512
+ col_idx = ov_opset .unsqueeze (col_range , ov_opset .constant ([0 ], Type .i32 ))
1513
+
1514
+ target_shape = ov_opset .concat (
1515
+ [ov_opset .unsqueeze (N , [0 ]), ov_opset .unsqueeze (M , [0 ])], axis = 0
1516
+ )
1517
+
1518
+ row_idx = ov_opset .broadcast (row_idx , target_shape )
1519
+ col_idx = ov_opset .broadcast (col_idx , target_shape )
1520
+
1521
+ mask = ov_opset .less_equal (col_idx , ov_opset .add (row_idx , k ))
1522
+
1523
+ if ov_dtype == Type .boolean :
1524
+ result = mask
1525
+ else :
1526
+ result = ov_opset .convert (mask , ov_dtype )
1527
+
1528
+ return OpenVINOKerasTensor (result .output (0 ))
1488
1529
1489
1530
1490
1531
def tril (x , k = 0 ):
1491
- raise NotImplementedError ("`tril` is not supported with openvino backend" )
1532
+ def get_shape_dims (x ):
1533
+ shape = ov_opset .shape_of (x , Type .i32 )
1534
+ rank_tensor = ov_opset .shape_of (shape , Type .i32 )
1535
+ rank_scalar = ov_opset .squeeze (
1536
+ rank_tensor , ov_opset .constant ([0 ], Type .i32 )
1537
+ )
1538
+ indices = ov_opset .range (
1539
+ ov_opset .constant (0 , Type .i32 ),
1540
+ rank_scalar ,
1541
+ ov_opset .constant (1 , Type .i32 ),
1542
+ output_type = Type .i32 ,
1543
+ )
1544
+ return ov_opset .gather (shape , indices , axis = 0 )
1545
+
1546
+ x = get_ov_output (x )
1547
+ ov_type = x .get_element_type ()
1548
+ input_shape = ov_opset .shape_of (x , Type .i32 )
1549
+ shape = get_shape_dims (x )
1550
+
1551
+ zero_const = ov_opset .constant (0 , Type .i32 )
1552
+ minus2 = ov_opset .constant ([- 2 ], Type .i32 )
1553
+ minus1 = ov_opset .constant ([- 1 ], Type .i32 )
1554
+
1555
+ M = ov_opset .squeeze (
1556
+ ov_opset .gather (shape , minus2 , zero_const ),
1557
+ ov_opset .constant ([0 ], Type .i32 ),
1558
+ )
1559
+ N = ov_opset .squeeze (
1560
+ ov_opset .gather (shape , minus1 , zero_const ),
1561
+ ov_opset .constant ([0 ], Type .i32 ),
1562
+ )
1563
+
1564
+ row_range = ov_opset .range (
1565
+ ov_opset .constant (0 , Type .i32 ),
1566
+ M ,
1567
+ ov_opset .constant (1 , Type .i32 ),
1568
+ output_type = Type .i32 ,
1569
+ )
1570
+ col_range = ov_opset .range (
1571
+ ov_opset .constant (0 , Type .i32 ),
1572
+ N ,
1573
+ ov_opset .constant (1 , Type .i32 ),
1574
+ output_type = Type .i32 ,
1575
+ )
1576
+
1577
+ row_idx = ov_opset .unsqueeze (row_range , ov_opset .constant ([1 ], Type .i32 ))
1578
+ col_idx = ov_opset .unsqueeze (col_range , ov_opset .constant ([0 ], Type .i32 ))
1579
+
1580
+ M_1d = ov_opset .unsqueeze (M , ov_opset .constant ([0 ], Type .i32 ))
1581
+ N_1d = ov_opset .unsqueeze (N , ov_opset .constant ([0 ], Type .i32 ))
1582
+ target_shape = ov_opset .concat ([M_1d , N_1d ], axis = 0 )
1583
+
1584
+ row_idx = ov_opset .broadcast (row_idx , target_shape )
1585
+ col_idx = ov_opset .broadcast (col_idx , target_shape )
1586
+
1587
+ k_const = ov_opset .constant (k , Type .i32 )
1588
+ mask = ov_opset .less_equal (col_idx , ov_opset .add (row_idx , k_const ))
1589
+ mask = ov_opset .convert (mask , ov_type )
1590
+
1591
+ shape_rank_tensor = ov_opset .shape_of (input_shape , Type .i32 )
1592
+ shape_rank = ov_opset .squeeze (
1593
+ shape_rank_tensor , ov_opset .constant ([0 ], Type .i32 )
1594
+ )
1595
+ batch_dims_count = ov_opset .subtract (
1596
+ shape_rank , ov_opset .constant (2 , Type .i32 )
1597
+ )
1598
+ batch_dims_count = ov_opset .squeeze (
1599
+ batch_dims_count , ov_opset .constant ([0 ], Type .i32 )
1600
+ )
1601
+
1602
+ batch_indices = ov_opset .range (
1603
+ start = ov_opset .constant (0 , Type .i32 ),
1604
+ stop = batch_dims_count ,
1605
+ step = ov_opset .constant (1 , Type .i32 ),
1606
+ output_type = Type .i32 ,
1607
+ )
1608
+
1609
+ batch_shape = ov_opset .gather (input_shape , batch_indices , axis = 0 )
1610
+ full_mask_shape = ov_opset .concat ([batch_shape , M_1d , N_1d ], axis = 0 )
1611
+ mask = ov_opset .broadcast (mask , full_mask_shape )
1612
+
1613
+ if ov_type == Type .boolean :
1614
+ out = ov_opset .logical_and (x , mask )
1615
+ else :
1616
+ out = ov_opset .multiply (x , mask )
1617
+ return OpenVINOKerasTensor (out .output (0 ))
1492
1618
1493
1619
1494
1620
def triu (x , k = 0 ):
1495
- raise NotImplementedError ("`triu` is not supported with openvino backend" )
1621
+ def get_shape_dims (x ):
1622
+ shape = ov_opset .shape_of (x , Type .i32 )
1623
+ rank_tensor = ov_opset .shape_of (shape , Type .i32 )
1624
+ rank_scalar = ov_opset .squeeze (
1625
+ rank_tensor , ov_opset .constant ([0 ], Type .i32 )
1626
+ )
1627
+ indices = ov_opset .range (
1628
+ ov_opset .constant (0 , Type .i32 ),
1629
+ rank_scalar ,
1630
+ ov_opset .constant (1 , Type .i32 ),
1631
+ output_type = Type .i32 ,
1632
+ )
1633
+ return ov_opset .gather (shape , indices , axis = 0 )
1634
+
1635
+ x = get_ov_output (x )
1636
+ ov_type = x .get_element_type ()
1637
+ input_shape = ov_opset .shape_of (x , Type .i32 )
1638
+ shape = get_shape_dims (x )
1639
+
1640
+ zero_const = ov_opset .constant (0 , Type .i32 )
1641
+ minus2 = ov_opset .constant ([- 2 ], Type .i32 )
1642
+ minus1 = ov_opset .constant ([- 1 ], Type .i32 )
1643
+
1644
+ M = ov_opset .squeeze (
1645
+ ov_opset .gather (shape , minus2 , zero_const ),
1646
+ ov_opset .constant ([0 ], Type .i32 ),
1647
+ )
1648
+ N = ov_opset .squeeze (
1649
+ ov_opset .gather (shape , minus1 , zero_const ),
1650
+ ov_opset .constant ([0 ], Type .i32 ),
1651
+ )
1652
+
1653
+ row_range = ov_opset .range (
1654
+ ov_opset .constant (0 , Type .i32 ),
1655
+ M ,
1656
+ ov_opset .constant (1 , Type .i32 ),
1657
+ output_type = Type .i32 ,
1658
+ )
1659
+ col_range = ov_opset .range (
1660
+ ov_opset .constant (0 , Type .i32 ),
1661
+ N ,
1662
+ ov_opset .constant (1 , Type .i32 ),
1663
+ output_type = Type .i32 ,
1664
+ )
1665
+
1666
+ row_idx = ov_opset .unsqueeze (row_range , ov_opset .constant ([1 ], Type .i32 ))
1667
+ col_idx = ov_opset .unsqueeze (col_range , ov_opset .constant ([0 ], Type .i32 ))
1668
+
1669
+ M_1d = ov_opset .unsqueeze (M , ov_opset .constant ([0 ], Type .i32 ))
1670
+ N_1d = ov_opset .unsqueeze (N , ov_opset .constant ([0 ], Type .i32 ))
1671
+ target_shape = ov_opset .concat ([M_1d , N_1d ], axis = 0 )
1672
+
1673
+ row_idx = ov_opset .broadcast (row_idx , target_shape )
1674
+ col_idx = ov_opset .broadcast (col_idx , target_shape )
1675
+
1676
+ k_const = ov_opset .constant (k , Type .i32 )
1677
+ mask = ov_opset .greater_equal (col_idx , ov_opset .add (row_idx , k_const ))
1678
+ mask = ov_opset .convert (mask , ov_type )
1679
+
1680
+ shape_rank_tensor = ov_opset .shape_of (input_shape , Type .i32 )
1681
+ shape_rank = ov_opset .squeeze (
1682
+ shape_rank_tensor , ov_opset .constant ([0 ], Type .i32 )
1683
+ )
1684
+ batch_dims_count = ov_opset .subtract (
1685
+ shape_rank , ov_opset .constant (2 , Type .i32 )
1686
+ )
1687
+ batch_dims_count = ov_opset .squeeze (
1688
+ batch_dims_count , ov_opset .constant ([0 ], Type .i32 )
1689
+ )
1690
+
1691
+ batch_indices = ov_opset .range (
1692
+ start = ov_opset .constant (0 , Type .i32 ),
1693
+ stop = batch_dims_count ,
1694
+ step = ov_opset .constant (1 , Type .i32 ),
1695
+ output_type = Type .i32 ,
1696
+ )
1697
+
1698
+ batch_shape = ov_opset .gather (input_shape , batch_indices , axis = 0 )
1699
+ full_mask_shape = ov_opset .concat ([batch_shape , M_1d , N_1d ], axis = 0 )
1700
+ mask = ov_opset .broadcast (mask , full_mask_shape )
1701
+
1702
+ if ov_type == Type .boolean :
1703
+ out = ov_opset .logical_and (x , mask )
1704
+ else :
1705
+ out = ov_opset .multiply (x , mask )
1706
+ return OpenVINOKerasTensor (out .output (0 ))
1496
1707
1497
1708
1498
1709
def vdot (x1 , x2 ):
0 commit comments