@@ -1523,14 +1523,6 @@ def ensure_constant(value, default_type=Type.i32):
1523
1523
row_idx = ov_opset .unsqueeze (row_range , ov_opset .constant ([1 ], Type .i32 ))
1524
1524
col_idx = ov_opset .unsqueeze (col_range , ov_opset .constant ([0 ], Type .i32 ))
1525
1525
1526
- # Target shape for broadcasting
1527
- target_shape = ov_opset .concat (
1528
- [ov_opset .unsqueeze (N_const , [0 ]), ov_opset .unsqueeze (M_const , [0 ])],
1529
- axis = 0 ,
1530
- )
1531
-
1532
- row_idx = ov_opset .broadcast (row_idx , target_shape )
1533
- col_idx = ov_opset .broadcast (col_idx , target_shape )
1534
1526
mask = ov_opset .less_equal (col_idx , ov_opset .add (row_idx , k_const ))
1535
1527
1536
1528
if ov_dtype == Type .boolean :
@@ -1558,7 +1550,6 @@ def get_shape_dims(x):
1558
1550
1559
1551
x = get_ov_output (x )
1560
1552
ov_type = x .get_element_type ()
1561
- input_shape = ov_opset .shape_of (x , Type .i32 )
1562
1553
shape = get_shape_dims (x )
1563
1554
zero_const = ov_opset .constant (0 , Type .i32 )
1564
1555
minus2 = ov_opset .constant ([- 2 ], Type .i32 )
@@ -1571,31 +1562,6 @@ def get_shape_dims(x):
1571
1562
1572
1563
mask = ov_opset .convert (tri_mask , ov_type )
1573
1564
1574
- # Broadcast mask to input shape (including batch dims)
1575
- shape_rank = ov_opset .squeeze (
1576
- ov_opset .shape_of (input_shape , Type .i32 ), zero_const
1577
- )
1578
- batch_dims = ov_opset .subtract (shape_rank , ov_opset .constant (2 , Type .i32 ))
1579
- batch_indices = ov_opset .range (
1580
- zero_const ,
1581
- batch_dims ,
1582
- ov_opset .constant (1 , Type .i32 ),
1583
- output_type = Type .i32 ,
1584
- )
1585
- batch_shape = ov_opset .gather (input_shape , batch_indices , zero_const )
1586
-
1587
- M_reshaped = ov_opset .unsqueeze (M , zero_const )
1588
- N_reshaped = ov_opset .unsqueeze (N , zero_const )
1589
-
1590
- concat_inputs = [
1591
- batch_shape .output (0 ),
1592
- M_reshaped .output (0 ),
1593
- N_reshaped .output (0 ),
1594
- ]
1595
-
1596
- full_mask_shape = ov_opset .concat (concat_inputs , axis = 0 )
1597
- mask = ov_opset .broadcast (mask , full_mask_shape )
1598
-
1599
1565
if ov_type == Type .boolean :
1600
1566
out = ov_opset .logical_and (x , mask )
1601
1567
else :
@@ -1621,7 +1587,6 @@ def get_shape_dims(x):
1621
1587
1622
1588
x = get_ov_output (x )
1623
1589
ov_type = x .get_element_type ()
1624
- input_shape = ov_opset .shape_of (x , Type .i32 )
1625
1590
shape = get_shape_dims (x )
1626
1591
zero_const = ov_opset .constant (0 , Type .i32 )
1627
1592
minus2 = ov_opset .constant ([- 2 ], Type .i32 )
@@ -1631,7 +1596,6 @@ def get_shape_dims(x):
1631
1596
1632
1597
tri_mask = tri (M , N , k = k - 1 , dtype = "bool" ).output
1633
1598
1634
- # Handle boolean type differently since subtract doesn't work with boolean
1635
1599
if ov_type == Type .boolean :
1636
1600
mask = ov_opset .logical_not (tri_mask )
1637
1601
else :
@@ -1640,32 +1604,6 @@ def get_shape_dims(x):
1640
1604
)
1641
1605
mask = ov_opset .subtract (ones , ov_opset .convert (tri_mask , ov_type ))
1642
1606
1643
- # Broadcast mask
1644
- shape_rank = ov_opset .squeeze (
1645
- ov_opset .shape_of (input_shape , Type .i32 ), zero_const
1646
- )
1647
- batch_dims = ov_opset .subtract (shape_rank , ov_opset .constant (2 , Type .i32 ))
1648
- batch_indices = ov_opset .range (
1649
- zero_const ,
1650
- batch_dims ,
1651
- ov_opset .constant (1 , Type .i32 ),
1652
- output_type = Type .i32 ,
1653
- )
1654
- batch_shape = ov_opset .gather (input_shape , batch_indices , zero_const )
1655
-
1656
- # Ensure all tensors are properly shaped before concat
1657
- M_reshaped = ov_opset .unsqueeze (M , zero_const )
1658
- N_reshaped = ov_opset .unsqueeze (N , zero_const )
1659
-
1660
- concat_inputs = [
1661
- batch_shape .output (0 ),
1662
- M_reshaped .output (0 ),
1663
- N_reshaped .output (0 ),
1664
- ]
1665
-
1666
- full_mask_shape = ov_opset .concat (concat_inputs , axis = 0 )
1667
- mask = ov_opset .broadcast (mask , full_mask_shape )
1668
-
1669
1607
if ov_type == Type .boolean :
1670
1608
out = ov_opset .logical_and (x , mask )
1671
1609
else :
0 commit comments