From 8a4e7488a8dbe7adfefbec52b072d1bec64fc1e4 Mon Sep 17 00:00:00 2001 From: Dan F-M Date: Mon, 8 Nov 2021 15:03:51 -0500 Subject: [PATCH 01/12] starting to add optional cuda support --- .gitmodules | 3 +++ CMakeLists.txt | 21 ++++++++++++++++----- FindFFTW.cmake => cmake/FindFFTW.cmake | 0 vendor/cufinufft | 1 + 4 files changed, 20 insertions(+), 5 deletions(-) rename FindFFTW.cmake => cmake/FindFFTW.cmake (100%) create mode 160000 vendor/cufinufft diff --git a/.gitmodules b/.gitmodules index b11d194..fbf5ea2 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,6 @@ [submodule "finufft"] path = vendor/finufft url = https://github.com/flatironinstitute/finufft +[submodule "vendor/cufinufft"] + path = vendor/cufinufft + url = https://github.com/flatironinstitute/cufinufft diff --git a/CMakeLists.txt b/CMakeLists.txt index a32dfce..c649728 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,9 +1,10 @@ cmake_minimum_required(VERSION 3.12) project(jax_finufft LANGUAGES C CXX) -message(STATUS "Using CMake version: " ${CMAKE_VERSION}) -set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} ${CMAKE_CURRENT_LIST_DIR}) +# Add the /cmake directory to the module path so that we can find FFTW +set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} ${CMAKE_CURRENT_LIST_DIR}/cmake) +# Handle Python settings passed from scikit-build if(SKBUILD) set(Python_EXECUTABLE "${PYTHON_EXECUTABLE}") set(Python_INCLUDE_DIR "${PYTHON_INCLUDE_DIR}") @@ -15,20 +16,20 @@ if(SKBUILD) list(APPEND CMAKE_PREFIX_PATH "${_tmp_dir}") endif() +# Find and link pybind11 and fftw find_package(pybind11 CONFIG REQUIRED) find_package(FFTW REQUIRED COMPONENTS FLOAT_LIB DOUBLE_LIB) link_libraries(${FFTW_FLOAT_LIB} ${FFTW_DOUBLE_LIB}) +# Work out compiler flags set(CMAKE_POSITION_INDEPENDENT_CODE ON) add_compile_options(-Wall -O3 -funroll-loops) - set(FINUFFT_INCLUDE_DIRS ${CMAKE_CURRENT_LIST_DIR}/lib ${CMAKE_CURRENT_LIST_DIR}/vendor/finufft/include ${FFTW_INCLUDE_DIRS}) -message(STATUS "FINUFFT include dirs: " "${FINUFFT_INCLUDE_DIRS}") - +# Build single and double point versions of the FINUFFT library add_library(finufft STATIC ${CMAKE_CURRENT_LIST_DIR}/vendor/finufft/src/spreadinterp.cpp ${CMAKE_CURRENT_LIST_DIR}/vendor/finufft/src/utils.cpp @@ -44,6 +45,7 @@ add_library(finufft_32 STATIC target_compile_definitions(finufft_32 PUBLIC SINGLE) target_include_directories(finufft_32 PRIVATE ${FINUFFT_INCLUDE_DIRS}) +# Build the XLA bindings to those libraries pybind11_add_module(jax_finufft ${CMAKE_CURRENT_LIST_DIR}/lib/jax_finufft.cc ${CMAKE_CURRENT_LIST_DIR}/vendor/finufft/src/utils_precindep.cpp @@ -51,3 +53,12 @@ pybind11_add_module(jax_finufft target_link_libraries(jax_finufft PRIVATE finufft finufft_32) target_include_directories(jax_finufft PRIVATE ${FINUFFT_INCLUDE_DIRS}) install(TARGETS jax_finufft DESTINATION .) + +include(CheckLanguage) +check_language(CUDA) +if (CMAKE_CUDA_COMPILER) + enable_language(CUDA) + set(CUFINUFFT_INCLUDE_DIRS ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) +else() + message(STATUS "No CUDA compiler found; GPU support will be disabled") +endif() diff --git a/FindFFTW.cmake b/cmake/FindFFTW.cmake similarity index 100% rename from FindFFTW.cmake rename to cmake/FindFFTW.cmake diff --git a/vendor/cufinufft b/vendor/cufinufft new file mode 160000 index 0000000..21053ea --- /dev/null +++ b/vendor/cufinufft @@ -0,0 +1 @@ +Subproject commit 21053eaf96a98f23a4ede2e2d57bba58a0792d98 From 2e827c0c3e1206490afed9f0cf26d7b4587be90a Mon Sep 17 00:00:00 2001 From: Dan F-M Date: Mon, 8 Nov 2021 15:04:26 -0500 Subject: [PATCH 02/12] include dirs for cuda --- CMakeLists.txt | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index c649728..6770b9a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -58,7 +58,10 @@ include(CheckLanguage) check_language(CUDA) if (CMAKE_CUDA_COMPILER) enable_language(CUDA) - set(CUFINUFFT_INCLUDE_DIRS ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) + set(CUFINUFFT_INCLUDE_DIRS + ${CMAKE_CURRENT_LIST_DIR}/lib + ${CMAKE_CURRENT_LIST_DIR}/vendor/finufft/include + ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) else() message(STATUS "No CUDA compiler found; GPU support will be disabled") endif() From 94c5bbf07d7067fefe0af7af358cfe010ff589b5 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Tue, 9 Nov 2021 15:20:59 -0500 Subject: [PATCH 03/12] getting cufinufft to compile --- CMakeLists.txt | 53 ++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 51 insertions(+), 2 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 6770b9a..6a1f522 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -14,6 +14,13 @@ if(SKBUILD) OUTPUT_VARIABLE _tmp_dir OUTPUT_STRIP_TRAILING_WHITESPACE COMMAND_ECHO STDOUT) list(APPEND CMAKE_PREFIX_PATH "${_tmp_dir}") +else() + find_package(Python COMPONENTS Interpreter Development REQUIRED) + execute_process( + COMMAND "${Python_EXECUTABLE}" -c "import pybind11; print(pybind11.get_cmake_dir())" + OUTPUT_VARIABLE _tmp_dir + OUTPUT_STRIP_TRAILING_WHITESPACE COMMAND_ECHO STDOUT) + list(APPEND CMAKE_PREFIX_PATH "${_tmp_dir}") endif() # Find and link pybind11 and fftw @@ -23,7 +30,7 @@ link_libraries(${FFTW_FLOAT_LIB} ${FFTW_DOUBLE_LIB}) # Work out compiler flags set(CMAKE_POSITION_INDEPENDENT_CODE ON) -add_compile_options(-Wall -O3 -funroll-loops) +add_compile_options(-Wall -Wno-unknown-pragmas -O3 -funroll-loops) set(FINUFFT_INCLUDE_DIRS ${CMAKE_CURRENT_LIST_DIR}/lib ${CMAKE_CURRENT_LIST_DIR}/vendor/finufft/include @@ -58,10 +65,52 @@ include(CheckLanguage) check_language(CUDA) if (CMAKE_CUDA_COMPILER) enable_language(CUDA) + set(CMAKE_CUDA_SEPARABLE_COMPILATION ON) + if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES) + set(CMAKE_CUDA_ARCHITECTURES "52;60;61;70;75") + endif() + set(CUFINUFFT_INCLUDE_DIRS ${CMAKE_CURRENT_LIST_DIR}/lib - ${CMAKE_CURRENT_LIST_DIR}/vendor/finufft/include + ${CMAKE_CURRENT_LIST_DIR}/vendor/cufinufft/include + ${CMAKE_CURRENT_LIST_DIR}/vendor/cufinufft/contrib/cuda_samples ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) + + set(CUFINUFFT_SOURCES + ${CMAKE_CURRENT_LIST_DIR}/vendor/cufinufft/src/2d/spreadinterp2d.cu + ${CMAKE_CURRENT_LIST_DIR}/vendor/cufinufft/src/2d/spread2d_wrapper.cu + ${CMAKE_CURRENT_LIST_DIR}/vendor/cufinufft/src/2d/spread2d_wrapper_paul.cu + ${CMAKE_CURRENT_LIST_DIR}/vendor/cufinufft/src/2d/interp2d_wrapper.cu + ${CMAKE_CURRENT_LIST_DIR}/vendor/cufinufft/src/2d/cufinufft2d.cu + + ${CMAKE_CURRENT_LIST_DIR}/vendor/cufinufft/src/3d/spreadinterp3d.cu + ${CMAKE_CURRENT_LIST_DIR}/vendor/cufinufft/src/3d/spread3d_wrapper.cu + ${CMAKE_CURRENT_LIST_DIR}/vendor/cufinufft/src/3d/interp3d_wrapper.cu + ${CMAKE_CURRENT_LIST_DIR}/vendor/cufinufft/src/3d/cufinufft3d.cu + + ${CMAKE_CURRENT_LIST_DIR}/vendor/cufinufft/src/memtransfer_wrapper.cu + ${CMAKE_CURRENT_LIST_DIR}/vendor/cufinufft/src/deconvolve_wrapper.cu + ${CMAKE_CURRENT_LIST_DIR}/vendor/cufinufft/src/cufinufft.cu + + ${CMAKE_CURRENT_LIST_DIR}/vendor/cufinufft/contrib/dirft2d.cpp + ${CMAKE_CURRENT_LIST_DIR}/vendor/cufinufft/contrib/common.cpp + ${CMAKE_CURRENT_LIST_DIR}/vendor/cufinufft/contrib/spreadinterp.cpp + ${CMAKE_CURRENT_LIST_DIR}/vendor/cufinufft/contrib/utils_fp.cpp) + + add_library(cufinufft STATIC + ${CMAKE_CURRENT_LIST_DIR}/vendor/cufinufft/src/precision_independent.cu + ${CMAKE_CURRENT_LIST_DIR}/vendor/cufinufft/src/profile.cu + ${CMAKE_CURRENT_LIST_DIR}/vendor/cufinufft/contrib/legendre_rule_fast.c + ${CMAKE_CURRENT_LIST_DIR}/vendor/cufinufft/contrib/utils.cpp) + target_include_directories(cufinufft PRIVATE ${CUFINUFFT_INCLUDE_DIRS}) + + add_library(cufinufft_64 STATIC ${CUFINUFFT_SOURCES}) + target_include_directories(cufinufft_64 PRIVATE ${CUFINUFFT_INCLUDE_DIRS}) + + add_library(cufinufft_32 STATIC ${CUFINUFFT_SOURCES}) + target_compile_definitions(cufinufft_32 PUBLIC SINGLE) + target_include_directories(cufinufft_32 PRIVATE ${CUFINUFFT_INCLUDE_DIRS}) + else() message(STATUS "No CUDA compiler found; GPU support will be disabled") endif() From 974b0e94cc3fe79f1f0106343f259440d88f87e8 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Wed, 10 Nov 2021 12:43:17 -0500 Subject: [PATCH 04/12] adding first pass at gpu kernels --- CMakeLists.txt | 18 ++-- lib/{jax_finufft.cc => jax_finufft_cpu.cc} | 2 +- lib/jax_finufft_gpu.cc | 36 ++++++++ lib/kernels.cc.cu | 95 ++++++++++++++++++++++ lib/kernels.h | 24 ++++++ src/jax_finufft/ops.py | 6 +- 6 files changed, 172 insertions(+), 9 deletions(-) rename lib/{jax_finufft.cc => jax_finufft_cpu.cc} (98%) create mode 100644 lib/jax_finufft_gpu.cc create mode 100644 lib/kernels.cc.cu create mode 100644 lib/kernels.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 6a1f522..75a221b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -53,13 +53,13 @@ target_compile_definitions(finufft_32 PUBLIC SINGLE) target_include_directories(finufft_32 PRIVATE ${FINUFFT_INCLUDE_DIRS}) # Build the XLA bindings to those libraries -pybind11_add_module(jax_finufft - ${CMAKE_CURRENT_LIST_DIR}/lib/jax_finufft.cc +pybind11_add_module(jax_finufft_cpu + ${CMAKE_CURRENT_LIST_DIR}/lib/jax_finufft_cpu.cc ${CMAKE_CURRENT_LIST_DIR}/vendor/finufft/src/utils_precindep.cpp ${CMAKE_CURRENT_LIST_DIR}/vendor/finufft/contrib/legendre_rule_fast.c) -target_link_libraries(jax_finufft PRIVATE finufft finufft_32) -target_include_directories(jax_finufft PRIVATE ${FINUFFT_INCLUDE_DIRS}) -install(TARGETS jax_finufft DESTINATION .) +target_link_libraries(jax_finufft_cpu PRIVATE finufft finufft_32) +target_include_directories(jax_finufft_cpu PRIVATE ${FINUFFT_INCLUDE_DIRS}) +install(TARGETS jax_finufft_cpu DESTINATION .) include(CheckLanguage) check_language(CUDA) @@ -111,6 +111,14 @@ if (CMAKE_CUDA_COMPILER) target_compile_definitions(cufinufft_32 PUBLIC SINGLE) target_include_directories(cufinufft_32 PRIVATE ${CUFINUFFT_INCLUDE_DIRS}) + pybind11_add_module(jax_finufft_gpu + ${CMAKE_CURRENT_LIST_DIR}/lib/jax_finufft_gpu.cc + ${CMAKE_CURRENT_LIST_DIR}/lib/kernels.cc.cu) + target_link_libraries(jax_finufft_gpu PRIVATE cufinufft cufinufft_32 cufinufft_64) + target_include_directories(jax_finufft_gpu PRIVATE ${FINUFFT_INCLUDE_DIRS}) + install(TARGETS jax_finufft_gpu DESTINATION .) + + else() message(STATUS "No CUDA compiler found; GPU support will be disabled") endif() diff --git a/lib/jax_finufft.cc b/lib/jax_finufft_cpu.cc similarity index 98% rename from lib/jax_finufft.cc rename to lib/jax_finufft_cpu.cc index fb847dd..3c44f08 100644 --- a/lib/jax_finufft.cc +++ b/lib/jax_finufft_cpu.cc @@ -85,7 +85,7 @@ pybind11::dict Registrations() { return dict; } -PYBIND11_MODULE(jax_finufft, m) { +PYBIND11_MODULE(jax_finufft_cpu, m) { m.def("registrations", &Registrations); m.def("build_descriptorf", &build_descriptor); m.def("build_descriptor", &build_descriptor); diff --git a/lib/jax_finufft_gpu.cc b/lib/jax_finufft_gpu.cc new file mode 100644 index 0000000..31ecad0 --- /dev/null +++ b/lib/jax_finufft_gpu.cc @@ -0,0 +1,36 @@ +// This file defines the Python interface to the XLA custom call implemented on the CPU. +// It is exposed as a standard pybind11 module defining "capsule" objects containing our +// method. For simplicity, we export a separate capsule for each supported dtype. + +#include "pybind11_kernel_helpers.h" +#include "kernels.h" + +using namespace jax_finufft; + +namespace { + +pybind11::dict Registrations() { + pybind11::dict dict; + + // dict["nufft1d1f"] = encapsulate_function(nufft1d1f); + // dict["nufft1d2f"] = encapsulate_function(nufft1d2f); + dict["nufft2d1f"] = encapsulate_function(nufft2d1f); + dict["nufft2d2f"] = encapsulate_function(nufft2d2f); + dict["nufft3d1f"] = encapsulate_function(nufft3d1f); + dict["nufft3d2f"] = encapsulate_function(nufft3d2f); + + // dict["nufft1d1"] = encapsulate_function(nufft1d1); + // dict["nufft1d2"] = encapsulate_function(nufft1d2); + dict["nufft2d1"] = encapsulate_function(nufft2d1); + dict["nufft2d2"] = encapsulate_function(nufft2d2); + dict["nufft3d1"] = encapsulate_function(nufft3d1); + dict["nufft3d2"] = encapsulate_function(nufft3d2); + + return dict; +} + +PYBIND11_MODULE(jax_finufft_gpu, m) { + m.def("registrations", &Registrations); +} + +} // namespace diff --git a/lib/kernels.cc.cu b/lib/kernels.cc.cu new file mode 100644 index 0000000..0f76b58 --- /dev/null +++ b/lib/kernels.cc.cu @@ -0,0 +1,95 @@ +#include "jax_finufft.h" +#include "kernels.h" +#include "kernel_helpers.h" + +namespace jax_finufft { + +void ThrowIfError(cudaError_t error) { + if (error != cudaSuccess) { + throw std::runtime_error(cudaGetErrorString(error)); + } +} + +template +void nufft1(cudaStream_t stream, void** buffers, const char* opaque, std::size_t opaque_len) { + const NufftDescriptor *descriptor = unpack_descriptor>(opaque, opaque_len); + + std::complex *c = reinterpret_cast *>(buffers[0]); + T *x = reinterpret_cast(buffers[1]); + T *y = NULL; + T *z = NULL; + int out_dim = 2; + if (ndim > 1) { + y = reinterpret_cast(buffers[2]); + out_dim = 3; + } + if (ndim > 2) { + z = reinterpret_cast(buffers[3]); + out_dim = 4; + } + std::complex *F = reinterpret_cast *>(buffers[out_dim]); + + // Call cuFINUFFT here... + // run_nufft(1, in[0], x, y, z, c, F); + + ThrowIfError(cudaGetLastError()); +} + +template +void nufft2(cudaStream_t stream, void** buffers, const char* opaque, std::size_t opaque_len) { + const NufftDescriptor *descriptor = unpack_descriptor>(opaque, opaque_len); + + std::complex *c = reinterpret_cast *>(buffers[0]); + T *x = reinterpret_cast(buffers[1]); + T *y = NULL; + T *z = NULL; + int out_dim = 2; + if (ndim > 1) { + y = reinterpret_cast(buffers[2]); + out_dim = 3; + } + if (ndim > 2) { + z = reinterpret_cast(buffers[3]); + out_dim = 4; + } + std::complex *F = reinterpret_cast *>(buffers[out_dim]); + + // Call cuFINUFFT here... + // run_nufft(1, in[0], x, y, z, c, F); + + ThrowIfError(cudaGetLastError()); +} + +void nufft2d1(cudaStream_t stream, void** buffers, const char* opaque, std::size_t opaque_len) { + nufft1<2, double>(stream, buffers, opaque, opaque_len); +} + +void nufft2d2(cudaStream_t stream, void** buffers, const char* opaque, std::size_t opaque_len) { + nufft2<2, double>(stream, buffers, opaque, opaque_len); +} + +void nufft3d1(cudaStream_t stream, void** buffers, const char* opaque, std::size_t opaque_len) { + nufft1<3, double>(stream, buffers, opaque, opaque_len); +} + +void nufft3d2(cudaStream_t stream, void** buffers, const char* opaque, std::size_t opaque_len) { + nufft2<3, double>(stream, buffers, opaque, opaque_len); +} + +void nufft2d1f(cudaStream_t stream, void** buffers, const char* opaque, std::size_t opaque_len) { + nufft1<2, float>(stream, buffers, opaque, opaque_len); +} + +void nufft2d2f(cudaStream_t stream, void** buffers, const char* opaque, std::size_t opaque_len) { + nufft2<2, float>(stream, buffers, opaque, opaque_len); +} + +void nufft3d1f(cudaStream_t stream, void** buffers, const char* opaque, std::size_t opaque_len) { + nufft1<3, float>(stream, buffers, opaque, opaque_len); +} + +void nufft3d2f(cudaStream_t stream, void** buffers, const char* opaque, std::size_t opaque_len) { + nufft2<3, float>(stream, buffers, opaque, opaque_len); +} + +} \ No newline at end of file diff --git a/lib/kernels.h b/lib/kernels.h new file mode 100644 index 0000000..4318e49 --- /dev/null +++ b/lib/kernels.h @@ -0,0 +1,24 @@ +#ifndef _JAX_FINUFFT_KERNELS_H_ +#define _JAX_FINUFFT_KERNELS_H_ + +#include + +#include +#include + + +namespace jax_finufft { + +void nufft2d1(cudaStream_t stream, void** buffers, const char* opaque, std::size_t opaque_len); +void nufft2d2(cudaStream_t stream, void** buffers, const char* opaque, std::size_t opaque_len); +void nufft3d1(cudaStream_t stream, void** buffers, const char* opaque, std::size_t opaque_len); +void nufft3d2(cudaStream_t stream, void** buffers, const char* opaque, std::size_t opaque_len); + +void nufft2d1f(cudaStream_t stream, void** buffers, const char* opaque, std::size_t opaque_len); +void nufft2d2f(cudaStream_t stream, void** buffers, const char* opaque, std::size_t opaque_len); +void nufft3d1f(cudaStream_t stream, void** buffers, const char* opaque, std::size_t opaque_len); +void nufft3d2f(cudaStream_t stream, void** buffers, const char* opaque, std::size_t opaque_len); + +} // namespace jax_finufft + +#endif \ No newline at end of file diff --git a/src/jax_finufft/ops.py b/src/jax_finufft/ops.py index 87dc832..75834a1 100644 --- a/src/jax_finufft/ops.py +++ b/src/jax_finufft/ops.py @@ -9,9 +9,9 @@ from jax.interpreters import ad, batching, xla from jax.lib import xla_client -from . import jax_finufft +from . import jax_finufft_cpu -for _name, _value in jax_finufft.registrations().items(): +for _name, _value in jax_finufft_cpu.registrations().items(): xla_client.register_cpu_custom_call_target(_name, _value) xops = xla_client.ops @@ -147,7 +147,7 @@ def translation_rule( # Dispatch to the right op suffix = "f" if source_dtype == np.csingle else "" op_name = f"nufft{ndim}d{type_}{suffix}".encode("ascii") - desc = getattr(jax_finufft, f"build_descriptor{suffix}")( + desc = getattr(jax_finufft_cpu, f"build_descriptor{suffix}")( eps, iflag, n_tot, n_transf, n_j, *n_k_full ) From 9ffc0061606a3e2eed39fb8bea1e473ca46c638c Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Wed, 10 Nov 2021 12:46:57 -0500 Subject: [PATCH 05/12] order of parameters --- lib/kernels.cc.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/kernels.cc.cu b/lib/kernels.cc.cu index 0f76b58..ea9370a 100644 --- a/lib/kernels.cc.cu +++ b/lib/kernels.cc.cu @@ -39,7 +39,7 @@ template void nufft2(cudaStream_t stream, void** buffers, const char* opaque, std::size_t opaque_len) { const NufftDescriptor *descriptor = unpack_descriptor>(opaque, opaque_len); - std::complex *c = reinterpret_cast *>(buffers[0]); + std::complex *F = reinterpret_cast *>(buffers[0]); T *x = reinterpret_cast(buffers[1]); T *y = NULL; T *z = NULL; @@ -52,7 +52,7 @@ void nufft2(cudaStream_t stream, void** buffers, const char* opaque, std::size_t z = reinterpret_cast(buffers[3]); out_dim = 4; } - std::complex *F = reinterpret_cast *>(buffers[out_dim]); + std::complex *c = reinterpret_cast *>(buffers[out_dim]); // Call cuFINUFFT here... // run_nufft(1, in[0], x, y, z, c, F); From 9473643b07810fdd61bb54ee9d221f76bb5b40f8 Mon Sep 17 00:00:00 2001 From: Lehman Garrison Date: Fri, 12 Nov 2021 09:08:42 -0500 Subject: [PATCH 06/12] Minor refactoring to support GPU --- .gitignore | 2 + CMakeLists.txt | 2 +- lib/jax_finufft_common.h | 21 ++++ lib/jax_finufft_cpu.cc | 2 + lib/{jax_finufft.h => jax_finufft_cpu.h} | 10 -- lib/jax_finufft_gpu.cc | 1 + lib/jax_finufft_gpu.h | 129 +++++++++++++++++++++++ lib/kernel_helpers.h | 2 + lib/kernels.cc.cu | 3 +- lib/pybind11_kernel_helpers.h | 1 - 10 files changed, 160 insertions(+), 13 deletions(-) create mode 100644 lib/jax_finufft_common.h rename lib/{jax_finufft.h => jax_finufft_cpu.h} (95%) create mode 100644 lib/jax_finufft_gpu.h diff --git a/.gitignore b/.gitignore index 32539c6..471afc1 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,5 @@ build _skbuild dist MANIFEST +__pycache__ +*.egg-info diff --git a/CMakeLists.txt b/CMakeLists.txt index 75a221b..f3bf28e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -115,7 +115,7 @@ if (CMAKE_CUDA_COMPILER) ${CMAKE_CURRENT_LIST_DIR}/lib/jax_finufft_gpu.cc ${CMAKE_CURRENT_LIST_DIR}/lib/kernels.cc.cu) target_link_libraries(jax_finufft_gpu PRIVATE cufinufft cufinufft_32 cufinufft_64) - target_include_directories(jax_finufft_gpu PRIVATE ${FINUFFT_INCLUDE_DIRS}) + target_include_directories(jax_finufft_gpu PRIVATE ${CUFINUFFT_INCLUDE_DIRS}) install(TARGETS jax_finufft_gpu DESTINATION .) diff --git a/lib/jax_finufft_common.h b/lib/jax_finufft_common.h new file mode 100644 index 0000000..c8eee89 --- /dev/null +++ b/lib/jax_finufft_common.h @@ -0,0 +1,21 @@ +#ifndef _JAX_FINUFFT_COMMON_H_ +#define _JAX_FINUFFT_COMMON_H_ + +// This descriptor is common to both the jax_finufft and jax_finufft_gpu modules +// We will use the jax_finufft namespace for both + +namespace jax_finufft { + +template +struct NufftDescriptor { + T eps; + int iflag; + int64_t n_tot; + int n_transf; + int64_t n_j; + int64_t n_k[3]; +}; + +} + +#endif diff --git a/lib/jax_finufft_cpu.cc b/lib/jax_finufft_cpu.cc index 3c44f08..8c538ef 100644 --- a/lib/jax_finufft_cpu.cc +++ b/lib/jax_finufft_cpu.cc @@ -4,6 +4,8 @@ #include "pybind11_kernel_helpers.h" +#include "jax_finufft_cpu.h" + using namespace jax_finufft; namespace { diff --git a/lib/jax_finufft.h b/lib/jax_finufft_cpu.h similarity index 95% rename from lib/jax_finufft.h rename to lib/jax_finufft_cpu.h index 0456706..2157dbc 100644 --- a/lib/jax_finufft.h +++ b/lib/jax_finufft_cpu.h @@ -7,16 +7,6 @@ namespace jax_finufft { -template -struct NufftDescriptor { - T eps; - int iflag; - int64_t n_tot; - int n_transf; - int64_t n_j; - int64_t n_k[3]; -}; - template struct plan_type; diff --git a/lib/jax_finufft_gpu.cc b/lib/jax_finufft_gpu.cc index 31ecad0..8547430 100644 --- a/lib/jax_finufft_gpu.cc +++ b/lib/jax_finufft_gpu.cc @@ -3,6 +3,7 @@ // method. For simplicity, we export a separate capsule for each supported dtype. #include "pybind11_kernel_helpers.h" +#include "jax_finufft_gpu.h" #include "kernels.h" using namespace jax_finufft; diff --git a/lib/jax_finufft_gpu.h b/lib/jax_finufft_gpu.h new file mode 100644 index 0000000..178034f --- /dev/null +++ b/lib/jax_finufft_gpu.h @@ -0,0 +1,129 @@ +#ifndef _JAX_FINUFFT_H_ +#define _JAX_FINUFFT_H_ + +#include + +#include "cufinufft.h" + +namespace jax_finufft { + +template +struct plan_type; + +template <> +struct plan_type { + typedef cufinufft_plan type; +}; +/* +template <> +struct plan_type { + typedef finufftf_plan type; +}; + +template +void default_opts(nufft_opts* opts); + +template <> +void default_opts(nufft_opts* opts) { + finufftf_default_opts(opts); +} + +template <> +void default_opts(nufft_opts* opts) { + finufft_default_opts(opts); +} + +template +int makeplan(int type, int dim, int64_t* nmodes, int iflag, int ntr, T eps, + typename plan_type::type* plan, nufft_opts* opts); + +template <> +int makeplan(int type, int dim, int64_t* nmodes, int iflag, int ntr, float eps, + typename plan_type::type* plan, nufft_opts* opts) { + return finufftf_makeplan(type, dim, nmodes, iflag, ntr, eps, plan, opts); +} + +template <> +int makeplan(int type, int dim, int64_t* nmodes, int iflag, int ntr, double eps, + typename plan_type::type* plan, nufft_opts* opts) { + return finufft_makeplan(type, dim, nmodes, iflag, ntr, eps, plan, opts); +} + +template +int setpts(typename plan_type::type plan, int64_t M, T* x, T* y, T* z, int64_t N, T* s, T* t, + T* u); + +template <> +int setpts(typename plan_type::type plan, int64_t M, float* x, float* y, float* z, + int64_t N, float* s, float* t, float* u) { + return finufftf_setpts(plan, M, x, y, z, N, s, t, u); +} + +template <> +int setpts(typename plan_type::type plan, int64_t M, double* x, double* y, + double* z, int64_t N, double* s, double* t, double* u) { + return finufft_setpts(plan, M, x, y, z, N, s, t, u); +} + +template +int execute(typename plan_type::type plan, std::complex* c, std::complex* f); + +template <> +int execute(typename plan_type::type plan, std::complex* c, + std::complex* f) { + return finufftf_execute(plan, c, f); +} + +template <> +int execute(typename plan_type::type plan, std::complex* c, + std::complex* f) { + return finufft_execute(plan, c, f); +} + +template +void destroy(typename plan_type::type plan); + +template <> +void destroy(typename plan_type::type plan) { + finufftf_destroy(plan); +} + +template <> +void destroy(typename plan_type::type plan) { + finufft_destroy(plan); +} + +template +T* y_index(T* y, int64_t index) { + return &(y[index]); +} + +template <> +double* y_index<1, double>(double* y, int64_t index) { + return NULL; +} + +template <> +float* y_index<1, float>(float* y, int64_t index) { + return NULL; +} + +template +T* z_index(T* z, int64_t index) { + return NULL; +} + +template <> +double* z_index<3, double>(double* z, int64_t index) { + return &(z[index]); +} + +template <> +float* z_index<3, float>(float* z, int64_t index) { + return &(z[index]); +} +*/ + +} // namespace jax_finufft + +#endif diff --git a/lib/kernel_helpers.h b/lib/kernel_helpers.h index 766a4bf..c0daac2 100644 --- a/lib/kernel_helpers.h +++ b/lib/kernel_helpers.h @@ -11,6 +11,8 @@ #include #include +#include "jax_finufft_common.h" + namespace jax_finufft { // https://en.cppreference.com/w/cpp/numeric/bit_cast diff --git a/lib/kernels.cc.cu b/lib/kernels.cc.cu index ea9370a..4062cd5 100644 --- a/lib/kernels.cc.cu +++ b/lib/kernels.cc.cu @@ -1,4 +1,4 @@ -#include "jax_finufft.h" +#include "jax_finufft_gpu.h" #include "kernels.h" #include "kernel_helpers.h" @@ -10,6 +10,7 @@ void ThrowIfError(cudaError_t error) { } } + template void nufft1(cudaStream_t stream, void** buffers, const char* opaque, std::size_t opaque_len) { const NufftDescriptor *descriptor = unpack_descriptor>(opaque, opaque_len); diff --git a/lib/pybind11_kernel_helpers.h b/lib/pybind11_kernel_helpers.h index d227482..11147cd 100644 --- a/lib/pybind11_kernel_helpers.h +++ b/lib/pybind11_kernel_helpers.h @@ -9,7 +9,6 @@ #include -#include "jax_finufft.h" #include "kernel_helpers.h" namespace jax_finufft { From 1b12da0bb44f02a516f59d666de24c4b0fd302c3 Mon Sep 17 00:00:00 2001 From: Lehman Garrison Date: Fri, 12 Nov 2021 11:38:04 -0500 Subject: [PATCH 07/12] Maybe sort-of calling all the right functions? --- lib/jax_finufft_gpu.h | 51 +++++++++++++++++++++++-------------------- lib/kernels.cc.cu | 33 ++++++++++++++++++++++++++-- 2 files changed, 58 insertions(+), 26 deletions(-) diff --git a/lib/jax_finufft_gpu.h b/lib/jax_finufft_gpu.h index 178034f..8858b05 100644 --- a/lib/jax_finufft_gpu.h +++ b/lib/jax_finufft_gpu.h @@ -1,5 +1,5 @@ -#ifndef _JAX_FINUFFT_H_ -#define _JAX_FINUFFT_H_ +#ifndef _JAX_FINUFFT_GPU_H_ +#define _JAX_FINUFFT_GPU_H_ #include @@ -14,39 +14,39 @@ template <> struct plan_type { typedef cufinufft_plan type; }; -/* + template <> struct plan_type { - typedef finufftf_plan type; + typedef cufinufftf_plan type; }; template -void default_opts(nufft_opts* opts); +void default_opts(int type, int dim, cufinufft_opts* opts); template <> -void default_opts(nufft_opts* opts) { - finufftf_default_opts(opts); +void default_opts(int type, int dim, cufinufft_opts* opts) { + cufinufftf_default_opts(type, dim, opts); } template <> -void default_opts(nufft_opts* opts) { - finufft_default_opts(opts); +void default_opts(int type, int dim, cufinufft_opts* opts) { + cufinufft_default_opts(type, dim, opts); } template -int makeplan(int type, int dim, int64_t* nmodes, int iflag, int ntr, T eps, - typename plan_type::type* plan, nufft_opts* opts); +int makeplan(int type, int dim, int* nmodes, int iflag, int ntr, T eps, int batch, + typename plan_type::type* plan, cufinufft_opts* opts); template <> -int makeplan(int type, int dim, int64_t* nmodes, int iflag, int ntr, float eps, - typename plan_type::type* plan, nufft_opts* opts) { - return finufftf_makeplan(type, dim, nmodes, iflag, ntr, eps, plan, opts); +int makeplan(int type, int dim, int* nmodes, int iflag, int ntr, float eps, int batch, + typename plan_type::type* plan, cufinufft_opts* opts) { + return cufinufftf_makeplan(type, dim, nmodes, iflag, ntr, eps, batch, plan, opts); } template <> -int makeplan(int type, int dim, int64_t* nmodes, int iflag, int ntr, double eps, - typename plan_type::type* plan, nufft_opts* opts) { - return finufft_makeplan(type, dim, nmodes, iflag, ntr, eps, plan, opts); +int makeplan(int type, int dim, int* nmodes, int iflag, int ntr, double eps, int batch, + typename plan_type::type* plan, cufinufft_opts* opts) { + return cufinufft_makeplan(type, dim, nmodes, iflag, ntr, eps, batch, plan, opts); } template @@ -56,13 +56,13 @@ int setpts(typename plan_type::type plan, int64_t M, T* x, T* y, T* z, int64_ template <> int setpts(typename plan_type::type plan, int64_t M, float* x, float* y, float* z, int64_t N, float* s, float* t, float* u) { - return finufftf_setpts(plan, M, x, y, z, N, s, t, u); + return cufinufftf_setpts(M, x, y, z, N, s, t, u, plan); } template <> int setpts(typename plan_type::type plan, int64_t M, double* x, double* y, double* z, int64_t N, double* s, double* t, double* u) { - return finufft_setpts(plan, M, x, y, z, N, s, t, u); + return cufinufft_setpts(M, x, y, z, N, s, t, u, plan); } template @@ -71,13 +71,17 @@ int execute(typename plan_type::type plan, std::complex* c, std::complex int execute(typename plan_type::type plan, std::complex* c, std::complex* f) { - return finufftf_execute(plan, c, f); + cuFloatComplex* _c = reinterpret_cast(c); + cuFloatComplex* _f = reinterpret_cast(f); + return cufinufftf_execute(_c, _f, plan); } template <> int execute(typename plan_type::type plan, std::complex* c, std::complex* f) { - return finufft_execute(plan, c, f); + cuDoubleComplex* _c = reinterpret_cast(c); + cuDoubleComplex* _f = reinterpret_cast(f); + return cufinufft_execute(_c, _f, plan); } template @@ -85,12 +89,12 @@ void destroy(typename plan_type::type plan); template <> void destroy(typename plan_type::type plan) { - finufftf_destroy(plan); + cufinufftf_destroy(plan); } template <> void destroy(typename plan_type::type plan) { - finufft_destroy(plan); + cufinufft_destroy(plan); } template @@ -122,7 +126,6 @@ template <> float* z_index<3, float>(float* z, int64_t index) { return &(z[index]); } -*/ } // namespace jax_finufft diff --git a/lib/kernels.cc.cu b/lib/kernels.cc.cu index 4062cd5..1eb28b1 100644 --- a/lib/kernels.cc.cu +++ b/lib/kernels.cc.cu @@ -9,6 +9,35 @@ void ThrowIfError(cudaError_t error) { throw std::runtime_error(cudaGetErrorString(error)); } } + +template +void run_nufft(int type, const NufftDescriptor* descriptor, T *x, T *y, T *z, std::complex *c, std::complex *F) { + int64_t n_k = 1; + for (int d = 0; d < ndim; ++d) n_k *= descriptor->n_k[d]; + + // TODO: okay to stack-allocate this? + int nmodes32[ndim]; + for (int d = 0; d < ndim; ++d) nmodes32[d] = static_cast(descriptor->n_k[d]); + + // TODO: does this need to be part of NufftDescriptor? It's GPU-specific. + int maxbatchsize = 0; // auto + cufinufft_opts *opts = new cufinufft_opts; + typename plan_type::type plan; + default_opts(type, ndim, opts); + makeplan(type, ndim, nmodes32, descriptor->iflag, + descriptor->n_transf, descriptor->eps, maxbatchsize, &plan, opts); + for (int64_t index = 0; index < descriptor->n_tot; ++index) { + int64_t j = index * descriptor->n_j * descriptor->n_transf; + int64_t k = index * n_k * descriptor->n_transf; + + setpts(plan, descriptor->n_j, &(x[j]), y_index(y, j), z_index(z, j), 0, + NULL, NULL, NULL); + + execute(plan, &c[j], &F[k]); + } + destroy(plan); + delete opts; +} template @@ -31,7 +60,7 @@ void nufft1(cudaStream_t stream, void** buffers, const char* opaque, std::size_t std::complex *F = reinterpret_cast *>(buffers[out_dim]); // Call cuFINUFFT here... - // run_nufft(1, in[0], x, y, z, c, F); + run_nufft(1, descriptor, x, y, z, c, F); ThrowIfError(cudaGetLastError()); } @@ -56,7 +85,7 @@ void nufft2(cudaStream_t stream, void** buffers, const char* opaque, std::size_t std::complex *c = reinterpret_cast *>(buffers[out_dim]); // Call cuFINUFFT here... - // run_nufft(1, in[0], x, y, z, c, F); + run_nufft(1, descriptor, x, y, z, c, F); ThrowIfError(cudaGetLastError()); } From fe79e974b0227b62ce5eebe1b78b78528b0a21d4 Mon Sep 17 00:00:00 2001 From: Lehman Garrison Date: Mon, 15 Nov 2021 14:09:45 -0500 Subject: [PATCH 08/12] Add FindCUDAToolkit to cmake to bring in cufft --- CMakeLists.txt | 3 +++ tests/gpu_ops_test.py | 3 +++ 2 files changed, 6 insertions(+) create mode 100644 tests/gpu_ops_test.py diff --git a/CMakeLists.txt b/CMakeLists.txt index f3bf28e..f4b3fd3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -69,6 +69,8 @@ if (CMAKE_CUDA_COMPILER) if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES) set(CMAKE_CUDA_ARCHITECTURES "52;60;61;70;75") endif() + + find_package(CUDAToolkit) set(CUFINUFFT_INCLUDE_DIRS ${CMAKE_CURRENT_LIST_DIR}/lib @@ -115,6 +117,7 @@ if (CMAKE_CUDA_COMPILER) ${CMAKE_CURRENT_LIST_DIR}/lib/jax_finufft_gpu.cc ${CMAKE_CURRENT_LIST_DIR}/lib/kernels.cc.cu) target_link_libraries(jax_finufft_gpu PRIVATE cufinufft cufinufft_32 cufinufft_64) + target_link_libraries(jax_finufft_gpu PRIVATE ${CUDA_cufft_LIBRARY} ${CUDA_nvToolsExt_LIBRARY}) target_include_directories(jax_finufft_gpu PRIVATE ${CUFINUFFT_INCLUDE_DIRS}) install(TARGETS jax_finufft_gpu DESTINATION .) diff --git a/tests/gpu_ops_test.py b/tests/gpu_ops_test.py new file mode 100644 index 0000000..e02ba6b --- /dev/null +++ b/tests/gpu_ops_test.py @@ -0,0 +1,3 @@ +def test_import(): + from jax_finufft import jax_finufft_gpu + #print(vars(jax_finufft_gpu)) From 1295a1cadd63c5eff5111c86e2e486c7e89e59d5 Mon Sep 17 00:00:00 2001 From: Lehman Garrison Date: Mon, 15 Nov 2021 16:29:41 -0500 Subject: [PATCH 09/12] Trying to hook up Jax CUDA ops --- .github/workflows/tests.yml | 2 +- lib/jax_finufft_gpu.cc | 1 + src/jax_finufft/__init__.py | 3 +- src/jax_finufft/gpu_ops.py | 326 ++++++++++++++++++++++++++++++++++++ tests/gpu_ops_test.py | 204 +++++++++++++++++++++- 5 files changed, 531 insertions(+), 5 deletions(-) create mode 100644 src/jax_finufft/gpu_ops.py diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 480a99a..6d503d7 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -40,4 +40,4 @@ jobs: python -m pip install .[test] - name: Run tests - run: python -m pytest -v tests + run: python -m pytest -v tests --ignore='tests/gpu_ops_test.py' diff --git a/lib/jax_finufft_gpu.cc b/lib/jax_finufft_gpu.cc index 8547430..56cd704 100644 --- a/lib/jax_finufft_gpu.cc +++ b/lib/jax_finufft_gpu.cc @@ -13,6 +13,7 @@ namespace { pybind11::dict Registrations() { pybind11::dict dict; + // TODO: do we prefer to keep these names the same as the CPU version or prefix them with "cu"? // dict["nufft1d1f"] = encapsulate_function(nufft1d1f); // dict["nufft1d2f"] = encapsulate_function(nufft1d2f); dict["nufft2d1f"] = encapsulate_function(nufft2d1f); diff --git a/src/jax_finufft/__init__.py b/src/jax_finufft/__init__.py index 931e67b..edec140 100644 --- a/src/jax_finufft/__init__.py +++ b/src/jax_finufft/__init__.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -__all__ = ["__version__", "nufft1", "nufft2"] +__all__ = ["__version__", "nufft1", "nufft2", "cunufft1", "cunufft2"] from .jax_finufft_version import version as __version__ from .ops import nufft1, nufft2 +from .gpu_ops import cunufft1, cunufft2 diff --git a/src/jax_finufft/gpu_ops.py b/src/jax_finufft/gpu_ops.py new file mode 100644 index 0000000..70ba8bc --- /dev/null +++ b/src/jax_finufft/gpu_ops.py @@ -0,0 +1,326 @@ +__all__ = ["cunufft1", "cunufft2"] + +from functools import partial, reduce + +import numpy as np +from jax import core, dtypes, jit +from jax import numpy as jnp +from jax.abstract_arrays import ShapedArray +from jax.interpreters import ad, batching, xla +from jax.lib import xla_client + +from .jax_finufft_cpu import build_descriptor, build_descriptorf +from . import jax_finufft_gpu + +for _name, _value in jax_finufft_gpu.registrations().items(): + xla_client.register_custom_call_target(_name, _value, platform='CUDA') + +xops = xla_client.ops + + +@partial(jit, static_argnums=(0,), static_argnames=("iflag", "eps")) +def cunufft1(output_shape, source, *points, iflag=1, eps=1e-6): + iflag = int(iflag) + eps = float(eps) + ndim = len(points) + if not 1 <= ndim <= 3: + raise ValueError("Only 1-, 2-, and 3-dimensions are supported") + + # Support passing a scalar output_shape + output_shape = np.atleast_1d(output_shape).astype(np.int64) + if len(output_shape) != ndim: + raise ValueError(f"output_shape must have shape: ({ndim},)") + + # Handle broadcasting + expected_output_shape = source.shape[:-1] + tuple(output_shape) + source, points = pad_shapes(1, source, *points) + if points[0].shape[-1] != source.shape[-1]: + raise ValueError("The final dimension of 'source' must match 'points'") + + return jnp.reshape( + nufft1_p.bind(source, *points, output_shape=output_shape, iflag=iflag, eps=eps), + expected_output_shape, + ) + + +@partial(jit, static_argnames=("iflag", "eps")) +def cunufft2(source, *points, iflag=-1, eps=1e-6): + iflag = int(iflag) + eps = float(eps) + ndim = len(points) + if not 1 <= ndim <= 3: + raise ValueError("Only 1-, 2-, and 3-dimensions are supported") + + # Handle broadcasting + expected_output_shape = source.shape[:-ndim] + source, points = pad_shapes(ndim, source, *points) + expected_output_shape = expected_output_shape + (points[0].shape[-1],) + + return jnp.reshape( + nufft2_p.bind(source, *points, output_shape=None, iflag=iflag, eps=eps), + expected_output_shape, + ) + + +def get_output_shape(type_, source_shape, *points_shape, output_shape): + if type_ == 1: + ndim = len(points_shape) + assert len(output_shape) == ndim + assert len(points_shape[0]) >= 2 + assert all( + x[-1] == source_shape[-1] and x[:-1] == source_shape[:-2] + for x in points_shape + ) + return tuple(source_shape[:-1]) + tuple(output_shape) + + elif type_ == 2: + ndim = len(points_shape) + assert len(points_shape[0]) >= 2 + assert all(x[:-1] == source_shape[: -ndim - 1] for x in points_shape) + return tuple(source_shape[:-ndim]) + (points_shape[0][-1],) + + raise ValueError(f"Unsupported transformation type: {type_}") + + +def abstract_eval(type_, source, *points, output_shape, **_): + ndim = len(points) + assert 1 <= ndim <= 3 + + source_dtype = dtypes.canonicalize_dtype(source.dtype) + points_dtype = [dtypes.canonicalize_dtype(x.dtype) for x in points] + + # Check supported and consistent dtypes + single = source_dtype == np.csingle and all(x == np.single for x in points_dtype) + double = source_dtype == np.cdouble and all(x == np.double for x in points_dtype) + assert single or double + + return ShapedArray( + get_output_shape( + type_, source.shape, *(x.shape for x in points), output_shape=output_shape + ), + source_dtype, + ) + + +def translation_rule( + type_, ctx, avals_in, avals_out, source, *points, output_shape, iflag, eps +): + ndim = len(points) + assert 1 <= ndim <= 3 + + c = ctx.builder + source_shape_info = c.get_shape(source) + points_shape_info = list(map(c.get_shape, points)) + + # Check supported and consistent dtypes + source_dtype = source_shape_info.element_type() + single = source_dtype == np.csingle and all( + x.element_type() == np.single for x in points_shape_info + ) + double = source_dtype == np.cdouble and all( + x.element_type() == np.double for x in points_shape_info + ) + assert single or double + + # Check shapes + source_shape = source_shape_info.dimensions() + points_shape = tuple(x.dimensions() for x in points_shape_info) + full_output_shape = get_output_shape( + type_, source_shape, *points_shape, output_shape=output_shape + ) + + # Work out the other dimenstions of the problem + n_j = np.array(points_shape[0][-1]).astype(np.int64) + if type_ == 1: + n_tot = np.prod(source_shape[:-2]).astype(np.int64) + n_transf = np.array(source_shape[-2]).astype(np.int32) + n_k = np.array(full_output_shape[-ndim:], dtype=np.int64) + else: + n_tot = np.prod(source_shape[: -ndim - 1]).astype(np.int64) + n_transf = np.array(source_shape[-ndim - 1]).astype(np.int32) + n_k = np.array(source_shape[-ndim:], dtype=np.int64) + + # The backend expects the output shape in Fortran order so we'll just + # fake it here, by sending in n_k and x in the reverse order. + n_k_full = np.zeros(3, dtype=np.int64) + n_k_full[:ndim] = n_k[::-1] + + # Dispatch to the right op + suffix = "f" if source_dtype == np.csingle else "" + op_name = f"nufft{ndim}d{type_}{suffix}".encode("ascii") + desc = globals()[f"build_descriptor{suffix}"]( + eps, iflag, n_tot, n_transf, n_j, *n_k_full + ) + + return [ + xops.CustomCallWithLayout( + c, + op_name, + # The inputs: + operands=( + xops.ConstantLiteral(c, np.frombuffer(desc, dtype=np.uint8)), + source, + *points[::-1], # Reverse order because backend uses Fortran order + ), + # The input shapes: + operand_shapes_with_layout=( + xla_client.Shape.array_shape(np.dtype(np.uint8), (len(desc),), (0,)), + xla_client.Shape.array_shape( + source_dtype, + source_shape, + tuple(range(len(source_shape) - 1, -1, -1)), + ), + ) + + tuple( + xla_client.Shape.array_shape( + x.element_type(), + x.dimensions(), + tuple(range(len(x.dimensions()) - 1, -1, -1)), + ) + for x in points_shape_info[::-1] # Reverse order, again + ), + # The output shapes: + shape_with_layout=xla_client.Shape.array_shape( + source_dtype, + full_output_shape, + tuple(range(len(full_output_shape) - 1, -1, -1)), + ), + ) + ] + + +def jvp(type_, prim, args, tangents, *, output_shape, iflag, eps): + # Type 1: + # f_k = sum_j c_j * exp(iflag * i * k * x_j) + # df_k/dx_j = iflag * i * k * c_j * exp(iflag * i * k * x_j) + + # Type 2: + # c_j = sum_k f_k * exp(iflag * i * k * x_j) + # dc_j/dx_j = sum_k iflag * i * k * f_k * exp(iflag * i * k * x_j) + + source, *points = args + dsource, *dpoints = tangents + output = prim.bind(source, *points, output_shape=output_shape, iflag=iflag, eps=eps) + + # The JVP op can be written as a single transform of the same type with + output_tangents = [] + ndim = len(points) + scales = [] + arguments = [] + if type(dsource) is not ad.Zero: + if type_ == 1: + scales.append(jnp.ones_like(output)) + arguments.append(dsource) + else: + output_tangents.append( + prim.bind( + dsource, *points, output_shape=output_shape, iflag=iflag, eps=eps + ) + ) + + for dim, dx in enumerate(dpoints): + if type(dx) is ad.Zero: + continue + + n = output_shape[dim] if type_ == 1 else source.shape[-ndim + dim] + shape = np.ones(ndim, dtype=int) + shape[dim] = -1 + k = np.arange(-np.floor(n / 2), np.floor((n - 1) / 2 + 1)) + k = k.reshape(shape) + factor = 1j * iflag * k + + if type_ == 1: + scales.append(factor) + arguments.append(dx * source) + else: + scales.append(dx) + arguments.append(factor * source) + + if len(scales): + axis = -2 if type_ == 1 else -ndim - 1 + output_tangent = prim.bind( + jnp.concatenate(arguments, axis=axis), + *points, + output_shape=output_shape, + iflag=iflag, + eps=eps, + ) + + axis = -2 if type_ == 2 else -ndim - 1 + output_tangent *= jnp.concatenate(jnp.broadcast_arrays(*scales), axis=axis) + + expand_shape = ( + output.shape[: axis + 1] + (len(scales),) + output.shape[axis + 1 :] + ) + output_tangents.append( + jnp.sum(jnp.reshape(output_tangent, expand_shape), axis=axis) + ) + + return output, reduce(ad.add_tangents, output_tangents, ad.Zero.from_value(output)) + + +def transpose(type_, doutput, source, *points, output_shape, eps, iflag): + assert ad.is_undefined_primal(source) + assert not any(map(ad.is_undefined_primal, points)) + assert type(doutput) is not ad.Zero + + if type_ == 1: + result = nufft2(doutput, *points, eps=eps, iflag=iflag) + else: + ndim = len(points) + result = nufft1( + source.aval.shape[-ndim:], doutput, *points, eps=eps, iflag=iflag + ) + + return (result,) + tuple(None for _ in range(len(points))) + + +def batch(type_, prim, args, axes, **kwargs): + ndim = len(args) - 1 + if type_ == 1: + mx = args[0].ndim - 2 + else: + mx = args[0].ndim - ndim - 1 + assert all(a < mx for a in axes) + assert all(a == axes[0] for a in axes[1:]) + return prim.bind(*args, **kwargs), axes[0] + + +def pad_shapes(output_dim, source, *points): + points = jnp.broadcast_arrays(*points) + if points[0].ndim == 0 or source.ndim == 0: + raise ValueError( + "0-dimensional arrays are not supported; are you vmap-ing somewhere " + "where you don't want to?" + ) + + if points[0].ndim == source.ndim - output_dim + 1: + new_shape = source.shape[:-output_dim] + (1,) + source.shape[-output_dim:] + source = jnp.reshape(source, new_shape) + if points[0].ndim != source.ndim - output_dim: + raise ValueError( + f"'source' must have {output_dim} more dimension than 'points'" + ) + if source.ndim == output_dim + 1: + source = source[None, ...] + points = tuple(x[None, :] for x in points) + + return source, points + + +nufft1_p = core.Primitive("nufft1") +nufft1_p.def_impl(partial(xla.apply_primitive, nufft1_p)) +nufft1_p.def_abstract_eval(partial(abstract_eval, 1)) +xla.register_translation(nufft1_p, partial(translation_rule, 1), platform="CUDA") +ad.primitive_jvps[nufft1_p] = partial(jvp, 1, nufft1_p) +ad.primitive_transposes[nufft1_p] = partial(transpose, 1) +batching.primitive_batchers[nufft1_p] = partial(batch, 1, nufft1_p) + + +nufft2_p = core.Primitive("nufft2") +nufft2_p.def_impl(partial(xla.apply_primitive, nufft2_p)) +nufft2_p.def_abstract_eval(partial(abstract_eval, 2)) +xla.register_translation(nufft2_p, partial(translation_rule, 2), platform="CUDA") +ad.primitive_jvps[nufft2_p] = partial(jvp, 2, nufft2_p) +ad.primitive_transposes[nufft2_p] = partial(transpose, 2) +batching.primitive_batchers[nufft2_p] = partial(batch, 2, nufft2_p) diff --git a/tests/gpu_ops_test.py b/tests/gpu_ops_test.py index e02ba6b..577d6e5 100644 --- a/tests/gpu_ops_test.py +++ b/tests/gpu_ops_test.py @@ -1,3 +1,201 @@ -def test_import(): - from jax_finufft import jax_finufft_gpu - #print(vars(jax_finufft_gpu)) +from functools import partial +from itertools import product + +import jax +import jax.numpy as jnp +import numpy as np +import pytest +from jax.test_util import check_grads + +from jax_finufft import cunufft1, cunufft2 + +# TODO: decide on naming and update tests +nufft1 = cunufft1 +nufft2 = cunufft2 + + +@pytest.mark.parametrize( + "ndim, x64, num_nonnuniform, num_uniform, iflag", + product([1, 2, 3], [False, True], [50], [75], [-1, 1]), +) +def test_nufft1_forward(ndim, x64, num_nonnuniform, num_uniform, iflag): + random = np.random.default_rng(657) + + eps = 1e-10 if x64 else 1e-7 + dtype = np.double if x64 else np.single + cdtype = np.cdouble if x64 else np.csingle + + num_uniform = tuple(num_uniform // ndim + 5 * np.arange(ndim)) + ks = [np.arange(-np.floor(n / 2), np.floor((n - 1) / 2 + 1)) for n in num_uniform] + + x = [ + random.uniform(-np.pi, np.pi, size=num_nonnuniform).astype(dtype) + for _ in range(ndim) + ] + x_vec = np.array(x) + c = random.normal(size=num_nonnuniform) + 1j * random.normal(size=num_nonnuniform) + c = c.astype(cdtype) + f_expect = np.zeros(num_uniform, dtype=cdtype) + for coords in product(*map(range, num_uniform)): + k_vec = np.array([k[n] for (n, k) in zip(coords, ks)]) + f_expect[coords] = np.sum(c * np.exp(iflag * 1j * np.dot(k_vec, x_vec))) + + with jax.experimental.enable_x64(x64): + f_calc = nufft1(num_uniform, c, *x, eps=eps, iflag=iflag) + np.testing.assert_allclose(f_calc, f_expect, rtol=5e-7 if x64 else 5e-2) + + f_calc = jax.jit(nufft1, static_argnums=(0,), static_argnames=("eps", "iflag"))( + num_uniform, c, *x, eps=eps, iflag=iflag + ) + np.testing.assert_allclose(f_calc, f_expect, rtol=5e-7 if x64 else 5e-2) + + +@pytest.mark.parametrize( + "ndim, x64, num_nonnuniform, num_uniform, iflag", + product([1, 2, 3], [False, True], [50], [75], [-1, 1]), +) +def test_nufft2_forward(ndim, x64, num_nonnuniform, num_uniform, iflag): + random = np.random.default_rng(657) + + eps = 1e-10 if x64 else 1e-7 + dtype = np.double if x64 else np.single + cdtype = np.cdouble if x64 else np.csingle + + num_uniform = tuple(num_uniform // ndim + 5 * np.arange(ndim)) + ks = [np.arange(-np.floor(n / 2), np.floor((n - 1) / 2 + 1)) for n in num_uniform] + x = [ + random.uniform(-np.pi, np.pi, size=num_nonnuniform).astype(dtype) + for _ in range(ndim) + ] + f = random.normal(size=num_uniform) + 1j * random.normal(size=num_uniform) + f = f.astype(cdtype) + + c_expect = np.zeros(num_nonnuniform, dtype=cdtype) + for n in range(num_nonnuniform): + arg = np.copy(f) + for i, k in enumerate(ks): + coords = [None for _ in range(ndim)] + coords[i] = slice(None) + arg *= np.exp(iflag * 1j * k * x[i][n])[tuple(coords)] + c_expect[n] = np.sum(arg) + + with jax.experimental.enable_x64(x64): + c_calc = nufft2(f, *x, eps=eps, iflag=iflag) + np.testing.assert_allclose(c_calc, c_expect, rtol=5e-7 if x64 else 5e-2) + + c_calc = jax.jit(nufft2, static_argnames=("eps", "iflag"))( + f, *x, eps=eps, iflag=iflag + ) + np.testing.assert_allclose(c_calc, c_expect, rtol=5e-7 if x64 else 5e-2) + + +@pytest.mark.parametrize( + "ndim, num_nonnuniform, num_uniform, iflag", + product([1, 2, 3], [50], [35], [-1, 1]), +) +def test_nufft1_grad(ndim, num_nonnuniform, num_uniform, iflag): + random = np.random.default_rng(657) + + eps = 1e-10 + dtype = np.double + cdtype = np.cdouble + + num_uniform = tuple(num_uniform // ndim + 5 * np.arange(ndim)) + + x = [ + random.uniform(-np.pi, np.pi, size=num_nonnuniform).astype(dtype) + for _ in range(ndim) + ] + c = random.normal(size=num_nonnuniform) + 1j * random.normal(size=num_nonnuniform) + c = c.astype(cdtype) + + with jax.experimental.enable_x64(): + func = partial(nufft1, num_uniform, eps=eps, iflag=iflag) + check_grads(func, (c, *x), 1, modes=("fwd", "rev")) + + +@pytest.mark.parametrize( + "ndim, num_nonnuniform, num_uniform, iflag", + product([1, 2, 3], [50], [35], [-1, 1]), +) +def test_nufft2_grad(ndim, num_nonnuniform, num_uniform, iflag): + random = np.random.default_rng(657) + + eps = 1e-10 + dtype = np.double + cdtype = np.cdouble + + num_uniform = tuple(num_uniform // ndim + 5 * np.arange(ndim)) + + x = [ + random.uniform(-np.pi, np.pi, size=num_nonnuniform).astype(dtype) + for _ in range(ndim) + ] + f = random.normal(size=num_uniform) + 1j * random.normal(size=num_uniform) + f = f.astype(cdtype) + + with jax.experimental.enable_x64(): + func = partial(nufft2, eps=eps, iflag=iflag) + check_grads(func, (f, *x), 1, modes=("fwd", "rev")) + + +@pytest.mark.parametrize( + "ndim, num_nonnuniform, num_uniform, iflag", + product([1, 2, 3], [50], [35], [-1, 1]), +) +def test_nufft1_vmap(ndim, num_nonnuniform, num_uniform, iflag): + random = np.random.default_rng(657) + + eps = 1e-10 + dtype = np.double + cdtype = np.cdouble + + num_uniform = tuple(num_uniform // ndim + 5 * np.arange(ndim)) + + x = [ + random.uniform(-np.pi, np.pi, size=num_nonnuniform).astype(dtype) + for _ in range(ndim) + ] + c = random.normal(size=num_nonnuniform) + 1j * random.normal(size=num_nonnuniform) + c = c.astype(cdtype) + + num = 5 + xs = [jnp.repeat(x_[None], num, axis=0) for x_ in x] + cs = jnp.repeat(c[None], num, axis=0) + + func = partial(nufft1, num_uniform, eps=eps, iflag=iflag) + calc = jax.vmap(func)(cs, *xs) + expect = func(c, *x) + for n in range(num): + np.testing.assert_allclose(calc[n], expect) + + +@pytest.mark.parametrize( + "ndim, num_nonnuniform, num_uniform, iflag", + product([1, 2, 3], [50], [35], [-1, 1]), +) +def test_nufft2_vmap(ndim, num_nonnuniform, num_uniform, iflag): + random = np.random.default_rng(657) + + eps = 1e-10 + dtype = np.double + cdtype = np.cdouble + + num_uniform = tuple(num_uniform // ndim + 5 * np.arange(ndim)) + + x = [ + random.uniform(-np.pi, np.pi, size=num_nonnuniform).astype(dtype) + for _ in range(ndim) + ] + f = random.normal(size=num_uniform) + 1j * random.normal(size=num_uniform) + f = f.astype(cdtype) + + num = 5 + xs = [jnp.repeat(x_[None], num, axis=0) for x_ in x] + fs = jnp.repeat(f[None], num, axis=0) + + func = partial(nufft2, eps=eps, iflag=iflag) + calc = jax.vmap(func)(fs, *xs) + expect = func(f, *x) + for n in range(num): + np.testing.assert_allclose(calc[n], expect) From 93ba4b3373ae8c8a9f630cec3f898a5527c075f8 Mon Sep 17 00:00:00 2001 From: Lehman Garrison Date: Mon, 15 Nov 2021 16:34:45 -0500 Subject: [PATCH 10/12] Don't fail on no CUDA --- src/jax_finufft/__init__.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/jax_finufft/__init__.py b/src/jax_finufft/__init__.py index edec140..b446d3f 100644 --- a/src/jax_finufft/__init__.py +++ b/src/jax_finufft/__init__.py @@ -16,4 +16,10 @@ from .jax_finufft_version import version as __version__ from .ops import nufft1, nufft2 -from .gpu_ops import cunufft1, cunufft2 + +try: + # TODO: how to know when we can import GPU ops? + from .gpu_ops import cunufft1, cunufft2 +except ImportError as e: + import warnings + warnings.warn(f"Could not import GPU extensions due to:\n{e}") From 65f399c4c51a565200d657da0d8d196846c503e5 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Tue, 16 Nov 2021 13:06:20 -0500 Subject: [PATCH 11/12] first pass at getting GPU ops to work --- src/jax_finufft/__init__.py | 9 +- src/jax_finufft/gpu_ops.py | 326 ------------------------------------ src/jax_finufft/ops.py | 118 ++++++++----- tests/gpu_ops_test.py | 201 ---------------------- tests/ops_test.py | 18 ++ 5 files changed, 99 insertions(+), 573 deletions(-) delete mode 100644 src/jax_finufft/gpu_ops.py delete mode 100644 tests/gpu_ops_test.py diff --git a/src/jax_finufft/__init__.py b/src/jax_finufft/__init__.py index b446d3f..931e67b 100644 --- a/src/jax_finufft/__init__.py +++ b/src/jax_finufft/__init__.py @@ -12,14 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__all__ = ["__version__", "nufft1", "nufft2", "cunufft1", "cunufft2"] +__all__ = ["__version__", "nufft1", "nufft2"] from .jax_finufft_version import version as __version__ from .ops import nufft1, nufft2 - -try: - # TODO: how to know when we can import GPU ops? - from .gpu_ops import cunufft1, cunufft2 -except ImportError as e: - import warnings - warnings.warn(f"Could not import GPU extensions due to:\n{e}") diff --git a/src/jax_finufft/gpu_ops.py b/src/jax_finufft/gpu_ops.py deleted file mode 100644 index 70ba8bc..0000000 --- a/src/jax_finufft/gpu_ops.py +++ /dev/null @@ -1,326 +0,0 @@ -__all__ = ["cunufft1", "cunufft2"] - -from functools import partial, reduce - -import numpy as np -from jax import core, dtypes, jit -from jax import numpy as jnp -from jax.abstract_arrays import ShapedArray -from jax.interpreters import ad, batching, xla -from jax.lib import xla_client - -from .jax_finufft_cpu import build_descriptor, build_descriptorf -from . import jax_finufft_gpu - -for _name, _value in jax_finufft_gpu.registrations().items(): - xla_client.register_custom_call_target(_name, _value, platform='CUDA') - -xops = xla_client.ops - - -@partial(jit, static_argnums=(0,), static_argnames=("iflag", "eps")) -def cunufft1(output_shape, source, *points, iflag=1, eps=1e-6): - iflag = int(iflag) - eps = float(eps) - ndim = len(points) - if not 1 <= ndim <= 3: - raise ValueError("Only 1-, 2-, and 3-dimensions are supported") - - # Support passing a scalar output_shape - output_shape = np.atleast_1d(output_shape).astype(np.int64) - if len(output_shape) != ndim: - raise ValueError(f"output_shape must have shape: ({ndim},)") - - # Handle broadcasting - expected_output_shape = source.shape[:-1] + tuple(output_shape) - source, points = pad_shapes(1, source, *points) - if points[0].shape[-1] != source.shape[-1]: - raise ValueError("The final dimension of 'source' must match 'points'") - - return jnp.reshape( - nufft1_p.bind(source, *points, output_shape=output_shape, iflag=iflag, eps=eps), - expected_output_shape, - ) - - -@partial(jit, static_argnames=("iflag", "eps")) -def cunufft2(source, *points, iflag=-1, eps=1e-6): - iflag = int(iflag) - eps = float(eps) - ndim = len(points) - if not 1 <= ndim <= 3: - raise ValueError("Only 1-, 2-, and 3-dimensions are supported") - - # Handle broadcasting - expected_output_shape = source.shape[:-ndim] - source, points = pad_shapes(ndim, source, *points) - expected_output_shape = expected_output_shape + (points[0].shape[-1],) - - return jnp.reshape( - nufft2_p.bind(source, *points, output_shape=None, iflag=iflag, eps=eps), - expected_output_shape, - ) - - -def get_output_shape(type_, source_shape, *points_shape, output_shape): - if type_ == 1: - ndim = len(points_shape) - assert len(output_shape) == ndim - assert len(points_shape[0]) >= 2 - assert all( - x[-1] == source_shape[-1] and x[:-1] == source_shape[:-2] - for x in points_shape - ) - return tuple(source_shape[:-1]) + tuple(output_shape) - - elif type_ == 2: - ndim = len(points_shape) - assert len(points_shape[0]) >= 2 - assert all(x[:-1] == source_shape[: -ndim - 1] for x in points_shape) - return tuple(source_shape[:-ndim]) + (points_shape[0][-1],) - - raise ValueError(f"Unsupported transformation type: {type_}") - - -def abstract_eval(type_, source, *points, output_shape, **_): - ndim = len(points) - assert 1 <= ndim <= 3 - - source_dtype = dtypes.canonicalize_dtype(source.dtype) - points_dtype = [dtypes.canonicalize_dtype(x.dtype) for x in points] - - # Check supported and consistent dtypes - single = source_dtype == np.csingle and all(x == np.single for x in points_dtype) - double = source_dtype == np.cdouble and all(x == np.double for x in points_dtype) - assert single or double - - return ShapedArray( - get_output_shape( - type_, source.shape, *(x.shape for x in points), output_shape=output_shape - ), - source_dtype, - ) - - -def translation_rule( - type_, ctx, avals_in, avals_out, source, *points, output_shape, iflag, eps -): - ndim = len(points) - assert 1 <= ndim <= 3 - - c = ctx.builder - source_shape_info = c.get_shape(source) - points_shape_info = list(map(c.get_shape, points)) - - # Check supported and consistent dtypes - source_dtype = source_shape_info.element_type() - single = source_dtype == np.csingle and all( - x.element_type() == np.single for x in points_shape_info - ) - double = source_dtype == np.cdouble and all( - x.element_type() == np.double for x in points_shape_info - ) - assert single or double - - # Check shapes - source_shape = source_shape_info.dimensions() - points_shape = tuple(x.dimensions() for x in points_shape_info) - full_output_shape = get_output_shape( - type_, source_shape, *points_shape, output_shape=output_shape - ) - - # Work out the other dimenstions of the problem - n_j = np.array(points_shape[0][-1]).astype(np.int64) - if type_ == 1: - n_tot = np.prod(source_shape[:-2]).astype(np.int64) - n_transf = np.array(source_shape[-2]).astype(np.int32) - n_k = np.array(full_output_shape[-ndim:], dtype=np.int64) - else: - n_tot = np.prod(source_shape[: -ndim - 1]).astype(np.int64) - n_transf = np.array(source_shape[-ndim - 1]).astype(np.int32) - n_k = np.array(source_shape[-ndim:], dtype=np.int64) - - # The backend expects the output shape in Fortran order so we'll just - # fake it here, by sending in n_k and x in the reverse order. - n_k_full = np.zeros(3, dtype=np.int64) - n_k_full[:ndim] = n_k[::-1] - - # Dispatch to the right op - suffix = "f" if source_dtype == np.csingle else "" - op_name = f"nufft{ndim}d{type_}{suffix}".encode("ascii") - desc = globals()[f"build_descriptor{suffix}"]( - eps, iflag, n_tot, n_transf, n_j, *n_k_full - ) - - return [ - xops.CustomCallWithLayout( - c, - op_name, - # The inputs: - operands=( - xops.ConstantLiteral(c, np.frombuffer(desc, dtype=np.uint8)), - source, - *points[::-1], # Reverse order because backend uses Fortran order - ), - # The input shapes: - operand_shapes_with_layout=( - xla_client.Shape.array_shape(np.dtype(np.uint8), (len(desc),), (0,)), - xla_client.Shape.array_shape( - source_dtype, - source_shape, - tuple(range(len(source_shape) - 1, -1, -1)), - ), - ) - + tuple( - xla_client.Shape.array_shape( - x.element_type(), - x.dimensions(), - tuple(range(len(x.dimensions()) - 1, -1, -1)), - ) - for x in points_shape_info[::-1] # Reverse order, again - ), - # The output shapes: - shape_with_layout=xla_client.Shape.array_shape( - source_dtype, - full_output_shape, - tuple(range(len(full_output_shape) - 1, -1, -1)), - ), - ) - ] - - -def jvp(type_, prim, args, tangents, *, output_shape, iflag, eps): - # Type 1: - # f_k = sum_j c_j * exp(iflag * i * k * x_j) - # df_k/dx_j = iflag * i * k * c_j * exp(iflag * i * k * x_j) - - # Type 2: - # c_j = sum_k f_k * exp(iflag * i * k * x_j) - # dc_j/dx_j = sum_k iflag * i * k * f_k * exp(iflag * i * k * x_j) - - source, *points = args - dsource, *dpoints = tangents - output = prim.bind(source, *points, output_shape=output_shape, iflag=iflag, eps=eps) - - # The JVP op can be written as a single transform of the same type with - output_tangents = [] - ndim = len(points) - scales = [] - arguments = [] - if type(dsource) is not ad.Zero: - if type_ == 1: - scales.append(jnp.ones_like(output)) - arguments.append(dsource) - else: - output_tangents.append( - prim.bind( - dsource, *points, output_shape=output_shape, iflag=iflag, eps=eps - ) - ) - - for dim, dx in enumerate(dpoints): - if type(dx) is ad.Zero: - continue - - n = output_shape[dim] if type_ == 1 else source.shape[-ndim + dim] - shape = np.ones(ndim, dtype=int) - shape[dim] = -1 - k = np.arange(-np.floor(n / 2), np.floor((n - 1) / 2 + 1)) - k = k.reshape(shape) - factor = 1j * iflag * k - - if type_ == 1: - scales.append(factor) - arguments.append(dx * source) - else: - scales.append(dx) - arguments.append(factor * source) - - if len(scales): - axis = -2 if type_ == 1 else -ndim - 1 - output_tangent = prim.bind( - jnp.concatenate(arguments, axis=axis), - *points, - output_shape=output_shape, - iflag=iflag, - eps=eps, - ) - - axis = -2 if type_ == 2 else -ndim - 1 - output_tangent *= jnp.concatenate(jnp.broadcast_arrays(*scales), axis=axis) - - expand_shape = ( - output.shape[: axis + 1] + (len(scales),) + output.shape[axis + 1 :] - ) - output_tangents.append( - jnp.sum(jnp.reshape(output_tangent, expand_shape), axis=axis) - ) - - return output, reduce(ad.add_tangents, output_tangents, ad.Zero.from_value(output)) - - -def transpose(type_, doutput, source, *points, output_shape, eps, iflag): - assert ad.is_undefined_primal(source) - assert not any(map(ad.is_undefined_primal, points)) - assert type(doutput) is not ad.Zero - - if type_ == 1: - result = nufft2(doutput, *points, eps=eps, iflag=iflag) - else: - ndim = len(points) - result = nufft1( - source.aval.shape[-ndim:], doutput, *points, eps=eps, iflag=iflag - ) - - return (result,) + tuple(None for _ in range(len(points))) - - -def batch(type_, prim, args, axes, **kwargs): - ndim = len(args) - 1 - if type_ == 1: - mx = args[0].ndim - 2 - else: - mx = args[0].ndim - ndim - 1 - assert all(a < mx for a in axes) - assert all(a == axes[0] for a in axes[1:]) - return prim.bind(*args, **kwargs), axes[0] - - -def pad_shapes(output_dim, source, *points): - points = jnp.broadcast_arrays(*points) - if points[0].ndim == 0 or source.ndim == 0: - raise ValueError( - "0-dimensional arrays are not supported; are you vmap-ing somewhere " - "where you don't want to?" - ) - - if points[0].ndim == source.ndim - output_dim + 1: - new_shape = source.shape[:-output_dim] + (1,) + source.shape[-output_dim:] - source = jnp.reshape(source, new_shape) - if points[0].ndim != source.ndim - output_dim: - raise ValueError( - f"'source' must have {output_dim} more dimension than 'points'" - ) - if source.ndim == output_dim + 1: - source = source[None, ...] - points = tuple(x[None, :] for x in points) - - return source, points - - -nufft1_p = core.Primitive("nufft1") -nufft1_p.def_impl(partial(xla.apply_primitive, nufft1_p)) -nufft1_p.def_abstract_eval(partial(abstract_eval, 1)) -xla.register_translation(nufft1_p, partial(translation_rule, 1), platform="CUDA") -ad.primitive_jvps[nufft1_p] = partial(jvp, 1, nufft1_p) -ad.primitive_transposes[nufft1_p] = partial(transpose, 1) -batching.primitive_batchers[nufft1_p] = partial(batch, 1, nufft1_p) - - -nufft2_p = core.Primitive("nufft2") -nufft2_p.def_impl(partial(xla.apply_primitive, nufft2_p)) -nufft2_p.def_abstract_eval(partial(abstract_eval, 2)) -xla.register_translation(nufft2_p, partial(translation_rule, 2), platform="CUDA") -ad.primitive_jvps[nufft2_p] = partial(jvp, 2, nufft2_p) -ad.primitive_transposes[nufft2_p] = partial(transpose, 2) -batching.primitive_batchers[nufft2_p] = partial(batch, 2, nufft2_p) diff --git a/src/jax_finufft/ops.py b/src/jax_finufft/ops.py index 75834a1..ddec9f3 100644 --- a/src/jax_finufft/ops.py +++ b/src/jax_finufft/ops.py @@ -11,8 +11,16 @@ from . import jax_finufft_cpu +try: + from . import jax_finufft_gpu + + for _name, _value in jax_finufft_gpu.registrations().items(): + xla_client.register_custom_call_target(_name, _value, platform="CUDA") +except ImportError: + jax_finufft_gpu = None + for _name, _value in jax_finufft_cpu.registrations().items(): - xla_client.register_cpu_custom_call_target(_name, _value) + xla_client.register_custom_call_target(_name, _value, platform="cpu") xops = xla_client.ops @@ -102,10 +110,15 @@ def abstract_eval(type_, source, *points, output_shape, **_): def translation_rule( - type_, ctx, avals_in, avals_out, source, *points, output_shape, iflag, eps + platform, type_, ctx, avals_in, avals_out, source, *points, output_shape, iflag, eps ): + if platform == "gpu" and jax_finufft_gpu is None: + raise ValueError("jax-finufft was not compiled with GPU support") + ndim = len(points) assert 1 <= ndim <= 3 + if platform == "gpu" and ndim == 1: + raise ValueError("1-D transforms are not yet supported on the GPU") c = ctx.builder source_shape_info = c.get_shape(source) @@ -151,41 +164,62 @@ def translation_rule( eps, iflag, n_tot, n_transf, n_j, *n_k_full ) - return [ - xops.CustomCallWithLayout( - c, - op_name, - # The inputs: - operands=( - xops.ConstantLiteral(c, np.frombuffer(desc, dtype=np.uint8)), - source, - *points[::-1], # Reverse order because backend uses Fortran order - ), - # The input shapes: - operand_shapes_with_layout=( - xla_client.Shape.array_shape(np.dtype(np.uint8), (len(desc),), (0,)), - xla_client.Shape.array_shape( - source_dtype, - source_shape, - tuple(range(len(source_shape) - 1, -1, -1)), - ), - ) - + tuple( - xla_client.Shape.array_shape( - x.element_type(), - x.dimensions(), - tuple(range(len(x.dimensions()) - 1, -1, -1)), - ) - for x in points_shape_info[::-1] # Reverse order, again - ), - # The output shapes: - shape_with_layout=xla_client.Shape.array_shape( - source_dtype, - full_output_shape, - tuple(range(len(full_output_shape) - 1, -1, -1)), - ), + # Set up most of the arguments + operands = ( + source, + *points[::-1], # Reverse order because backend uses Fortran order + ) + operand_shapes_with_layout = ( + xla_client.Shape.array_shape( + source_dtype, + source_shape, + tuple(range(len(source_shape) - 1, -1, -1)), + ), + ) + tuple( + xla_client.Shape.array_shape( + x.element_type(), + x.dimensions(), + tuple(range(len(x.dimensions()) - 1, -1, -1)), ) - ] + for x in points_shape_info[::-1] # Reverse order, again + ) + shape_with_layout = xla_client.Shape.array_shape( + source_dtype, + full_output_shape, + tuple(range(len(full_output_shape) - 1, -1, -1)), + ) + + if platform == "cpu": + return [ + xops.CustomCallWithLayout( + c, + op_name, + operands=(xops.ConstantLiteral(c, np.frombuffer(desc, dtype=np.uint8)),) + + operands, + operand_shapes_with_layout=( + xla_client.Shape.array_shape( + np.dtype(np.uint8), (len(desc),), (0,) + ), + ) + + operand_shapes_with_layout, + shape_with_layout=shape_with_layout, + ) + ] + + elif platform == "gpu": + return [ + xops.CustomCallWithLayout( + c, + op_name, + operands=operands, + operand_shapes_with_layout=operand_shapes_with_layout, + shape_with_layout=shape_with_layout, + opaque=desc, + ) + ] + + else: + raise ValueError(f"Unrecognized platform '{platform}'") def jvp(type_, prim, args, tangents, *, output_shape, iflag, eps): @@ -310,7 +344,11 @@ def pad_shapes(output_dim, source, *points): nufft1_p = core.Primitive("nufft1") nufft1_p.def_impl(partial(xla.apply_primitive, nufft1_p)) nufft1_p.def_abstract_eval(partial(abstract_eval, 1)) -xla.register_translation(nufft1_p, partial(translation_rule, 1), platform="cpu") +xla.register_translation(nufft1_p, partial(translation_rule, "cpu", 1), platform="cpu") +if jax_finufft_gpu is not None: + xla.register_translation( + nufft1_p, partial(translation_rule, "gpu", 1), platform="gpu" + ) ad.primitive_jvps[nufft1_p] = partial(jvp, 1, nufft1_p) ad.primitive_transposes[nufft1_p] = partial(transpose, 1) batching.primitive_batchers[nufft1_p] = partial(batch, 1, nufft1_p) @@ -319,7 +357,11 @@ def pad_shapes(output_dim, source, *points): nufft2_p = core.Primitive("nufft2") nufft2_p.def_impl(partial(xla.apply_primitive, nufft2_p)) nufft2_p.def_abstract_eval(partial(abstract_eval, 2)) -xla.register_translation(nufft2_p, partial(translation_rule, 2), platform="cpu") +xla.register_translation(nufft2_p, partial(translation_rule, "cpu", 2), platform="cpu") +if jax_finufft_gpu is not None: + xla.register_translation( + nufft2_p, partial(translation_rule, "gpu", 2), platform="gpu" + ) ad.primitive_jvps[nufft2_p] = partial(jvp, 2, nufft2_p) ad.primitive_transposes[nufft2_p] = partial(transpose, 2) batching.primitive_batchers[nufft2_p] = partial(batch, 2, nufft2_p) diff --git a/tests/gpu_ops_test.py b/tests/gpu_ops_test.py deleted file mode 100644 index 577d6e5..0000000 --- a/tests/gpu_ops_test.py +++ /dev/null @@ -1,201 +0,0 @@ -from functools import partial -from itertools import product - -import jax -import jax.numpy as jnp -import numpy as np -import pytest -from jax.test_util import check_grads - -from jax_finufft import cunufft1, cunufft2 - -# TODO: decide on naming and update tests -nufft1 = cunufft1 -nufft2 = cunufft2 - - -@pytest.mark.parametrize( - "ndim, x64, num_nonnuniform, num_uniform, iflag", - product([1, 2, 3], [False, True], [50], [75], [-1, 1]), -) -def test_nufft1_forward(ndim, x64, num_nonnuniform, num_uniform, iflag): - random = np.random.default_rng(657) - - eps = 1e-10 if x64 else 1e-7 - dtype = np.double if x64 else np.single - cdtype = np.cdouble if x64 else np.csingle - - num_uniform = tuple(num_uniform // ndim + 5 * np.arange(ndim)) - ks = [np.arange(-np.floor(n / 2), np.floor((n - 1) / 2 + 1)) for n in num_uniform] - - x = [ - random.uniform(-np.pi, np.pi, size=num_nonnuniform).astype(dtype) - for _ in range(ndim) - ] - x_vec = np.array(x) - c = random.normal(size=num_nonnuniform) + 1j * random.normal(size=num_nonnuniform) - c = c.astype(cdtype) - f_expect = np.zeros(num_uniform, dtype=cdtype) - for coords in product(*map(range, num_uniform)): - k_vec = np.array([k[n] for (n, k) in zip(coords, ks)]) - f_expect[coords] = np.sum(c * np.exp(iflag * 1j * np.dot(k_vec, x_vec))) - - with jax.experimental.enable_x64(x64): - f_calc = nufft1(num_uniform, c, *x, eps=eps, iflag=iflag) - np.testing.assert_allclose(f_calc, f_expect, rtol=5e-7 if x64 else 5e-2) - - f_calc = jax.jit(nufft1, static_argnums=(0,), static_argnames=("eps", "iflag"))( - num_uniform, c, *x, eps=eps, iflag=iflag - ) - np.testing.assert_allclose(f_calc, f_expect, rtol=5e-7 if x64 else 5e-2) - - -@pytest.mark.parametrize( - "ndim, x64, num_nonnuniform, num_uniform, iflag", - product([1, 2, 3], [False, True], [50], [75], [-1, 1]), -) -def test_nufft2_forward(ndim, x64, num_nonnuniform, num_uniform, iflag): - random = np.random.default_rng(657) - - eps = 1e-10 if x64 else 1e-7 - dtype = np.double if x64 else np.single - cdtype = np.cdouble if x64 else np.csingle - - num_uniform = tuple(num_uniform // ndim + 5 * np.arange(ndim)) - ks = [np.arange(-np.floor(n / 2), np.floor((n - 1) / 2 + 1)) for n in num_uniform] - x = [ - random.uniform(-np.pi, np.pi, size=num_nonnuniform).astype(dtype) - for _ in range(ndim) - ] - f = random.normal(size=num_uniform) + 1j * random.normal(size=num_uniform) - f = f.astype(cdtype) - - c_expect = np.zeros(num_nonnuniform, dtype=cdtype) - for n in range(num_nonnuniform): - arg = np.copy(f) - for i, k in enumerate(ks): - coords = [None for _ in range(ndim)] - coords[i] = slice(None) - arg *= np.exp(iflag * 1j * k * x[i][n])[tuple(coords)] - c_expect[n] = np.sum(arg) - - with jax.experimental.enable_x64(x64): - c_calc = nufft2(f, *x, eps=eps, iflag=iflag) - np.testing.assert_allclose(c_calc, c_expect, rtol=5e-7 if x64 else 5e-2) - - c_calc = jax.jit(nufft2, static_argnames=("eps", "iflag"))( - f, *x, eps=eps, iflag=iflag - ) - np.testing.assert_allclose(c_calc, c_expect, rtol=5e-7 if x64 else 5e-2) - - -@pytest.mark.parametrize( - "ndim, num_nonnuniform, num_uniform, iflag", - product([1, 2, 3], [50], [35], [-1, 1]), -) -def test_nufft1_grad(ndim, num_nonnuniform, num_uniform, iflag): - random = np.random.default_rng(657) - - eps = 1e-10 - dtype = np.double - cdtype = np.cdouble - - num_uniform = tuple(num_uniform // ndim + 5 * np.arange(ndim)) - - x = [ - random.uniform(-np.pi, np.pi, size=num_nonnuniform).astype(dtype) - for _ in range(ndim) - ] - c = random.normal(size=num_nonnuniform) + 1j * random.normal(size=num_nonnuniform) - c = c.astype(cdtype) - - with jax.experimental.enable_x64(): - func = partial(nufft1, num_uniform, eps=eps, iflag=iflag) - check_grads(func, (c, *x), 1, modes=("fwd", "rev")) - - -@pytest.mark.parametrize( - "ndim, num_nonnuniform, num_uniform, iflag", - product([1, 2, 3], [50], [35], [-1, 1]), -) -def test_nufft2_grad(ndim, num_nonnuniform, num_uniform, iflag): - random = np.random.default_rng(657) - - eps = 1e-10 - dtype = np.double - cdtype = np.cdouble - - num_uniform = tuple(num_uniform // ndim + 5 * np.arange(ndim)) - - x = [ - random.uniform(-np.pi, np.pi, size=num_nonnuniform).astype(dtype) - for _ in range(ndim) - ] - f = random.normal(size=num_uniform) + 1j * random.normal(size=num_uniform) - f = f.astype(cdtype) - - with jax.experimental.enable_x64(): - func = partial(nufft2, eps=eps, iflag=iflag) - check_grads(func, (f, *x), 1, modes=("fwd", "rev")) - - -@pytest.mark.parametrize( - "ndim, num_nonnuniform, num_uniform, iflag", - product([1, 2, 3], [50], [35], [-1, 1]), -) -def test_nufft1_vmap(ndim, num_nonnuniform, num_uniform, iflag): - random = np.random.default_rng(657) - - eps = 1e-10 - dtype = np.double - cdtype = np.cdouble - - num_uniform = tuple(num_uniform // ndim + 5 * np.arange(ndim)) - - x = [ - random.uniform(-np.pi, np.pi, size=num_nonnuniform).astype(dtype) - for _ in range(ndim) - ] - c = random.normal(size=num_nonnuniform) + 1j * random.normal(size=num_nonnuniform) - c = c.astype(cdtype) - - num = 5 - xs = [jnp.repeat(x_[None], num, axis=0) for x_ in x] - cs = jnp.repeat(c[None], num, axis=0) - - func = partial(nufft1, num_uniform, eps=eps, iflag=iflag) - calc = jax.vmap(func)(cs, *xs) - expect = func(c, *x) - for n in range(num): - np.testing.assert_allclose(calc[n], expect) - - -@pytest.mark.parametrize( - "ndim, num_nonnuniform, num_uniform, iflag", - product([1, 2, 3], [50], [35], [-1, 1]), -) -def test_nufft2_vmap(ndim, num_nonnuniform, num_uniform, iflag): - random = np.random.default_rng(657) - - eps = 1e-10 - dtype = np.double - cdtype = np.cdouble - - num_uniform = tuple(num_uniform // ndim + 5 * np.arange(ndim)) - - x = [ - random.uniform(-np.pi, np.pi, size=num_nonnuniform).astype(dtype) - for _ in range(ndim) - ] - f = random.normal(size=num_uniform) + 1j * random.normal(size=num_uniform) - f = f.astype(cdtype) - - num = 5 - xs = [jnp.repeat(x_[None], num, axis=0) for x_ in x] - fs = jnp.repeat(f[None], num, axis=0) - - func = partial(nufft2, eps=eps, iflag=iflag) - calc = jax.vmap(func)(fs, *xs) - expect = func(f, *x) - for n in range(num): - np.testing.assert_allclose(calc[n], expect) diff --git a/tests/ops_test.py b/tests/ops_test.py index ff3e205..4022b8f 100644 --- a/tests/ops_test.py +++ b/tests/ops_test.py @@ -15,6 +15,9 @@ product([1, 2, 3], [False, True], [50], [75], [-1, 1]), ) def test_nufft1_forward(ndim, x64, num_nonnuniform, num_uniform, iflag): + if ndim == 1 and jax.default_backend() != "cpu": + pytest.skip("1D transforms not implemented on GPU") + random = np.random.default_rng(657) eps = 1e-10 if x64 else 1e-7 @@ -51,6 +54,9 @@ def test_nufft1_forward(ndim, x64, num_nonnuniform, num_uniform, iflag): product([1, 2, 3], [False, True], [50], [75], [-1, 1]), ) def test_nufft2_forward(ndim, x64, num_nonnuniform, num_uniform, iflag): + if ndim == 1 and jax.default_backend() != "cpu": + pytest.skip("1D transforms not implemented on GPU") + random = np.random.default_rng(657) eps = 1e-10 if x64 else 1e-7 @@ -90,6 +96,9 @@ def test_nufft2_forward(ndim, x64, num_nonnuniform, num_uniform, iflag): product([1, 2, 3], [50], [35], [-1, 1]), ) def test_nufft1_grad(ndim, num_nonnuniform, num_uniform, iflag): + if ndim == 1 and jax.default_backend() != "cpu": + pytest.skip("1D transforms not implemented on GPU") + random = np.random.default_rng(657) eps = 1e-10 @@ -115,6 +124,9 @@ def test_nufft1_grad(ndim, num_nonnuniform, num_uniform, iflag): product([1, 2, 3], [50], [35], [-1, 1]), ) def test_nufft2_grad(ndim, num_nonnuniform, num_uniform, iflag): + if ndim == 1 and jax.default_backend() != "cpu": + pytest.skip("1D transforms not implemented on GPU") + random = np.random.default_rng(657) eps = 1e-10 @@ -140,6 +152,9 @@ def test_nufft2_grad(ndim, num_nonnuniform, num_uniform, iflag): product([1, 2, 3], [50], [35], [-1, 1]), ) def test_nufft1_vmap(ndim, num_nonnuniform, num_uniform, iflag): + if ndim == 1 and jax.default_backend() != "cpu": + pytest.skip("1D transforms not implemented on GPU") + random = np.random.default_rng(657) eps = 1e-10 @@ -171,6 +186,9 @@ def test_nufft1_vmap(ndim, num_nonnuniform, num_uniform, iflag): product([1, 2, 3], [50], [35], [-1, 1]), ) def test_nufft2_vmap(ndim, num_nonnuniform, num_uniform, iflag): + if ndim == 1 and jax.default_backend() != "cpu": + pytest.skip("1D transforms not implemented on GPU") + random = np.random.default_rng(657) eps = 1e-10 From 1d94e9e90471a449ae16653049a92ce4799773ce Mon Sep 17 00:00:00 2001 From: Lehman Garrison Date: Thu, 18 Nov 2021 15:14:44 -0500 Subject: [PATCH 12/12] Fix GPU tests --- .github/workflows/tests.yml | 2 +- CMakeLists.txt | 3 ++- lib/jax_finufft_gpu.h | 7 +++++++ lib/kernels.cc.cu | 8 +++---- tests/ops_test.py | 42 ++++++++++++++++++------------------- 5 files changed, 33 insertions(+), 29 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 6d503d7..480a99a 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -40,4 +40,4 @@ jobs: python -m pip install .[test] - name: Run tests - run: python -m pytest -v tests --ignore='tests/gpu_ops_test.py' + run: python -m pytest -v tests diff --git a/CMakeLists.txt b/CMakeLists.txt index f4b3fd3..3c13737 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -67,9 +67,10 @@ if (CMAKE_CUDA_COMPILER) enable_language(CUDA) set(CMAKE_CUDA_SEPARABLE_COMPILATION ON) if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES) - set(CMAKE_CUDA_ARCHITECTURES "52;60;61;70;75") + set(CMAKE_CUDA_ARCHITECTURES "52;60;61;70;75;80") endif() + # Find cufft find_package(CUDAToolkit) set(CUFINUFFT_INCLUDE_DIRS diff --git a/lib/jax_finufft_gpu.h b/lib/jax_finufft_gpu.h index 8858b05..6b00ffa 100644 --- a/lib/jax_finufft_gpu.h +++ b/lib/jax_finufft_gpu.h @@ -31,6 +31,13 @@ void default_opts(int type, int dim, cufinufft_opts* opts) { template <> void default_opts(int type, int dim, cufinufft_opts* opts) { cufinufft_default_opts(type, dim, opts); + + // double precision in 3D blows out shared memory. + // Fall back to a slower, non-shared memory algorithm + // https://github.com/flatironinstitute/cufinufft/issues/58 + if(dim > 2){ + opts->gpu_method = 1; + } } template diff --git a/lib/kernels.cc.cu b/lib/kernels.cc.cu index 1eb28b1..85ebbdb 100644 --- a/lib/kernels.cc.cu +++ b/lib/kernels.cc.cu @@ -15,8 +15,8 @@ void run_nufft(int type, const NufftDescriptor* descriptor, T *x, T *y, T *z, int64_t n_k = 1; for (int d = 0; d < ndim; ++d) n_k *= descriptor->n_k[d]; - // TODO: okay to stack-allocate this? - int nmodes32[ndim]; + // cufinufft seems to read all 3 dims, even for ndim=1/2! + int nmodes32[3] = {1, 1, 1}; for (int d = 0; d < ndim; ++d) nmodes32[d] = static_cast(descriptor->n_k[d]); // TODO: does this need to be part of NufftDescriptor? It's GPU-specific. @@ -59,7 +59,6 @@ void nufft1(cudaStream_t stream, void** buffers, const char* opaque, std::size_t } std::complex *F = reinterpret_cast *>(buffers[out_dim]); - // Call cuFINUFFT here... run_nufft(1, descriptor, x, y, z, c, F); ThrowIfError(cudaGetLastError()); @@ -84,8 +83,7 @@ void nufft2(cudaStream_t stream, void** buffers, const char* opaque, std::size_t } std::complex *c = reinterpret_cast *>(buffers[out_dim]); - // Call cuFINUFFT here... - run_nufft(1, descriptor, x, y, z, c, F); + run_nufft(2, descriptor, x, y, z, c, F); ThrowIfError(cudaGetLastError()); } diff --git a/tests/ops_test.py b/tests/ops_test.py index 4022b8f..2d2511d 100644 --- a/tests/ops_test.py +++ b/tests/ops_test.py @@ -27,17 +27,13 @@ def test_nufft1_forward(ndim, x64, num_nonnuniform, num_uniform, iflag): num_uniform = tuple(num_uniform // ndim + 5 * np.arange(ndim)) ks = [np.arange(-np.floor(n / 2), np.floor((n - 1) / 2 + 1)) for n in num_uniform] - x = [ - random.uniform(-np.pi, np.pi, size=num_nonnuniform).astype(dtype) - for _ in range(ndim) - ] - x_vec = np.array(x) + x = random.uniform(-np.pi, np.pi, size=(ndim,num_nonnuniform)).astype(dtype) c = random.normal(size=num_nonnuniform) + 1j * random.normal(size=num_nonnuniform) c = c.astype(cdtype) f_expect = np.zeros(num_uniform, dtype=cdtype) for coords in product(*map(range, num_uniform)): k_vec = np.array([k[n] for (n, k) in zip(coords, ks)]) - f_expect[coords] = np.sum(c * np.exp(iflag * 1j * np.dot(k_vec, x_vec))) + f_expect[coords] = np.sum(c * np.exp(iflag * 1j * np.dot(k_vec, x))) with jax.experimental.enable_x64(x64): f_calc = nufft1(num_uniform, c, *x, eps=eps, iflag=iflag) @@ -170,15 +166,16 @@ def test_nufft1_vmap(ndim, num_nonnuniform, num_uniform, iflag): c = random.normal(size=num_nonnuniform) + 1j * random.normal(size=num_nonnuniform) c = c.astype(cdtype) - num = 5 - xs = [jnp.repeat(x_[None], num, axis=0) for x_ in x] - cs = jnp.repeat(c[None], num, axis=0) + with jax.experimental.enable_x64(): + num = 5 + xs = [jnp.repeat(x_[None], num, axis=0) for x_ in x] + cs = jnp.repeat(c[None], num, axis=0) - func = partial(nufft1, num_uniform, eps=eps, iflag=iflag) - calc = jax.vmap(func)(cs, *xs) - expect = func(c, *x) - for n in range(num): - np.testing.assert_allclose(calc[n], expect) + func = partial(nufft1, num_uniform, eps=eps, iflag=iflag) + calc = jax.vmap(func)(cs, *xs) + expect = func(c, *x) + for n in range(num): + np.testing.assert_allclose(calc[n], expect) @pytest.mark.parametrize( @@ -204,12 +201,13 @@ def test_nufft2_vmap(ndim, num_nonnuniform, num_uniform, iflag): f = random.normal(size=num_uniform) + 1j * random.normal(size=num_uniform) f = f.astype(cdtype) - num = 5 - xs = [jnp.repeat(x_[None], num, axis=0) for x_ in x] - fs = jnp.repeat(f[None], num, axis=0) + with jax.experimental.enable_x64(): + num = 5 + xs = [jnp.repeat(x_[None], num, axis=0) for x_ in x] + fs = jnp.repeat(f[None], num, axis=0) - func = partial(nufft2, eps=eps, iflag=iflag) - calc = jax.vmap(func)(fs, *xs) - expect = func(f, *x) - for n in range(num): - np.testing.assert_allclose(calc[n], expect) + func = partial(nufft2, eps=eps, iflag=iflag) + calc = jax.vmap(func)(fs, *xs) + expect = func(f, *x) + for n in range(num): + np.testing.assert_allclose(calc[n], expect)