2121import pytensor .scalar .sharedvar
2222from pytensor import compile , config , printing
2323from pytensor import scalar as ps
24+ from pytensor .compile .builders import OpFromGraph
2425from pytensor .gradient import DisconnectedType , grad_undefined
2526from pytensor .graph import RewriteDatabaseQuery
2627from pytensor .graph .basic import Apply , Constant , Variable , equal_computations
@@ -1334,6 +1335,25 @@ def infer_shape(self, fgraph, node, in_shapes):
13341335 def grad (self , inp , grads ):
13351336 return [grad_undefined (self , i , inp [i ]) for i in range (3 )]
13361337
1338+ @staticmethod
1339+ def is_offset_zero (node ) -> bool :
1340+ """
1341+ Test if an Eye Op has a diagonal offset of zero
1342+
1343+ Parameters
1344+ ----------
1345+ node
1346+ Eye node to test
1347+
1348+ Returns
1349+ -------
1350+ is_offset_zero: bool
1351+ True if the offset is zero (``k = 0``).
1352+ """
1353+
1354+ offset = node .inputs [- 1 ]
1355+ return isinstance (offset , Constant ) and offset .data .item () == 0
1356+
13371357
13381358def eye (n , m = None , k = 0 , dtype = None ):
13391359 """Return a 2-D array with ones on the diagonal and zeros elsewhere.
@@ -3749,109 +3769,37 @@ def trace(a, offset=0, axis1=0, axis2=1):
37493769 return diagonal (a , offset = offset , axis1 = axis1 , axis2 = axis2 ).sum (- 1 )
37503770
37513771
3752- class AllocDiag (Op ):
3753- """An `Op` that copies a vector to the diagonal of a zero-ed matrix."""
3772+ class AllocDiag (OpFromGraph ):
3773+ """
3774+ Wrapper Op for alloc_diag graphs
3775+ """
37543776
3755- __props__ = ("offset" , " axis1" , "axis2" )
3777+ __props__ = ("axis1" , "axis2" )
37563778
3757- def __init__ (self , offset = 0 , axis1 = 0 , axis2 = 1 ):
3758- """
3759- Parameters
3760- ----------
3761- offset: int
3762- Offset of the diagonal from the main diagonal defined by `axis1`
3763- and `axis2`. Can be positive or negative. Defaults to main
3764- diagonal (i.e. 0).
3765- axis1: int
3766- Axis to be used as the first axis of the 2-D sub-arrays to which
3767- the diagonals will be allocated. Defaults to first axis (i.e. 0).
3768- axis2: int
3769- Axis to be used as the second axis of the 2-D sub-arrays to which
3770- the diagonals will be allocated. Defaults to second axis (i.e. 1).
3771- """
3772- warnings .warn (
3773- "AllocDiag is deprecated. Use `alloc_diag` instead" ,
3774- FutureWarning ,
3775- )
3776- self .offset = offset
3777- if axis1 < 0 or axis2 < 0 :
3778- raise NotImplementedError ("AllocDiag does not support negative axis" )
3779- if axis1 == axis2 :
3780- raise ValueError ("axis1 and axis2 cannot be the same" )
3779+ def __init__ (self , * args , axis1 , axis2 , offset , ** kwargs ):
37813780 self .axis1 = axis1
37823781 self .axis2 = axis2
3782+ self .offset = offset
37833783
3784- def make_node (self , diag ):
3785- diag = as_tensor_variable (diag )
3786- if diag .type .ndim < 1 :
3787- raise ValueError (
3788- "AllocDiag needs an input with 1 or more dimensions" , diag .type
3789- )
3790- return Apply (
3791- self ,
3792- [diag ],
3793- [diag .type .clone (shape = (None ,) * (diag .ndim + 1 ))()],
3794- )
3795-
3796- def perform (self , node , inputs , outputs ):
3797- (x ,) = inputs
3798- (z ,) = outputs
3799-
3800- axis1 = np .minimum (self .axis1 , self .axis2 )
3801- axis2 = np .maximum (self .axis1 , self .axis2 )
3802- offset = self .offset
3803-
3804- # Create array with one extra dimension for resulting matrix
3805- result_shape = x .shape [:- 1 ] + (x .shape [- 1 ] + abs (offset ),) * 2
3806- result = np .zeros (result_shape , dtype = x .dtype )
3807-
3808- # Create slice for diagonal in final 2 axes
3809- idxs = np .arange (x .shape [- 1 ])
3810- diagonal_slice = (len (result_shape ) - 2 ) * [slice (None )] + [
3811- idxs + np .maximum (0 , - offset ),
3812- idxs + np .maximum (0 , offset ),
3813- ]
3814-
3815- # Fill in final 2 axes with x
3816- result [tuple (diagonal_slice )] = x
3817-
3818- if len (x .shape ) > 1 :
3819- # Re-order axes so they correspond to diagonals at axis1, axis2
3820- axes = list (range (len (x .shape [:- 1 ])))
3821- last_idx = axes [- 1 ]
3822- axes = axes [:axis1 ] + [last_idx + 1 ] + axes [axis1 :]
3823- axes = axes [:axis2 ] + [last_idx + 2 ] + axes [axis2 :]
3824- result = result .transpose (axes )
3825-
3826- z [0 ] = result
3827-
3828- def grad (self , inputs , gout ):
3829- (gz ,) = gout
3830- return [diagonal (gz , offset = self .offset , axis1 = self .axis1 , axis2 = self .axis2 )]
3831-
3832- def infer_shape (self , fgraph , nodes , shapes ):
3833- (x_shape ,) = shapes
3834- axis1 = np .minimum (self .axis1 , self .axis2 )
3835- axis2 = np .maximum (self .axis1 , self .axis2 )
3784+ super ().__init__ (* args , ** kwargs , strict = True )
38363785
3837- result_shape = list (x_shape [:- 1 ])
3838- diag_shape = x_shape [- 1 ] + abs (self .offset )
3839- result_shape = result_shape [:axis1 ] + [diag_shape ] + result_shape [axis1 :]
3840- result_shape = result_shape [:axis2 ] + [diag_shape ] + result_shape [axis2 :]
3841- return [tuple (result_shape )]
3786+ @staticmethod
3787+ def is_offset_zero (node ) -> bool :
3788+ """
3789+ Test if an AllocDiag Op has a diagonal offset of zero
38423790
3843- def __setstate__ (self , state ):
3844- if "view_map" in state :
3845- del state ["view_map" ]
3791+ Parameters
3792+ ----------
3793+ node
3794+ AllocDiag node to test
38463795
3847- self .__dict__ .update (state )
3796+ Returns
3797+ -------
3798+ is_offset_zero: bool
3799+ True if the offset is zero (``k = 0``).
3800+ """
38483801
3849- if "offset" not in state :
3850- self .offset = 0
3851- if "axis1" not in state :
3852- self .axis1 = 0
3853- if "axis2" not in state :
3854- self .axis2 = 1
3802+ return node .op .offset == 0
38553803
38563804
38573805def alloc_diag (diag , offset = 0 , axis1 = 0 , axis2 = 1 ):
@@ -3862,6 +3810,7 @@ def alloc_diag(diag, offset=0, axis1=0, axis2=1):
38623810 from pytensor .tensor import set_subtensor
38633811
38643812 diag = as_tensor_variable (diag )
3813+
38653814 axis1 , axis2 = normalize_axis_tuple ((axis1 , axis2 ), ndim = diag .type .ndim + 1 )
38663815 if axis1 > axis2 :
38673816 axis1 , axis2 = axis2 , axis1
@@ -3888,7 +3837,9 @@ def alloc_diag(diag, offset=0, axis1=0, axis2=1):
38883837 axes = axes [:axis2 ] + [last_idx + 2 ] + axes [axis2 :]
38893838 result = result .transpose (axes )
38903839
3891- return result
3840+ return AllocDiag (
3841+ inputs = [diag ], outputs = [result ], axis1 = axis1 , axis2 = axis2 , offset = offset
3842+ )(diag )
38923843
38933844
38943845def diag (v , k = 0 ):
0 commit comments