Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions lib/jax_finufft_gpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,15 @@ nb::dict Registrations() {
nb::dict dict;

// TODO: do we prefer to keep these names the same as the CPU version or prefix them with "cu"?
// dict["nufft1d1f"] = encapsulate_function(nufft1d1f);
// dict["nufft1d2f"] = encapsulate_function(nufft1d2f);
dict["nufft1d1f"] = encapsulate_function(nufft1d1f);
dict["nufft1d2f"] = encapsulate_function(nufft1d2f);
dict["nufft2d1f"] = encapsulate_function(nufft2d1f);
dict["nufft2d2f"] = encapsulate_function(nufft2d2f);
dict["nufft3d1f"] = encapsulate_function(nufft3d1f);
dict["nufft3d2f"] = encapsulate_function(nufft3d2f);

// dict["nufft1d1"] = encapsulate_function(nufft1d1);
// dict["nufft1d2"] = encapsulate_function(nufft1d2);
dict["nufft1d1"] = encapsulate_function(nufft1d1);
dict["nufft1d2"] = encapsulate_function(nufft1d2);
dict["nufft2d1"] = encapsulate_function(nufft2d1);
dict["nufft2d2"] = encapsulate_function(nufft2d2);
dict["nufft3d1"] = encapsulate_function(nufft3d1);
Expand Down
16 changes: 16 additions & 0 deletions lib/kernels.cc.cu
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,14 @@ void nufft2(cudaStream_t stream, void **buffers, const char *opaque, std::size_t
ThrowIfError(cudaGetLastError());
}

void nufft1d1(cudaStream_t stream, void **buffers, const char *opaque, std::size_t opaque_len) {
nufft1<1, double>(stream, buffers, opaque, opaque_len);
}

void nufft1d2(cudaStream_t stream, void **buffers, const char *opaque, std::size_t opaque_len) {
nufft2<1, double>(stream, buffers, opaque, opaque_len);
}

void nufft2d1(cudaStream_t stream, void **buffers, const char *opaque, std::size_t opaque_len) {
nufft1<2, double>(stream, buffers, opaque, opaque_len);
}
Expand All @@ -103,6 +111,14 @@ void nufft3d2(cudaStream_t stream, void **buffers, const char *opaque, std::size
nufft2<3, double>(stream, buffers, opaque, opaque_len);
}

void nufft1d1f(cudaStream_t stream, void **buffers, const char *opaque, std::size_t opaque_len) {
nufft1<1, float>(stream, buffers, opaque, opaque_len);
}

void nufft1d2f(cudaStream_t stream, void **buffers, const char *opaque, std::size_t opaque_len) {
nufft2<1, float>(stream, buffers, opaque, opaque_len);
}

void nufft2d1f(cudaStream_t stream, void **buffers, const char *opaque, std::size_t opaque_len) {
nufft1<2, float>(stream, buffers, opaque, opaque_len);
}
Expand Down
4 changes: 4 additions & 0 deletions lib/kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,15 @@ struct descriptor {
cufinufft_opts opts;
};

void nufft1d1(cudaStream_t stream, void** buffers, const char* opaque, std::size_t opaque_len);
void nufft1d2(cudaStream_t stream, void** buffers, const char* opaque, std::size_t opaque_len);
void nufft2d1(cudaStream_t stream, void** buffers, const char* opaque, std::size_t opaque_len);
void nufft2d2(cudaStream_t stream, void** buffers, const char* opaque, std::size_t opaque_len);
void nufft3d1(cudaStream_t stream, void** buffers, const char* opaque, std::size_t opaque_len);
void nufft3d2(cudaStream_t stream, void** buffers, const char* opaque, std::size_t opaque_len);

void nufft1d1f(cudaStream_t stream, void** buffers, const char* opaque, std::size_t opaque_len);
void nufft1d2f(cudaStream_t stream, void** buffers, const char* opaque, std::size_t opaque_len);
void nufft2d1f(cudaStream_t stream, void** buffers, const char* opaque, std::size_t opaque_len);
void nufft2d2f(cudaStream_t stream, void** buffers, const char* opaque, std::size_t opaque_len);
void nufft3d1f(cudaStream_t stream, void** buffers, const char* opaque, std::size_t opaque_len);
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,6 @@ combine-as-imports = true

[tool.pytest.ini_options]
testpaths = ["tests"]
filterwarnings = [
"error::DeprecationWarning",
]
2 changes: 0 additions & 2 deletions src/jax_finufft/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,6 @@ def lowering(

ndim = len(points)
assert 1 <= ndim <= 3
if platform == "cuda" and ndim == 1:
raise ValueError("1-D transforms are not yet supported on the GPU")

source_aval = ctx.avals_in[0]
single = source_aval.dtype == np.complex64
Expand Down
1 change: 1 addition & 0 deletions src/jax_finufft/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def nufft1(output_shape, source, *points, iflag=1, eps=1e-6, opts=None):
output_shape = np.atleast_1d(output_shape).astype(np.int64)
if output_shape.shape != (ndim,):
raise ValueError(f"output_shape must have shape: ({ndim},)")
output_shape = tuple(output_shape)

# Handle broadcasting and reshaping of inputs
index, source, *points = shapes.broadcast_and_flatten_inputs(
Expand Down
25 changes: 0 additions & 25 deletions tests/ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,6 @@ def check_close(a, b, **kwargs):
product([1, 2, 3], [False, True], [50], [75], [-1, 1]),
)
def test_nufft1_forward(ndim, x64, num_nonnuniform, num_uniform, iflag):
if ndim == 1 and jax.default_backend() != "cpu":
pytest.skip("1D transforms not implemented on GPU")

random = np.random.default_rng(657)

eps = 1e-10 if x64 else 1e-7
Expand Down Expand Up @@ -56,9 +53,6 @@ def test_nufft1_forward(ndim, x64, num_nonnuniform, num_uniform, iflag):
product([1, 2, 3], [False, True], [50], [75], [-1, 1]),
)
def test_nufft2_forward(ndim, x64, num_nonnuniform, num_uniform, iflag):
if ndim == 1 and jax.default_backend() != "cpu":
pytest.skip("1D transforms not implemented on GPU")

random = np.random.default_rng(657)

eps = 1e-10 if x64 else 1e-7
Expand Down Expand Up @@ -98,9 +92,6 @@ def test_nufft2_forward(ndim, x64, num_nonnuniform, num_uniform, iflag):
product([1, 2, 3], [50], [35], [-1, 1]),
)
def test_nufft1_grad(ndim, num_nonnuniform, num_uniform, iflag):
if ndim == 1 and jax.default_backend() != "cpu":
pytest.skip("1D transforms not implemented on GPU")

