diff --git a/lib/jax_finufft_gpu.cc b/lib/jax_finufft_gpu.cc index 56da6f8..3d472e5 100644 --- a/lib/jax_finufft_gpu.cc +++ b/lib/jax_finufft_gpu.cc @@ -23,15 +23,15 @@ nb::dict Registrations() { nb::dict dict; // TODO: do we prefer to keep these names the same as the CPU version or prefix them with "cu"? - // dict["nufft1d1f"] = encapsulate_function(nufft1d1f); - // dict["nufft1d2f"] = encapsulate_function(nufft1d2f); + dict["nufft1d1f"] = encapsulate_function(nufft1d1f); + dict["nufft1d2f"] = encapsulate_function(nufft1d2f); dict["nufft2d1f"] = encapsulate_function(nufft2d1f); dict["nufft2d2f"] = encapsulate_function(nufft2d2f); dict["nufft3d1f"] = encapsulate_function(nufft3d1f); dict["nufft3d2f"] = encapsulate_function(nufft3d2f); - // dict["nufft1d1"] = encapsulate_function(nufft1d1); - // dict["nufft1d2"] = encapsulate_function(nufft1d2); + dict["nufft1d1"] = encapsulate_function(nufft1d1); + dict["nufft1d2"] = encapsulate_function(nufft1d2); dict["nufft2d1"] = encapsulate_function(nufft2d1); dict["nufft2d2"] = encapsulate_function(nufft2d2); dict["nufft3d1"] = encapsulate_function(nufft3d1); diff --git a/lib/kernels.cc.cu b/lib/kernels.cc.cu index ca6f834..a5db1a5 100644 --- a/lib/kernels.cc.cu +++ b/lib/kernels.cc.cu @@ -87,6 +87,14 @@ void nufft2(cudaStream_t stream, void **buffers, const char *opaque, std::size_t ThrowIfError(cudaGetLastError()); } +void nufft1d1(cudaStream_t stream, void **buffers, const char *opaque, std::size_t opaque_len) { + nufft1<1, double>(stream, buffers, opaque, opaque_len); +} + +void nufft1d2(cudaStream_t stream, void **buffers, const char *opaque, std::size_t opaque_len) { + nufft2<1, double>(stream, buffers, opaque, opaque_len); +} + void nufft2d1(cudaStream_t stream, void **buffers, const char *opaque, std::size_t opaque_len) { nufft1<2, double>(stream, buffers, opaque, opaque_len); } @@ -103,6 +111,14 @@ void nufft3d2(cudaStream_t stream, void **buffers, const char *opaque, std::size nufft2<3, double>(stream, buffers, opaque, opaque_len); } +void nufft1d1f(cudaStream_t stream, void **buffers, const char *opaque, std::size_t opaque_len) { + nufft1<1, float>(stream, buffers, opaque, opaque_len); +} + +void nufft1d2f(cudaStream_t stream, void **buffers, const char *opaque, std::size_t opaque_len) { + nufft2<1, float>(stream, buffers, opaque, opaque_len); +} + void nufft2d1f(cudaStream_t stream, void **buffers, const char *opaque, std::size_t opaque_len) { nufft1<2, float>(stream, buffers, opaque, opaque_len); } diff --git a/lib/kernels.h b/lib/kernels.h index 7aaf9fd..2252b33 100644 --- a/lib/kernels.h +++ b/lib/kernels.h @@ -19,11 +19,15 @@ struct descriptor { cufinufft_opts opts; }; +void nufft1d1(cudaStream_t stream, void** buffers, const char* opaque, std::size_t opaque_len); +void nufft1d2(cudaStream_t stream, void** buffers, const char* opaque, std::size_t opaque_len); void nufft2d1(cudaStream_t stream, void** buffers, const char* opaque, std::size_t opaque_len); void nufft2d2(cudaStream_t stream, void** buffers, const char* opaque, std::size_t opaque_len); void nufft3d1(cudaStream_t stream, void** buffers, const char* opaque, std::size_t opaque_len); void nufft3d2(cudaStream_t stream, void** buffers, const char* opaque, std::size_t opaque_len); +void nufft1d1f(cudaStream_t stream, void** buffers, const char* opaque, std::size_t opaque_len); +void nufft1d2f(cudaStream_t stream, void** buffers, const char* opaque, std::size_t opaque_len); void nufft2d1f(cudaStream_t stream, void** buffers, const char* opaque, std::size_t opaque_len); void nufft2d2f(cudaStream_t stream, void** buffers, const char* opaque, std::size_t opaque_len); void nufft3d1f(cudaStream_t stream, void** buffers, const char* opaque, std::size_t opaque_len); diff --git a/pyproject.toml b/pyproject.toml index 23ccecf..640a92d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -72,3 +72,6 @@ combine-as-imports = true [tool.pytest.ini_options] testpaths = ["tests"] +filterwarnings = [ + "error::DeprecationWarning", +] diff --git a/src/jax_finufft/lowering.py b/src/jax_finufft/lowering.py index f1a847b..8829519 100644 --- a/src/jax_finufft/lowering.py +++ b/src/jax_finufft/lowering.py @@ -42,8 +42,6 @@ def lowering( ndim = len(points) assert 1 <= ndim <= 3 - if platform == "cuda" and ndim == 1: - raise ValueError("1-D transforms are not yet supported on the GPU") source_aval = ctx.avals_in[0] single = source_aval.dtype == np.complex64 diff --git a/src/jax_finufft/ops.py b/src/jax_finufft/ops.py index 9c90f31..4721c65 100644 --- a/src/jax_finufft/ops.py +++ b/src/jax_finufft/ops.py @@ -24,6 +24,7 @@ def nufft1(output_shape, source, *points, iflag=1, eps=1e-6, opts=None): output_shape = np.atleast_1d(output_shape).astype(np.int64) if output_shape.shape != (ndim,): raise ValueError(f"output_shape must have shape: ({ndim},)") + output_shape = tuple(output_shape) # Handle broadcasting and reshaping of inputs index, source, *points = shapes.broadcast_and_flatten_inputs( diff --git a/tests/ops_test.py b/tests/ops_test.py index c91b403..da31551 100644 --- a/tests/ops_test.py +++ b/tests/ops_test.py @@ -21,9 +21,6 @@ def check_close(a, b, **kwargs): product([1, 2, 3], [False, True], [50], [75], [-1, 1]), ) def test_nufft1_forward(ndim, x64, num_nonnuniform, num_uniform, iflag): - if ndim == 1 and jax.default_backend() != "cpu": - pytest.skip("1D transforms not implemented on GPU") - random = np.random.default_rng(657) eps = 1e-10 if x64 else 1e-7 @@ -56,9 +53,6 @@ def test_nufft1_forward(ndim, x64, num_nonnuniform, num_uniform, iflag): product([1, 2, 3], [False, True], [50], [75], [-1, 1]), ) def test_nufft2_forward(ndim, x64, num_nonnuniform, num_uniform, iflag): - if ndim == 1 and jax.default_backend() != "cpu": - pytest.skip("1D transforms not implemented on GPU") - random = np.random.default_rng(657) eps = 1e-10 if x64 else 1e-7 @@ -98,9 +92,6 @@ def test_nufft2_forward(ndim, x64, num_nonnuniform, num_uniform, iflag): product([1, 2, 3], [50], [35], [-1, 1]), ) def test_nufft1_grad(ndim, num_nonnuniform, num_uniform, iflag): - if ndim == 1 and jax.default_backend() != "cpu": - pytest.skip("1D transforms not implemented on GPU") - random = np.random.default_rng(657) eps = 1e-10 @@ -133,9 +124,6 @@ def scalar_func(*args): product([1, 2, 3], [50], [35], [-1, 1]), ) def test_nufft2_grad(ndim, num_nonnuniform, num_uniform, iflag): - if ndim == 1 and jax.default_backend() != "cpu": - pytest.skip("1D transforms not implemented on GPU") - random = np.random.default_rng(657) eps = 1e-10 @@ -168,9 +156,6 @@ def scalar_func(*args): product([1, 2, 3], [50], [35], [-1, 1]), ) def test_nufft1_vmap(ndim, num_nonnuniform, num_uniform, iflag): - if ndim == 1 and jax.default_backend() != "cpu": - pytest.skip("1D transforms not implemented on GPU") - random = np.random.default_rng(657) dtype = np.double @@ -219,9 +204,6 @@ def test_nufft1_vmap(ndim, num_nonnuniform, num_uniform, iflag): product([1, 2, 3], [50], [35], [-1, 1]), ) def test_nufft2_vmap(ndim, num_nonnuniform, num_uniform, iflag): - if ndim == 1 and jax.default_backend() != "cpu": - pytest.skip("1D transforms not implemented on GPU") - random = np.random.default_rng(657) dtype = np.double @@ -266,10 +248,6 @@ def test_nufft2_vmap(ndim, num_nonnuniform, num_uniform, iflag): def test_multi_transform(): - # TODO: is there a 2D or 3D version of this test? - if jax.default_backend() != "cpu": - pytest.skip("1D transforms not implemented on GPU") - random = np.random.default_rng(314) n_tot, n_tr, n_j, n_k = 4, 10, 100, 12 @@ -287,9 +265,6 @@ def test_multi_transform(): def test_gh14(): - if jax.default_backend() != "cpu": - pytest.skip("1D transforms not implemented on GPU") - M = 100 N = 200 diff --git a/vendor/finufft b/vendor/finufft index cbda179..e7144a5 160000 --- a/vendor/finufft +++ b/vendor/finufft @@ -1 +1 @@ -Subproject commit cbda17905ce0b52590b7fa2fbd73eb7f1845217e +Subproject commit e7144a5c08cbaf3e3b344a4fdd92bc3c7e468ff2