Skip to content

Commit b4522d2

Browse files
authored
Remove uses of numba_basic.global_numba_func
1 parent 21218d7 commit b4522d2

File tree

2 files changed

+69
-81
lines changed

2 files changed

+69
-81
lines changed

pytensor/link/numba/dispatch/basic.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -402,24 +402,22 @@ def numba_funcify_DeepCopyOp(op, node, **kwargs):
402402
return deepcopyop
403403

404404

405-
@numba_njit
406-
def makeslice(*x):
407-
return slice(*x)
408-
409-
410405
@numba_funcify.register(MakeSlice)
411406
def numba_funcify_MakeSlice(op, **kwargs):
412-
return global_numba_func(makeslice)
413-
407+
@numba_njit
408+
def makeslice(*x):
409+
return slice(*x)
414410

415-
@numba_njit
416-
def shape(x):
417-
return np.asarray(np.shape(x))
411+
return makeslice
418412

419413

420414
@numba_funcify.register(Shape)
421415
def numba_funcify_Shape(op, **kwargs):
422-
return global_numba_func(shape)
416+
@numba_njit
417+
def shape(x):
418+
return np.asarray(np.shape(x))
419+
420+
return shape
423421

424422

425423
@numba_funcify.register(Shape_i)

pytensor/link/numba/dispatch/scalar.py

Lines changed: 60 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -141,17 +141,16 @@ def {scalar_op_fn_name}({', '.join(input_names)}):
141141
)(scalar_op_fn)
142142

143143

144-
@numba_basic.numba_njit
145-
def switch(condition, x, y):
146-
if condition:
147-
return x
148-
else:
149-
return y
150-
151-
152144
@numba_funcify.register(Switch)
153145
def numba_funcify_Switch(op, node, **kwargs):
154-
return numba_basic.global_numba_func(switch)
146+
@numba_basic.numba_njit
147+
def switch(condition, x, y):
148+
if condition:
149+
return x
150+
else:
151+
return y
152+
153+
return switch
155154

156155

157156
def binary_to_nary_func(inputs: list[Variable], binary_op_name: str, binary_op: str):
@@ -197,34 +196,32 @@ def cast(x):
197196
return cast
198197

199198

200-
@numba_basic.numba_njit
201-
def identity(x):
202-
return x
203-
204-
205199
@numba_funcify.register(Identity)
206200
@numba_funcify.register(TypeCastingOp)
207201
def numba_funcify_type_casting(op, **kwargs):
208-
return numba_basic.global_numba_func(identity)
209-
210-
211-
@numba_basic.numba_njit
212-
def clip(_x, _min, _max):
213-
x = numba_basic.to_scalar(_x)
214-
_min_scalar = numba_basic.to_scalar(_min)
215-
_max_scalar = numba_basic.to_scalar(_max)
216-
217-
if x < _min_scalar:
218-
return _min_scalar
219-
elif x > _max_scalar:
220-
return _max_scalar
221-
else:
202+
@numba_basic.numba_njit
203+
def identity(x):
222204
return x
223205

206+
return identity
207+
224208

225209
@numba_funcify.register(Clip)
226210
def numba_funcify_Clip(op, **kwargs):
227-
return numba_basic.global_numba_func(clip)
211+
@numba_basic.numba_njit
212+
def clip(x, min_val, max_val):
213+
x = numba_basic.to_scalar(x)
214+
min_scalar = numba_basic.to_scalar(min_val)
215+
max_scalar = numba_basic.to_scalar(max_val)
216+
217+
if x < min_scalar:
218+
return min_scalar
219+
elif x > max_scalar:
220+
return max_scalar
221+
else:
222+
return x
223+
224+
return clip
228225

229226

230227
@numba_funcify.register(Composite)
@@ -239,79 +236,72 @@ def numba_funcify_Composite(op, node, **kwargs):
239236
return composite_fn
240237

241238

242-
@numba_basic.numba_njit
243-
def second(x, y):
244-
return y
245-
246-
247239
@numba_funcify.register(Second)
248240
def numba_funcify_Second(op, node, **kwargs):
249-
return numba_basic.global_numba_func(second)
250-
241+
@numba_basic.numba_njit
242+
def second(x, y):
243+
return y
251244

252-
@numba_basic.numba_njit
253-
def reciprocal(x):
254-
# TODO FIXME: This isn't really the behavior or `numpy.reciprocal` when
255-
# `x` is an `int`
256-
return 1 / x
245+
return second
257246

258247

259248
@numba_funcify.register(Reciprocal)
260249
def numba_funcify_Reciprocal(op, node, **kwargs):
261-
return numba_basic.global_numba_func(reciprocal)
262-
250+
@numba_basic.numba_njit
251+
def reciprocal(x):
252+
# TODO FIXME: This isn't really the behavior or `numpy.reciprocal` when
253+
# `x` is an `int`
254+
return 1 / x
263255

264-
@numba_basic.numba_njit
265-
def sigmoid(x):
266-
return 1 / (1 + np.exp(-x))
256+
return reciprocal
267257

268258

269259
@numba_funcify.register(Sigmoid)
270260
def numba_funcify_Sigmoid(op, node, **kwargs):
271-
return numba_basic.global_numba_func(sigmoid)
272-
261+
@numba_basic.numba_njit
262+
def sigmoid(x):
263+
return 1 / (1 + np.exp(-x))
273264

274-
@numba_basic.numba_njit
275-
def gammaln(x):
276-
return math.lgamma(x)
265+
return sigmoid
277266

278267

279268
@numba_funcify.register(GammaLn)
280269
def numba_funcify_GammaLn(op, node, **kwargs):
281-
return numba_basic.global_numba_func(gammaln)
282-
270+
@numba_basic.numba_njit
271+
def gammaln(x):
272+
return math.lgamma(x)
283273

284-
@numba_basic.numba_njit
285-
def logp1mexp(x):
286-
if x < np.log(0.5):
287-
return np.log1p(-np.exp(x))
288-
else:
289-
return np.log(-np.expm1(x))
274+
return gammaln
290275

291276

292277
@numba_funcify.register(Log1mexp)
293278
def numba_funcify_Log1mexp(op, node, **kwargs):
294-
return numba_basic.global_numba_func(logp1mexp)
295-
279+
@numba_basic.numba_njit
280+
def logp1mexp(x):
281+
if x < np.log(0.5):
282+
return np.log1p(-np.exp(x))
283+
else:
284+
return np.log(-np.expm1(x))
296285

297-
@numba_basic.numba_njit
298-
def erf(x):
299-
return math.erf(x)
286+
return logp1mexp
300287

301288

302289
@numba_funcify.register(Erf)
303290
def numba_funcify_Erf(op, **kwargs):
304-
return numba_basic.global_numba_func(erf)
305-
291+
@numba_basic.numba_njit
292+
def erf(x):
293+
return math.erf(x)
306294

307-
@numba_basic.numba_njit
308-
def erfc(x):
309-
return math.erfc(x)
295+
return erf
310296

311297

312298
@numba_funcify.register(Erfc)
313299
def numba_funcify_Erfc(op, **kwargs):
314-
return numba_basic.global_numba_func(erfc)
300+
@numba_basic.numba_njit
301+
def erfc(x):
302+
return math.erfc(x)
303+
304+
return erfc
315305

316306

317307
@numba_funcify.register(Softplus)

0 commit comments

Comments
 (0)