Skip to content

Update pip wheels to use jax.ffi #175

@michael-0brien

Description

@michael-0brien

As of JAX 0.8.0 released a few days ago, it is necessary to use jax.ffi interface. The TL;DR is that when making a fresh install with pip, a call to nufft1 yields

AttributeError: mlir.custom_call was removed in JAX v0.8.0; use the APIs provided by jax.ffi instead

Details

Environment setup:

uv venv --python 3.11
uv pip install jax_finufft

Code:

import numpy as np
from jax_finufft import nufft1

M = 100000
N = 200000

x = 2 * np.pi * np.random.uniform(size=M)
c = np.random.standard_normal(size=M) + 1j * np.random.standard_normal(size=M)
f = nufft1(N, c, x, eps=1e-6, iflag=1)

Call to jax.print_environment_info()

jax:    0.8.0
jaxlib: 0.8.0
numpy:  2.3.4
python: 3.11.4 (main, Jun 15 2023, 07:55:38) [Clang 14.0.3 (clang-1403.0.22.14.1)]
device info: cpu-1, 1 local devices"
process_count: 1
platform: uname_result(system='Darwin', node=..., release='24.6.0', version='Darwin Kernel Version 24.6.0: Mon Jul 14 11:30:51 PDT 2025; root:xnu-11417.140.69~1/RELEASE_ARM64_T8112', machine='arm64')

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions