Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,19 @@ endif()
find_package(Python 3.8 COMPONENTS Interpreter Development.Module REQUIRED)
find_package(nanobind CONFIG REQUIRED)

# Find XLA FFI headers using JAX's official API
execute_process(
COMMAND "${Python_EXECUTABLE}" "-c" "from jax import ffi; print(ffi.include_dir())"
OUTPUT_STRIP_TRAILING_WHITESPACE
OUTPUT_VARIABLE XLA_FFI_INCLUDE_DIR
RESULT_VARIABLE JAX_FFI_RESULT
)

if(NOT JAX_FFI_RESULT EQUAL 0)
message(FATAL_ERROR "Could not find JAX FFI headers. Please install last version of jax with: pip install jax")
endif()
message(STATUS "Found XLA FFI headers at: ${XLA_FFI_INCLUDE_DIR}")

if(NOT CMAKE_BUILD_TYPE AND NOT CMAKE_CONFIGURATION_TYPES)
set(CMAKE_BUILD_TYPE Release CACHE STRING "Choose the type of build." FORCE)
set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release" "MinSizeRel" "RelWithDebInfo")
Expand All @@ -87,7 +100,7 @@ if (NOT FFTW_INCLUDE_DIRS)
get_target_property(FFTW_INCLUDE_DIRS fftw3 INTERFACE_INCLUDE_DIRECTORIES)
endif()

target_include_directories(jax_finufft_cpu PRIVATE ${FFTW_INCLUDE_DIRS})
target_include_directories(jax_finufft_cpu PRIVATE ${FFTW_INCLUDE_DIRS} ${XLA_FFI_INCLUDE_DIR})
install(TARGETS jax_finufft_cpu LIBRARY DESTINATION .)

if(FINUFFT_USE_OPENMP)
Expand Down Expand Up @@ -115,6 +128,7 @@ if(FINUFFT_USE_CUDA)
${CMAKE_CURRENT_LIST_DIR}/lib/kernels.cc.cu)
target_include_directories(jax_finufft_gpu PUBLIC ${CUFINUFFT_INCLUDE_DIRS})
target_include_directories(jax_finufft_gpu PUBLIC ${CUFINUFFT_VENDORED_INCLUDE_DIRS})
target_include_directories(jax_finufft_gpu PRIVATE ${XLA_FFI_INCLUDE_DIR})
target_link_libraries(jax_finufft_gpu PRIVATE cufinufft)
install(TARGETS jax_finufft_gpu LIBRARY DESTINATION .)
endif()
802 changes: 688 additions & 114 deletions lib/jax_finufft_cpu.cc

Large diffs are not rendered by default.

11 changes: 0 additions & 11 deletions lib/jax_finufft_cpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,17 +127,6 @@ float* z_index<3, float>(float* z, int64_t index) {
return &(z[index]);
}

template <typename T>
struct descriptor {
T eps;
int iflag;
int64_t n_tot;
int n_transf;
int64_t n_j;
int64_t n_k[3];
finufft_opts opts;
};

} // namespace cpu
} // namespace jax_finufft

Expand Down
467 changes: 405 additions & 62 deletions lib/jax_finufft_gpu.cc

Large diffs are not rendered by default.

16 changes: 8 additions & 8 deletions lib/kernel_helpers.h
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
// This header is not specific to our application and you'll probably want something like this
// for any extension you're building. This includes the infrastructure needed to serialize
// descriptors that are used with the "opaque" parameter of the GPU custom call. In our example
// we'll use this parameter to pass the size of our problem.
// Helper utilities for serializing descriptors used with XLA custom calls.
// This provides the infrastructure for the "opaque" parameter used by the GPU
// backend which still uses the legacy API.

#ifndef _JAX_FINUFFT_KERNEL_HELPERS_H_
#define _JAX_FINUFFT_KERNEL_HELPERS_H_

#include <cstdint>
#include <cstring>
#include <stdexcept>
#include <string>
#include <type_traits>
Expand All @@ -19,12 +19,12 @@ typename std::enable_if<sizeof(To) == sizeof(From) && std::is_trivially_copyable
std::is_trivially_copyable<To>::value,
To>::type
bit_cast(const From& src) noexcept {
static_assert(
std::is_trivially_constructible<To>::value,
"This implementation additionally requires destination type to be trivially constructible");
static_assert(std::is_trivially_constructible<To>::value,
"This implementation additionally requires destination type "
"to be trivially constructible");

To dst;
memcpy(&dst, &src, sizeof(To));
std::memcpy(&dst, &src, sizeof(To));
return dst;
}

Expand Down
Loading
Loading