@@ -1654,15 +1654,94 @@ def trace(x, offset=0, axis1=0, axis2=1):
1654
1654
1655
1655
1656
1656
def tri (N , M = None , k = 0 , dtype = None ):
1657
- raise NotImplementedError ("`tri` is not supported with openvino backend" )
1657
+ if M is None :
1658
+ M = N
1659
+ if dtype is None :
1660
+ dtype = "float32"
1661
+
1662
+ ov_dtype = OPENVINO_DTYPES [dtype ]
1663
+
1664
+ def ensure_constant (value , default_type = Type .i32 ):
1665
+ if isinstance (value , (int , float )):
1666
+ return ov_opset .constant (value , default_type )
1667
+ elif hasattr (value , "get_element_type" ):
1668
+ if value .get_element_type () != Type .i32 :
1669
+ value = ov_opset .convert (value , Type .i32 )
1670
+ return ov_opset .squeeze (value , ov_opset .constant ([0 ], Type .i32 ))
1671
+ else :
1672
+ return ov_opset .constant (value , default_type )
1673
+
1674
+ N_const = ensure_constant (N )
1675
+ M_const = ensure_constant (M )
1676
+ k_const = ensure_constant (k )
1677
+
1678
+ # Create row and column indices
1679
+ row_range = ov_opset .range (
1680
+ ov_opset .constant (0 , Type .i32 ),
1681
+ N_const ,
1682
+ ov_opset .constant (1 , Type .i32 ),
1683
+ output_type = Type .i32 ,
1684
+ )
1685
+ col_range = ov_opset .range (
1686
+ ov_opset .constant (0 , Type .i32 ),
1687
+ M_const ,
1688
+ ov_opset .constant (1 , Type .i32 ),
1689
+ output_type = Type .i32 ,
1690
+ )
1691
+
1692
+ # Reshape indices for broadcasting
1693
+ row_idx = ov_opset .unsqueeze (row_range , ov_opset .constant ([1 ], Type .i32 ))
1694
+ col_idx = ov_opset .unsqueeze (col_range , ov_opset .constant ([0 ], Type .i32 ))
1695
+
1696
+ mask = ov_opset .less_equal (col_idx , ov_opset .add (row_idx , k_const ))
1697
+
1698
+ if ov_dtype == Type .boolean :
1699
+ result = mask
1700
+ else :
1701
+ result = ov_opset .convert (mask , ov_dtype )
1702
+
1703
+ return OpenVINOKerasTensor (result .output (0 ))
1658
1704
1659
1705
1660
1706
def tril (x , k = 0 ):
1661
- raise NotImplementedError ("`tril` is not supported with openvino backend" )
1707
+ x = get_ov_output (x )
1708
+ ov_type = x .get_element_type ()
1709
+ shape = ov_opset .shape_of (x , Type .i32 )
1710
+ zero_const = ov_opset .constant (0 , Type .i32 )
1711
+ minus2 = ov_opset .constant ([- 2 ], Type .i32 )
1712
+ minus1 = ov_opset .constant ([- 1 ], Type .i32 )
1713
+ M = ov_opset .squeeze (ov_opset .gather (shape , minus2 , zero_const ), zero_const )
1714
+ N = ov_opset .squeeze (ov_opset .gather (shape , minus1 , zero_const ), zero_const )
1715
+ tri_mask = tri (M , N , k = k , dtype = "bool" ).output
1716
+ mask = ov_opset .convert (tri_mask , ov_type )
1717
+ if ov_type == Type .boolean :
1718
+ out = ov_opset .logical_and (x , mask )
1719
+ else :
1720
+ out = ov_opset .multiply (x , mask )
1721
+ return OpenVINOKerasTensor (out .output (0 ))
1662
1722
1663
1723
1664
1724
def triu (x , k = 0 ):
1665
- raise NotImplementedError ("`triu` is not supported with openvino backend" )
1725
+ x = get_ov_output (x )
1726
+ ov_type = x .get_element_type ()
1727
+ shape = ov_opset .shape_of (x , Type .i32 )
1728
+ zero_const = ov_opset .constant (0 , Type .i32 )
1729
+ minus2 = ov_opset .constant ([- 2 ], Type .i32 )
1730
+ minus1 = ov_opset .constant ([- 1 ], Type .i32 )
1731
+ M = ov_opset .squeeze (ov_opset .gather (shape , minus2 , zero_const ), zero_const )
1732
+ N = ov_opset .squeeze (ov_opset .gather (shape , minus1 , zero_const ), zero_const )
1733
+ tri_mask = tri (M , N , k = k - 1 , dtype = "bool" ).output
1734
+ if ov_type == Type .boolean :
1735
+ mask = ov_opset .logical_not (tri_mask )
1736
+ else :
1737
+ const_one = ov_opset .constant (1 , ov_type )
1738
+ converted_mask = ov_opset .convert (tri_mask , ov_type )
1739
+ mask = ov_opset .subtract (const_one , converted_mask )
1740
+ if ov_type == Type .boolean :
1741
+ out = ov_opset .logical_and (x , mask )
1742
+ else :
1743
+ out = ov_opset .multiply (x , mask )
1744
+ return OpenVINOKerasTensor (out .output (0 ))
1666
1745
1667
1746
1668
1747
def vdot (x1 , x2 ):
0 commit comments