@@ -378,53 +378,42 @@ def L_op(self, inputs, outputs, grads):
378378
379379 def c_support_code (self , ** kwargs ):
380380 return """
381- // For GPU support
382- #ifdef WITHIN_KERNEL
383- #define DEVICE WITHIN_KERNEL
384- #else
385- #define DEVICE
386- #endif
387-
388- #ifndef ga_double
389- #define ga_double double
390- #endif
391-
392381 #ifndef _PSIFUNCDEFINED
393382 #define _PSIFUNCDEFINED
394- DEVICE double _psi(ga_double x) {
383+ double _psi(double x) {
395384
396- /*taken from
397- Bernardo, J. M. (1976). Algorithm AS 103:
398- Psi (Digamma) Function. Applied Statistics. 25 (3), 315-317.
399- http://www.uv.es/~bernardo/1976AppStatist.pdf */
385+ /*taken from
386+ Bernardo, J. M. (1976). Algorithm AS 103:
387+ Psi (Digamma) Function. Applied Statistics. 25 (3), 315-317.
388+ http://www.uv.es/~bernardo/1976AppStatist.pdf */
400389
401- ga_double y, R, psi_ = 0;
402- ga_double S = 1.0e-5;
403- ga_double C = 8.5;
404- ga_double S3 = 8.333333333e-2;
405- ga_double S4 = 8.333333333e-3;
406- ga_double S5 = 3.968253968e-3;
407- ga_double D1 = -0.5772156649;
390+ double y, R, psi_ = 0;
391+ double S = 1.0e-5;
392+ double C = 8.5;
393+ double S3 = 8.333333333e-2;
394+ double S4 = 8.333333333e-3;
395+ double S5 = 3.968253968e-3;
396+ double D1 = -0.5772156649;
408397
409- y = x;
398+ y = x;
410399
411- if (y <= 0.0)
412- return psi_;
400+ if (y <= 0.0)
401+ return psi_;
413402
414- if (y <= S)
415- return D1 - 1.0/y;
403+ if (y <= S)
404+ return D1 - 1.0/y;
416405
417- while (y < C) {
418- psi_ = psi_ - 1.0 / y;
419- y = y + 1;
420- }
406+ while (y < C) {
407+ psi_ = psi_ - 1.0 / y;
408+ y = y + 1;
409+ }
421410
422- R = 1.0 / y;
423- psi_ = psi_ + log(y) - .5 * R ;
424- R= R*R;
425- psi_ = psi_ - R * (S3 - R * (S4 - R * S5));
411+ R = 1.0 / y;
412+ psi_ = psi_ + log(y) - .5 * R ;
413+ R= R*R;
414+ psi_ = psi_ - R * (S3 - R * (S4 - R * S5));
426415
427- return psi_;
416+ return psi_;
428417 }
429418 #endif
430419 """
@@ -433,10 +422,13 @@ def c_code(self, node, name, inp, out, sub):
433422 (x ,) = inp
434423 (z ,) = out
435424 if node .inputs [0 ].type in float_types :
436- return f""" { z } =
437- _psi({ x } );"" "
425+ dtype = "npy_" + node . outputs [ 0 ]. dtype
426+ return f" { z } = ( { dtype } ) _psi({ x } );"
438427 raise NotImplementedError ("only floating point is implemented" )
439428
429+ def c_code_cache_version (self ):
430+ return (1 ,)
431+
440432
441433psi = Psi (upgrade_to_float , name = "psi" )
442434
0 commit comments