-
Notifications
You must be signed in to change notification settings - Fork 3.1k
Open
Labels
NVIDIA GPUIssues specific to NVIDIA GPUsIssues specific to NVIDIA GPUsperformancemake things lean and fastmake things lean and fast
Description
Description
Currently JAX uses a simple heuristics for jax.lax.linalg.eigh
on GPUs - for >32 cols it uses unbatched divide-and-conquer algorithm, and batched Jacobi otherwise. Since the divide-and-conquer implementation is not batched, this can result in a big performance cliff, as shown below (1000x difference between 32x32 and 33x33 matrices). Perhaps the heuristics should be adjusted to account for batch size, too?
import jax
import jax.numpy as jnp
import timeit
def run_and_time(n):
a = jnp.ones((1024, n, n))
run = lambda: jax.jit(jax.lax.linalg.eigh)(a)[0].block_until_ready()
run() # Warmup
print(f"Time per batch (ms): {timeit.timeit(run, number=1) * 1000:.2f} (shape: {a.shape})")
for i in range(30, 35):
run_and_time(i)
Result:
Time per batch (ms): 0.43 (shape: (1024, 30, 30))
Time per batch (ms): 0.38 (shape: (1024, 31, 31))
Time per batch (ms): 0.39 (shape: (1024, 32, 32))
Time per batch (ms): 562.58 (shape: (1024, 33, 33))
Time per batch (ms): 575.23 (shape: (1024, 34, 34))
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.7.2.dev20250823+ede5074c3
jaxlib: 0.7.2.dev20250827
numpy: 2.2.6
python: 3.12.3 (main, Aug 14 2025, 17:47:21) [GCC 13.3.0]
device info: NVIDIA GB200-4, 4 local devices"
process_count: 1
platform: uname_result(system='Linux', node='...', release='6.11.0-1013-nvidia-64k', version='#13-Ubuntu SMP PREEMPT_DYNAMIC Thu Jul 24 23:34:45 UTC 2025', machine='aarch64')
JAX_TOOLBOX_REF=main
XLA_FLAGS=
$ nvidia-smi
Wed Aug 27 04:47:26 2025
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 580.82 Driver Version: 580.82 CUDA Version: 13.0 |
+-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA GB200 On | 00000008:01:00.0 Off | 0 |
| N/A 47C P0 242W / 1200W | 734MiB / 192527MiB | 1% Default |
| | | Disabled |
+-----------------------------------------+------------------------+----------------------+
| 1 NVIDIA GB200 On | 00000009:01:00.0 Off | 0 |
| N/A 47C P0 226W / 1200W | 713MiB / 192527MiB | 3% Default |
| | | Disabled |
+-----------------------------------------+------------------------+----------------------+
| 2 NVIDIA GB200 On | 00000018:01:00.0 Off | 0 |
| N/A 48C P0 220W / 1200W | 710MiB / 192527MiB | 3% Default |
| | | Disabled |
+-----------------------------------------+------------------------+----------------------+
| 3 NVIDIA GB200 On | 00000019:01:00.0 Off | 0 |
| N/A 47C P0 233W / 1200W | 710MiB / 192527MiB | 2% Default |
| | | Disabled |
+-----------------------------------------+------------------------+----------------------+
jcmgray
Metadata
Metadata
Assignees
Labels
NVIDIA GPUIssues specific to NVIDIA GPUsIssues specific to NVIDIA GPUsperformancemake things lean and fastmake things lean and fast