@@ -141,17 +141,16 @@ def {scalar_op_fn_name}({', '.join(input_names)}):
141
141
)(scalar_op_fn )
142
142
143
143
144
- @numba_basic .numba_njit
145
- def switch (condition , x , y ):
146
- if condition :
147
- return x
148
- else :
149
- return y
150
-
151
-
152
144
@numba_funcify .register (Switch )
153
145
def numba_funcify_Switch (op , node , ** kwargs ):
154
- return numba_basic .global_numba_func (switch )
146
+ @numba_basic .numba_njit
147
+ def switch (condition , x , y ):
148
+ if condition :
149
+ return x
150
+ else :
151
+ return y
152
+
153
+ return switch
155
154
156
155
157
156
def binary_to_nary_func (inputs : list [Variable ], binary_op_name : str , binary_op : str ):
@@ -197,34 +196,32 @@ def cast(x):
197
196
return cast
198
197
199
198
200
- @numba_basic .numba_njit
201
- def identity (x ):
202
- return x
203
-
204
-
205
199
@numba_funcify .register (Identity )
206
200
@numba_funcify .register (TypeCastingOp )
207
201
def numba_funcify_type_casting (op , ** kwargs ):
208
- return numba_basic .global_numba_func (identity )
209
-
210
-
211
- @numba_basic .numba_njit
212
- def clip (_x , _min , _max ):
213
- x = numba_basic .to_scalar (_x )
214
- _min_scalar = numba_basic .to_scalar (_min )
215
- _max_scalar = numba_basic .to_scalar (_max )
216
-
217
- if x < _min_scalar :
218
- return _min_scalar
219
- elif x > _max_scalar :
220
- return _max_scalar
221
- else :
202
+ @numba_basic .numba_njit
203
+ def identity (x ):
222
204
return x
223
205
206
+ return identity
207
+
224
208
225
209
@numba_funcify .register (Clip )
226
210
def numba_funcify_Clip (op , ** kwargs ):
227
- return numba_basic .global_numba_func (clip )
211
+ @numba_basic .numba_njit
212
+ def clip (x , min_val , max_val ):
213
+ x = numba_basic .to_scalar (x )
214
+ min_scalar = numba_basic .to_scalar (min_val )
215
+ max_scalar = numba_basic .to_scalar (max_val )
216
+
217
+ if x < min_scalar :
218
+ return min_scalar
219
+ elif x > max_scalar :
220
+ return max_scalar
221
+ else :
222
+ return x
223
+
224
+ return clip
228
225
229
226
230
227
@numba_funcify .register (Composite )
@@ -239,79 +236,72 @@ def numba_funcify_Composite(op, node, **kwargs):
239
236
return composite_fn
240
237
241
238
242
- @numba_basic .numba_njit
243
- def second (x , y ):
244
- return y
245
-
246
-
247
239
@numba_funcify .register (Second )
248
240
def numba_funcify_Second (op , node , ** kwargs ):
249
- return numba_basic .global_numba_func (second )
250
-
241
+ @numba_basic .numba_njit
242
+ def second (x , y ):
243
+ return y
251
244
252
- @numba_basic .numba_njit
253
- def reciprocal (x ):
254
- # TODO FIXME: This isn't really the behavior or `numpy.reciprocal` when
255
- # `x` is an `int`
256
- return 1 / x
245
+ return second
257
246
258
247
259
248
@numba_funcify .register (Reciprocal )
260
249
def numba_funcify_Reciprocal (op , node , ** kwargs ):
261
- return numba_basic .global_numba_func (reciprocal )
262
-
250
+ @numba_basic .numba_njit
251
+ def reciprocal (x ):
252
+ # TODO FIXME: This isn't really the behavior or `numpy.reciprocal` when
253
+ # `x` is an `int`
254
+ return 1 / x
263
255
264
- @numba_basic .numba_njit
265
- def sigmoid (x ):
266
- return 1 / (1 + np .exp (- x ))
256
+ return reciprocal
267
257
268
258
269
259
@numba_funcify .register (Sigmoid )
270
260
def numba_funcify_Sigmoid (op , node , ** kwargs ):
271
- return numba_basic .global_numba_func (sigmoid )
272
-
261
+ @numba_basic .numba_njit
262
+ def sigmoid (x ):
263
+ return 1 / (1 + np .exp (- x ))
273
264
274
- @numba_basic .numba_njit
275
- def gammaln (x ):
276
- return math .lgamma (x )
265
+ return sigmoid
277
266
278
267
279
268
@numba_funcify .register (GammaLn )
280
269
def numba_funcify_GammaLn (op , node , ** kwargs ):
281
- return numba_basic .global_numba_func (gammaln )
282
-
270
+ @numba_basic .numba_njit
271
+ def gammaln (x ):
272
+ return math .lgamma (x )
283
273
284
- @numba_basic .numba_njit
285
- def logp1mexp (x ):
286
- if x < np .log (0.5 ):
287
- return np .log1p (- np .exp (x ))
288
- else :
289
- return np .log (- np .expm1 (x ))
274
+ return gammaln
290
275
291
276
292
277
@numba_funcify .register (Log1mexp )
293
278
def numba_funcify_Log1mexp (op , node , ** kwargs ):
294
- return numba_basic .global_numba_func (logp1mexp )
295
-
279
+ @numba_basic .numba_njit
280
+ def logp1mexp (x ):
281
+ if x < np .log (0.5 ):
282
+ return np .log1p (- np .exp (x ))
283
+ else :
284
+ return np .log (- np .expm1 (x ))
296
285
297
- @numba_basic .numba_njit
298
- def erf (x ):
299
- return math .erf (x )
286
+ return logp1mexp
300
287
301
288
302
289
@numba_funcify .register (Erf )
303
290
def numba_funcify_Erf (op , ** kwargs ):
304
- return numba_basic .global_numba_func (erf )
305
-
291
+ @numba_basic .numba_njit
292
+ def erf (x ):
293
+ return math .erf (x )
306
294
307
- @numba_basic .numba_njit
308
- def erfc (x ):
309
- return math .erfc (x )
295
+ return erf
310
296
311
297
312
298
@numba_funcify .register (Erfc )
313
299
def numba_funcify_Erfc (op , ** kwargs ):
314
- return numba_basic .global_numba_func (erfc )
300
+ @numba_basic .numba_njit
301
+ def erfc (x ):
302
+ return math .erfc (x )
303
+
304
+ return erfc
315
305
316
306
317
307
@numba_funcify .register (Softplus )
0 commit comments