random = np.random.default_rng(657)

eps = 1e-10
Expand Down Expand Up @@ -133,9 +124,6 @@ def scalar_func(*args):
product([1, 2, 3], [50], [35], [-1, 1]),
)
def test_nufft2_grad(ndim, num_nonnuniform, num_uniform, iflag):
if ndim == 1 and jax.default_backend() != "cpu":
pytest.skip("1D transforms not implemented on GPU")

random = np.random.default_rng(657)

eps = 1e-10
Expand Down Expand Up @@ -168,9 +156,6 @@ def scalar_func(*args):
product([1, 2, 3], [50], [35], [-1, 1]),
)
def test_nufft1_vmap(ndim, num_nonnuniform, num_uniform, iflag):
if ndim == 1 and jax.default_backend() != "cpu":
pytest.skip("1D transforms not implemented on GPU")

random = np.random.default_rng(657)

dtype = np.double
Expand Down Expand Up @@ -219,9 +204,6 @@ def test_nufft1_vmap(ndim, num_nonnuniform, num_uniform, iflag):
product([1, 2, 3], [50], [35], [-1, 1]),
)
def test_nufft2_vmap(ndim, num_nonnuniform, num_uniform, iflag):
if ndim == 1 and jax.default_backend() != "cpu":
pytest.skip("1D transforms not implemented on GPU")

random = np.random.default_rng(657)

dtype = np.double
Expand Down Expand Up @@ -266,10 +248,6 @@ def test_nufft2_vmap(ndim, num_nonnuniform, num_uniform, iflag):


def test_multi_transform():
# TODO: is there a 2D or 3D version of this test?
if jax.default_backend() != "cpu":
pytest.skip("1D transforms not implemented on GPU")

random = np.random.default_rng(314)

n_tot, n_tr, n_j, n_k = 4, 10, 100, 12
Expand All @@ -287,9 +265,6 @@ def test_multi_transform():


def test_gh14():
if jax.default_backend() != "cpu":
pytest.skip("1D transforms not implemented on GPU")

M = 100
N = 200

Expand Down
2 changes: 1 addition & 1 deletion vendor/finufft
Submodule finufft updated 43 files
+2 −2 .github/workflows/build_cufinufft_wheels.yml
+1 −1 .github/workflows/build_finufft_wheels.yml
+3 −0 .github/workflows/cmake_ci.yml
+15 −0 CHANGELOG
+1 −2 CMakeLists.txt
+1 −2 LICENSE
+0 −490 contrib/legendre_rule_fast.cpp
+0 −10 contrib/legendre_rule_fast.h
+0 −8 contrib/legendre_rule_fast.license
+6 −1 docs/ackn.rst
+1 −1 docs/conf.py
+0 −1 docs/dirs.rst
+3 −3 docs/install.rst
+6 −9 docs/overview.src
+ docs/pics/pois_fft_python.png
+ docs/pics/pois_fhat_python.png
+ docs/pics/pois_nufft_python.png
+ docs/pics/pois_nugrid_python.png
+4 −2 docs/tut.rst
+231 −0 docs/tutorial/peripois2d_python.rst
+0 −6 include/cufinufft/contrib/legendre_rule_fast.h
+7 −5 include/cufinufft/utils.h
+1 −1 include/finufft/finufft_core.h
+7 −4 include/finufft/finufft_utils.hpp
+8 −7 makefile
+1 −1 matlab/Contents.m
+71 −61 perftest/cuda/cuperftest.cu
+7 −6 python/cufinufft/cufinufft/__init__.py
+17 −13 python/cufinufft/cufinufft/_cufinufft.py
+1 −1 python/cufinufft/cufinufft/_plan.py
+22 −10 python/cufinufft/cufinufft/_simple.py
+3 −5 python/cufinufft/tests/test_array_ordering.py
+11 −19 python/cufinufft/tests/test_basic.py
+38 −9 python/cufinufft/tests/test_simple.py
+10 −14 python/cufinufft/tests/utils.py
+1 −1 python/finufft/finufft/__init__.py
+1 −1 src/cuda/CMakeLists.txt
+8 −15 src/cuda/common.cu
+59 −1 src/cuda/utils.cpp
+11 −9 src/finufft_core.cpp
+67 −3 src/finufft_utils.cpp
+26 −6 test/testutils.cpp
+160 −0 tutorial/poisson2dnuquad.py
Loading