@@ -1494,15 +1494,94 @@ def trace(x, offset=0, axis1=0, axis2=1):
1494
1494
1495
1495
1496
1496
def tri (N , M = None , k = 0 , dtype = None ):
1497
- raise NotImplementedError ("`tri` is not supported with openvino backend" )
1497
+ if M is None :
1498
+ M = N
1499
+ if dtype is None :
1500
+ dtype = "float32"
1501
+
1502
+ ov_dtype = OPENVINO_DTYPES [dtype ]
1503
+
1504
+ def ensure_constant (value , default_type = Type .i32 ):
1505
+ if isinstance (value , (int , float )):
1506
+ return ov_opset .constant (value , default_type )
1507
+ elif hasattr (value , "get_element_type" ):
1508
+ if value .get_element_type () != Type .i32 :
1509
+ value = ov_opset .convert (value , Type .i32 )
1510
+ return ov_opset .squeeze (value , ov_opset .constant ([0 ], Type .i32 ))
1511
+ else :
1512
+ return ov_opset .constant (value , default_type )
1513
+
1514
+ N_const = ensure_constant (N )
1515
+ M_const = ensure_constant (M )
1516
+ k_const = ensure_constant (k )
1517
+
1518
+ # Create row and column indices
1519
+ row_range = ov_opset .range (
1520
+ ov_opset .constant (0 , Type .i32 ),
1521
+ N_const ,
1522
+ ov_opset .constant (1 , Type .i32 ),
1523
+ output_type = Type .i32 ,
1524
+ )
1525
+ col_range = ov_opset .range (
1526
+ ov_opset .constant (0 , Type .i32 ),
1527
+ M_const ,
1528
+ ov_opset .constant (1 , Type .i32 ),
1529
+ output_type = Type .i32 ,
1530
+ )
1531
+
1532
+ # Reshape indices for broadcasting
1533
+ row_idx = ov_opset .unsqueeze (row_range , ov_opset .constant ([1 ], Type .i32 ))
1534
+ col_idx = ov_opset .unsqueeze (col_range , ov_opset .constant ([0 ], Type .i32 ))
1535
+
1536
+ mask = ov_opset .less_equal (col_idx , ov_opset .add (row_idx , k_const ))
1537
+
1538
+ if ov_dtype == Type .boolean :
1539
+ result = mask
1540
+ else :
1541
+ result = ov_opset .convert (mask , ov_dtype )
1542
+
1543
+ return OpenVINOKerasTensor (result .output (0 ))
1498
1544
1499
1545
1500
1546
def tril (x , k = 0 ):
1501
- raise NotImplementedError ("`tril` is not supported with openvino backend" )
1547
+ x = get_ov_output (x )
1548
+ ov_type = x .get_element_type ()
1549
+ shape = ov_opset .shape_of (x , Type .i32 )
1550
+ zero_const = ov_opset .constant (0 , Type .i32 )
1551
+ minus2 = ov_opset .constant ([- 2 ], Type .i32 )
1552
+ minus1 = ov_opset .constant ([- 1 ], Type .i32 )
1553
+ M = ov_opset .squeeze (ov_opset .gather (shape , minus2 , zero_const ), zero_const )
1554
+ N = ov_opset .squeeze (ov_opset .gather (shape , minus1 , zero_const ), zero_const )
1555
+ tri_mask = tri (M , N , k = k , dtype = "bool" ).output
1556
+ mask = ov_opset .convert (tri_mask , ov_type )
1557
+ if ov_type == Type .boolean :
1558
+ out = ov_opset .logical_and (x , mask )
1559
+ else :
1560
+ out = ov_opset .multiply (x , mask )
1561
+ return OpenVINOKerasTensor (out .output (0 ))
1502
1562
1503
1563
1504
1564
def triu (x , k = 0 ):
1505
- raise NotImplementedError ("`triu` is not supported with openvino backend" )
1565
+ x = get_ov_output (x )
1566
+ ov_type = x .get_element_type ()
1567
+ shape = ov_opset .shape_of (x , Type .i32 )
1568
+ zero_const = ov_opset .constant (0 , Type .i32 )
1569
+ minus2 = ov_opset .constant ([- 2 ], Type .i32 )
1570
+ minus1 = ov_opset .constant ([- 1 ], Type .i32 )
1571
+ M = ov_opset .squeeze (ov_opset .gather (shape , minus2 , zero_const ), zero_const )
1572
+ N = ov_opset .squeeze (ov_opset .gather (shape , minus1 , zero_const ), zero_const )
1573
+ tri_mask = tri (M , N , k = k - 1 , dtype = "bool" ).output
1574
+ if ov_type == Type .boolean :
1575
+ mask = ov_opset .logical_not (tri_mask )
1576
+ else :
1577
+ const_one = ov_opset .constant (1 , ov_type )
1578
+ converted_mask = ov_opset .convert (tri_mask , ov_type )
1579
+ mask = ov_opset .subtract (const_one , converted_mask )
1580
+ if ov_type == Type .boolean :
1581
+ out = ov_opset .logical_and (x , mask )
1582
+ else :
1583
+ out = ov_opset .multiply (x , mask )
1584
+ return OpenVINOKerasTensor (out .output (0 ))
1506
1585
1507
1586
1508
1587
def vdot (x1 , x2 ):
0 commit comments