@@ -1492,7 +1492,90 @@ def tril(x, k=0):
1492
1492
1493
1493
1494
1494
def triu (x , k = 0 ):
1495
- raise NotImplementedError ("`triu` is not supported with openvino backend" )
1495
+ def get_shape_dims (x ):
1496
+ shape = ov_opset .shape_of (x , Type .i32 )
1497
+ rank_tensor = ov_opset .shape_of (shape , Type .i32 )
1498
+ rank_scalar = ov_opset .squeeze (
1499
+ rank_tensor , ov_opset .constant ([0 ], Type .i32 )
1500
+ )
1501
+ indices = ov_opset .range (
1502
+ ov_opset .constant (0 , Type .i32 ),
1503
+ rank_scalar ,
1504
+ ov_opset .constant (1 , Type .i32 ),
1505
+ output_type = Type .i32 ,
1506
+ )
1507
+ return ov_opset .gather (shape , indices , axis = 0 )
1508
+
1509
+ x = get_ov_output (x )
1510
+ ov_type = x .get_element_type ()
1511
+ input_shape = ov_opset .shape_of (x , Type .i32 )
1512
+ shape = get_shape_dims (x )
1513
+
1514
+ zero_const = ov_opset .constant (0 , Type .i32 )
1515
+ minus2 = ov_opset .constant ([- 2 ], Type .i32 )
1516
+ minus1 = ov_opset .constant ([- 1 ], Type .i32 )
1517
+
1518
+ M = ov_opset .squeeze (
1519
+ ov_opset .gather (shape , minus2 , zero_const ),
1520
+ ov_opset .constant ([0 ], Type .i32 ),
1521
+ )
1522
+ N = ov_opset .squeeze (
1523
+ ov_opset .gather (shape , minus1 , zero_const ),
1524
+ ov_opset .constant ([0 ], Type .i32 ),
1525
+ )
1526
+
1527
+ row_range = ov_opset .range (
1528
+ ov_opset .constant (0 , Type .i32 ),
1529
+ M ,
1530
+ ov_opset .constant (1 , Type .i32 ),
1531
+ output_type = Type .i32 ,
1532
+ )
1533
+ col_range = ov_opset .range (
1534
+ ov_opset .constant (0 , Type .i32 ),
1535
+ N ,
1536
+ ov_opset .constant (1 , Type .i32 ),
1537
+ output_type = Type .i32 ,
1538
+ )
1539
+
1540
+ row_idx = ov_opset .unsqueeze (row_range , ov_opset .constant ([1 ], Type .i32 ))
1541
+ col_idx = ov_opset .unsqueeze (col_range , ov_opset .constant ([0 ], Type .i32 ))
1542
+
1543
+ M_1d = ov_opset .unsqueeze (M , ov_opset .constant ([0 ], Type .i32 ))
1544
+ N_1d = ov_opset .unsqueeze (N , ov_opset .constant ([0 ], Type .i32 ))
1545
+ target_shape = ov_opset .concat ([M_1d , N_1d ], axis = 0 )
1546
+
1547
+ row_idx = ov_opset .broadcast (row_idx , target_shape )
1548
+ col_idx = ov_opset .broadcast (col_idx , target_shape )
1549
+
1550
+ k_const = ov_opset .constant (k , Type .i32 )
1551
+ mask = ov_opset .greater_equal (col_idx , ov_opset .add (row_idx , k_const ))
1552
+ mask = ov_opset .convert (mask , ov_type )
1553
+
1554
+ shape_rank_tensor = ov_opset .shape_of (input_shape , Type .i32 )
1555
+ shape_rank = ov_opset .squeeze (
1556
+ shape_rank_tensor , ov_opset .constant ([0 ], Type .i32 )
1557
+ )
1558
+ batch_dims_count = ov_opset .subtract (
1559
+ shape_rank , ov_opset .constant (2 , Type .i32 )
1560
+ )
1561
+ batch_dims_count = ov_opset .squeeze (
1562
+ batch_dims_count , ov_opset .constant ([0 ], Type .i32 )
1563
+ )
1564
+
1565
+ batch_indices = ov_opset .range (
1566
+ start = ov_opset .constant (0 , Type .i32 ),
1567
+ stop = batch_dims_count ,
1568
+ step = ov_opset .constant (1 , Type .i32 ),
1569
+ output_type = Type .i32 ,
1570
+ )
1571
+
1572
+ batch_shape = ov_opset .gather (input_shape , batch_indices , axis = 0 )
1573
+ full_mask_shape = ov_opset .concat ([batch_shape , M_1d , N_1d ], axis = 0 )
1574
+ mask = ov_opset .broadcast (mask , full_mask_shape )
1575
+ if ov_type == Type .boolean :
1576
+ mask = ov_opset .convert (mask , Type .f32 )
1577
+ x = ov_opset .convert (x , Type .f32 )
1578
+ return OpenVINOKerasTensor (ov_opset .multiply (x , mask ).output (0 ))
1496
1579
1497
1580
1498
1581
def vdot (x1 , x2 ):
0 commit comments