Skip to content

Commit d6cada9

Browse files
committed
Adding options to switch kernel at runtime
1 parent b4b8a10 commit d6cada9

File tree

14 files changed

+337
-34
lines changed

14 files changed

+337
-34
lines changed

include/finufft.fh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ c DEPRECATED: spread_kerpad kept for ABI compatibility, ignored by library
2020
real*8 upsampfac
2121
integer spread_thread, maxbatchsize, spread_nthr_atomic
2222
integer spread_max_sp_size
23+
integer spread_kernel
2324
integer fftw_lock_fun, fftw_unlock_fun, fftw_lock_data
2425

2526
end type

include/finufft/spreadinterp.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@ namespace finufft {
2121
namespace spreadinterp {
2222

2323
template<typename T>
24-
FINUFFT_EXPORT_TEST int setup_spreader(finufft_spread_opts &opts, T eps, double upsampfac,
25-
int kerevalmeth, int debug, int showwarn,
26-
int spreadinterponly, int dim);
24+
FINUFFT_EXPORT_TEST int setup_spreader(
25+
finufft_spread_opts &opts, T eps, double upsampfac, int kerevalmeth, int debug,
26+
int showwarn, int spreadinterponly, int dim, int kernel_type = 0);
2727

2828
int spreadcheck(UBIGINT N1, UBIGINT N2, UBIGINT N3, const finufft_spread_opts &opts);
2929
template<typename T>

include/finufft_common/kernel.h

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include <vector>
77

88
#include <finufft_common/constants.h>
9+
#include <finufft_common/utils.h>
910

1011
namespace finufft::kernel {
1112

@@ -56,14 +57,25 @@ template<class T, class F> std::vector<T> fit_monomials(F &&f, int n, T a, T b)
5657
return c;
5758
}
5859

59-
template<typename T> T evaluate_kernel(T x, T beta, T c) {
60-
/* ES ("exp sqrt") kernel evaluation at single real argument:
61-
phi(x) = exp(beta.(sqrt(1 - (2x/n_s)^2) - 1)), for |x| < nspread/2
62-
related to an asymptotic approximation to the Kaiser--Bessel, itself an
63-
approximation to prolate spheroidal wavefunction (PSWF) of order 0.
64-
This is the "reference implementation", used by eg finufft/onedim_* 2/17/17.
65-
Rescaled so max is 1, Barnett 7/21/24
60+
template<typename T> T evaluate_kernel(T x, T beta, T c, int kernel_type = 0) {
61+
/* Kernel evaluation at single real argument.
62+
kernel_type == 0 : ES ("exp sqrt") kernel (default)
63+
phi_ES(x) = exp(beta*(sqrt(1 - c*x^2) - 1))
64+
kernel_type == 1 : Kaiser--Bessel (KB) kernel
65+
phi_KB(x) = I_0(beta*sqrt(1 - c*x^2)) / I_0(beta)
66+
Note: `std::cyl_bessel_i` from <cmath> is used for I_0.
67+
Rescaled so max is 1.
6668
*/
69+
if (kernel_type == 1) {
70+
// Kaiser--Bessel (normalized by I0(beta)). Use std::cyl_bessel_i from <cmath>.
71+
const T inner = std::sqrt(T(1) - c * x * x);
72+
const T arg = beta * inner;
73+
const double i0_arg = ::finufft::common::cyl_bessel_i(0, static_cast<double>(arg));
74+
const double i0_beta = ::finufft::common::cyl_bessel_i(0, static_cast<double>(beta));
75+
return static_cast<T>(i0_arg / i0_beta);
76+
}
77+
78+
// default to ES
6779
return std::exp(beta * (std::sqrt(T(1) - c * x * x) - T(1)));
6880
}
6981

include/finufft_common/utils.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@ namespace common {
1313
FINUFFT_EXPORT_TEST void gaussquad(int n, double *xgl, double *wgl);
1414
std::tuple<double, double> leg_eval(int n, double x);
1515

16+
// Series implementation of the modified Bessel function of the first kind I_nu(x)
17+
double cyl_bessel_i(double nu, double x) noexcept;
18+
1619
// helper to generate the integer sequence in range [Start, End]
1720
template<int Offset, typename Seq> struct offset_seq;
1821

include/finufft_mod.f90

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ module finufft_mod
2121
real(kind=C_DOUBLE) :: upsampfac
2222
integer(kind=C_INT) :: spread_thread, maxbatchsize
2323
integer(kind=C_INT) :: spread_nthr_atomic, spread_max_sp_size
24+
integer(kind=C_INT) :: spread_kernel
2425
integer(kind=C_SIZE_T) :: fftw_lock_fun, fftw_unlock_fun, fftw_lock_data
2526
! really, last should be type(C_PTR) :: etc, but fails to print nicely
2627

include/finufft_opts.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ typedef struct finufft_opts { // defaults see finufft_core.cpp:finufft_default_o
3232
int spread_nthr_atomic; // if >=0, threads above which spreader OMP critical goes
3333
// atomic
3434
int spread_max_sp_size; // if >0, overrides spreader (dir=1) max subproblem size
35+
int spread_kernel; // (dev only) 0:DEFAULT, (do not change), there is no guarantee
36+
// what non-zero values do and behaviour can change anytime
3537
// sphinx tag (don't remove): @opts_end
3638

3739
// User can provide their own FFTW planner lock functions for thread safety

include/finufft_spread_opts.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ typedef struct finufft_spread_opts {
2626
double ES_beta;
2727
double ES_halfwidth;
2828
double ES_c;
29+
// Kernel selector: 0 = ES (default), 1 = Kaiser--Bessel (KB)
30+
int kernel_type; // default 0
2931
} finufft_spread_opts;
3032

3133
#endif // FINUFFT_SPREAD_OPTS_H

matlab/finufft.mw

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,9 @@ $ }
103103
$ else if (strcmp(fname[ifield],"spread_max_sp_size") == 0) {
104104
$ oc->spread_max_sp_size = (int)round(*mxGetPr(mxGetFieldByNumber(om,idx,ifield)));
105105
$ }
106+
$ else if (strcmp(fname[ifield],"spread_kernel") == 0) {
107+
$ oc->spread_kernel = (int)round(*mxGetPr(mxGetFieldByNumber(om,idx,ifield)));
108+
$ }
106109
$ else if (strcmp(fname[ifield],"spreadinterponly") == 0) {
107110
$ oc->spreadinterponly = (int)round(*mxGetPr(mxGetFieldByNumber(om,idx,ifield)));
108111
$ }

python/finufft/finufft/_finufft.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ class FinufftOpts(ctypes.Structure):
8686
('maxbatchsize', c_int),
8787
('spread_nthr_atomic', c_int),
8888
('spread_max_sp_size', c_int),
89+
('spread_kernel', c_int),
8990
('fftw_lock_fun', c_void_p),
9091
('fftw_unlock_fun', c_void_p),
9192
('fftw_lock_data', c_void_p)]

src/finufft_core.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ static int setup_spreader_for_nufft(finufft_spread_opts &spopts, T eps,
108108
// this calls spreadinterp.cpp...
109109
int ier = setup_spreader(spopts, eps, opts.upsampfac, opts.spread_kerevalmeth,
110110
opts.spread_debug, opts.showwarn, opts.spreadinterponly,
111-
dim);
111+
dim, opts.spread_kernel);
112112
// override various spread opts from their defaults...
113113
spopts.debug = opts.spread_debug;
114114
spopts.sort = opts.spread_sort; // could make dim or CPU choices here?
@@ -532,6 +532,7 @@ template<typename TF> void FINUFFT_PLAN_T<TF>::precompute_horner_coeffs() {
532532
// Precompute kernel parameters once
533533
const TF beta = TF(this->spopts.ES_beta);
534534
const TF c_param = TF(this->spopts.ES_c);
535+
const int kernel_type = this->spopts.kernel_type;
535536

536537
nc = MIN_NC;
537538

@@ -549,9 +550,9 @@ template<typename TF> void FINUFFT_PLAN_T<TF>::precompute_horner_coeffs() {
549550
// original: 0.5 * (x - nspread + 2*j + 1)
550551
const TF shift = TF(2 * j + 1 - nspread);
551552

552-
const auto kernel = [shift, beta, c_param](TF x) -> TF {
553+
const auto kernel = [shift, beta, c_param, kernel_type](TF x) -> TF {
553554
const TF t = TF(0.5) * (x + shift);
554-
return evaluate_kernel(t, beta, c_param);
555+
return evaluate_kernel(t, beta, c_param, kernel_type);
555556
};
556557

557558
const auto coeffs = fit_monomials(kernel, static_cast<int>(max_degree), a, b);
@@ -715,6 +716,7 @@ void finufft_default_opts_t(finufft_opts *o)
715716
o->maxbatchsize = 0;
716717
o->spread_nthr_atomic = -1;
717718
o->spread_max_sp_size = 0;
719+
o->spread_kernel = 0;
718720
o->fftw_lock_fun = nullptr;
719721
o->fftw_unlock_fun = nullptr;
720722
o->fftw_lock_data = nullptr;

0 commit comments

Comments
 (0)