Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions lib/jax_finufft_cpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ NB_MODULE(jax_finufft_cpu, m) {
int spread_max_sp_size) {
new (self) finufft_opts;
default_opts<double>(self);

self->modeord = int(modeord);
self->debug = debug;
self->spread_debug = spread_debug;
Expand Down
43 changes: 23 additions & 20 deletions lib/jax_finufft_gpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,27 +46,30 @@ NB_MODULE(jax_finufft_gpu, m) {
m.def("build_descriptor", &build_descriptor<double>);

nb::class_<cufinufft_opts> opts(m, "CufinufftOpts");
opts.def("__init__", [](cufinufft_opts* self, double upsampfac, int gpu_method, bool gpu_sort,
int gpu_binsizex, int gpu_binsizey, int gpu_binsizez, int gpu_obinsizex,
int gpu_obinsizey, int gpu_obinsizez, int gpu_maxsubprobsize,
bool gpu_kerevalmeth, int gpu_spreadinterponly, int gpu_maxbatchsize) {
new (self) cufinufft_opts;
default_opts<double>(self);
opts.def("__init__",
[](cufinufft_opts* self, bool modeord, int gpu_spreadinterponly, int debug,
int gpu_method, bool gpu_sort, bool gpu_kerevalmeth, double upsampfac,
int gpu_maxsubprobsize, int gpu_obinsizex, int gpu_obinsizey, int gpu_obinsizez,
int gpu_binsizex, int gpu_binsizey, int gpu_binsizez, int gpu_maxbatchsize) {
new (self) cufinufft_opts;
default_opts<double>(self);

self->upsampfac = upsampfac;
self->gpu_method = gpu_method;
self->gpu_sort = int(gpu_sort);
self->gpu_binsizex = gpu_binsizex;
self->gpu_binsizey = gpu_binsizey;
self->gpu_binsizez = gpu_binsizez;
self->gpu_obinsizex = gpu_obinsizex;
self->gpu_obinsizey = gpu_obinsizey;
self->gpu_obinsizez = gpu_obinsizez;
self->gpu_maxsubprobsize = gpu_maxsubprobsize;
self->gpu_kerevalmeth = gpu_kerevalmeth;
self->gpu_spreadinterponly = gpu_spreadinterponly;
self->gpu_maxbatchsize = gpu_maxbatchsize;
});
self->modeord = int(modeord);
self->gpu_spreadinterponly = gpu_spreadinterponly;
self->debug = debug;
self->gpu_method = gpu_method;
self->gpu_sort = int(gpu_sort);
self->gpu_kerevalmeth = gpu_kerevalmeth;
self->upsampfac = upsampfac;
self->gpu_maxsubprobsize = gpu_maxsubprobsize;
self->gpu_obinsizex = gpu_obinsizex;
self->gpu_obinsizey = gpu_obinsizey;
self->gpu_obinsizez = gpu_obinsizez;
self->gpu_binsizex = gpu_binsizex;
self->gpu_binsizey = gpu_binsizey;
self->gpu_binsizez = gpu_binsizez;
self->gpu_maxbatchsize = gpu_maxbatchsize;
});
}

} // namespace
8 changes: 7 additions & 1 deletion src/jax_finufft/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ def jvp(prim, args, tangents, *, output_shape, iflag, eps, opts):
# c_j = sum_k f_k * exp(iflag * i * k * x_j)
# dc_j/dx_j = sum_k iflag * i * k * f_k * exp(iflag * i * k * x_j)

modeord = 0 if opts is None else opts.modeord

source, *points = args
dsource, *dpoints = tangents
output = prim.bind(
Expand Down Expand Up @@ -105,7 +107,11 @@ def jvp(prim, args, tangents, *, output_shape, iflag, eps, opts):
n = source.shape[-ndim + dim] if output_shape is None else output_shape[dim]
shape = np.ones(ndim, dtype=int)
shape[dim] = -1
k = np.arange(-np.floor(n / 2), np.floor((n - 1) / 2 + 1))

k = np.fft.fftfreq(n, 1 / n)
if modeord == 0:
k = np.fft.fftshift(k)

k = k.reshape(shape)
factor = 1j * iflag * k
dx = dx[:, None, :]
Expand Down
54 changes: 38 additions & 16 deletions src/jax_finufft/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,32 @@ class GpuMethod(IntEnum):

@dataclass(frozen=True)
class Opts:
"""FINUFFT optional paramaters.

# These correspond to the default cufinufft options
# set in vendor/finufft/src/cuda/cufinufft.cu
Refrences
---------
https://finufft.readthedocs.io/en/latest/opts.html
https://finufft.readthedocs.io/en/latest/c_gpu.html#options-for-gpu-code

"""

# Parameters are chosen to match the default values in
# vendor/finufft/src/finufft_core.cpp:finufft_default_opts
# vendor/finufft/src/cuda/cufinufft.cu

# data handling opts
modeord: bool = False
# TODO: Add warning stating that (cpu) spreadinterponly is not implemented.
# TODO: Add warning stating that gpu_device_id is not implemented.
gpu_spreadinterponly: bool = False

# diagnostic opts
debug: DebugLevel = DebugLevel.Silent
spread_debug: DebugLevel = DebugLevel.Silent
# TODO: Document (publicly) that this differs from default value.
showwarn: bool = False

# algorithm performance opts
nthreads: int = 0
fftw: int = FftwFlags.Estimate
spread_sort: SpreadSort = SpreadSort.Heuristic
Expand All @@ -58,19 +77,20 @@ class Opts:
spread_nthr_atomic: int = -1
spread_max_sp_size: int = 0

gpu_upsampfac: float = 2.0
gpu_method: GpuMethod = 0
gpu_sort: bool = True
gpu_binsizex: int = 0
gpu_binsizey: int = 0
gpu_binsizez: int = 0
gpu_kerevalmeth: bool = True
gpu_upsampfac: float = 2.0
gpu_maxsubprobsize: int = 1024
gpu_obinsizex: int = 0
gpu_obinsizey: int = 0
gpu_obinsizez: int = 0
gpu_maxsubprobsize: int = 1024
gpu_kerevalmeth: bool = True
gpu_spreadinterponly: bool = False
gpu_binsizex: int = 0
gpu_binsizey: int = 0
gpu_binsizez: int = 0
gpu_maxbatchsize: int = 0
# TODO: Add warning stating that gpu_np is not implemented.
# TODO: Add warning stating that gpustream is not implemented.

def to_finufft_opts(self):
compiled_with_omp = jax_finufft_cpu._omp_compile_check()
Expand All @@ -95,18 +115,20 @@ def to_cufinufft_opts(self):
from jax_finufft import jax_finufft_gpu

return jax_finufft_gpu.CufinufftOpts(
self.gpu_upsampfac,
self.modeord,
self.gpu_spreadinterponly,
int(self.debug),
int(self.gpu_method),
self.gpu_sort,
self.gpu_binsizex,
self.gpu_binsizey,
self.gpu_binsizez,
self.gpu_kerevalmeth,
self.gpu_upsampfac,
self.gpu_maxsubprobsize,
self.gpu_obinsizex,
self.gpu_obinsizey,
self.gpu_obinsizez,
self.gpu_maxsubprobsize,
self.gpu_kerevalmeth,
self.gpu_spreadinterponly,
self.gpu_binsizex,
self.gpu_binsizey,
self.gpu_binsizez,
self.gpu_maxbatchsize,
)

Expand Down
Loading