@@ -89,26 +89,48 @@ def log_jac_det(self, value, *inputs):
8989
9090
9191class Ordered (Transform ):
92+ """
93+ Transforms a vector of values into a vector of ordered values.
94+
95+ Parameters
96+ ----------
97+ positive: If True, all values are positive. This has better geometry than just chaining with a log transform.
98+ ascending: If True, the values are in ascending order (default). If False, the values are in descending order.
99+ """
100+
92101 name = "ordered"
93102
94- def __init__ (self , ndim_supp = None ):
103+ def __init__ (self , ndim_supp = None , positive = False , ascending = True ):
95104 if ndim_supp is not None :
96105 warnings .warn ("ndim_supp argument is deprecated and has no effect" , FutureWarning )
106+ self .positive = positive
107+ self .ascending = ascending
97108
98109 def backward (self , value , * inputs ):
99- x = pt .zeros (value .shape )
100- x = pt .set_subtensor (x [..., 0 ], value [..., 0 ])
101- x = pt .set_subtensor (x [..., 1 :], pt .exp (value [..., 1 :]))
102- return pt .cumsum (x , axis = - 1 )
110+ if self .positive : # Transform both initial value and deltas to be positive
111+ x = pt .exp (value )
112+ else : # Transform only deltas to be positive
113+ x = pt .empty (value .shape )
114+ x = pt .set_subtensor (x [..., 0 ], value [..., 0 ])
115+ x = pt .set_subtensor (x [..., 1 :], pt .exp (value [..., 1 :]))
116+ x = pt .cumsum (x , axis = - 1 ) # Add deltas cumulatively to initial value
117+ if not self .ascending :
118+ x = x [..., ::- 1 ]
119+ return x
103120
104121 def forward (self , value , * inputs ):
105- y = pt .zeros (value .shape )
106- y = pt .set_subtensor (y [..., 0 ], value [..., 0 ])
122+ if not self .ascending :
123+ value = value [..., ::- 1 ]
124+ y = pt .empty (value .shape )
125+ y = pt .set_subtensor (y [..., 0 ], pt .log (value [..., 0 ]) if self .positive else value [..., 0 ])
107126 y = pt .set_subtensor (y [..., 1 :], pt .log (value [..., 1 :] - value [..., :- 1 ]))
108127 return y
109128
110129 def log_jac_det (self , value , * inputs ):
111- return pt .sum (value [..., 1 :], axis = - 1 )
130+ if self .positive :
131+ return pt .sum (value , axis = - 1 )
132+ else :
133+ return pt .sum (value [..., 1 :], axis = - 1 )
112134
113135
114136class SumTo1 (Transform ):
0 commit comments