Skip to content

jax.lax.linalg.eigh performance cliff for larger batch sizes (on GPUs) #31368

@jaro-sevcik

Description

@jaro-sevcik

Description

Currently JAX uses a simple heuristics for jax.lax.linalg.eigh on GPUs - for >32 cols it uses unbatched divide-and-conquer algorithm, and batched Jacobi otherwise. Since the divide-and-conquer implementation is not batched, this can result in a big performance cliff, as shown below (1000x difference between 32x32 and 33x33 matrices). Perhaps the heuristics should be adjusted to account for batch size, too?

import jax
import jax.numpy as jnp
import timeit

def run_and_time(n):
    a = jnp.ones((1024, n, n))
    run = lambda: jax.jit(jax.lax.linalg.eigh)(a)[0].block_until_ready()
    run() # Warmup
    print(f"Time per batch (ms): {timeit.timeit(run, number=1) * 1000:.2f} (shape: {a.shape})")

for i in range(30, 35):
    run_and_time(i)

Result:

Time per batch (ms): 0.43 (shape: (1024, 30, 30))
Time per batch (ms): 0.38 (shape: (1024, 31, 31))
Time per batch (ms): 0.39 (shape: (1024, 32, 32))
Time per batch (ms): 562.58 (shape: (1024, 33, 33))
Time per batch (ms): 575.23 (shape: (1024, 34, 34))

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.7.2.dev20250823+ede5074c3
jaxlib: 0.7.2.dev20250827
numpy:  2.2.6
python: 3.12.3 (main, Aug 14 2025, 17:47:21) [GCC 13.3.0]
device info: NVIDIA GB200-4, 4 local devices"
process_count: 1
platform: uname_result(system='Linux', node='...', release='6.11.0-1013-nvidia-64k', version='#13-Ubuntu SMP PREEMPT_DYNAMIC Thu Jul 24 23:34:45 UTC 2025', machine='aarch64')
JAX_TOOLBOX_REF=main
XLA_FLAGS=

$ nvidia-smi
Wed Aug 27 04:47:26 2025
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 580.82                 Driver Version: 580.82         CUDA Version: 13.0     |
+-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA GB200                   On  |   00000008:01:00.0 Off |                    0 |
| N/A   47C    P0            242W / 1200W |     734MiB / 192527MiB |      1%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA GB200                   On  |   00000009:01:00.0 Off |                    0 |
| N/A   47C    P0            226W / 1200W |     713MiB / 192527MiB |      3%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   2  NVIDIA GB200                   On  |   00000018:01:00.0 Off |                    0 |
| N/A   48C    P0            220W / 1200W |     710MiB / 192527MiB |      3%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   3  NVIDIA GB200                   On  |   00000019:01:00.0 Off |                    0 |
| N/A   47C    P0            233W / 1200W |     710MiB / 192527MiB |      2%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+

Metadata

Metadata

Assignees

No one assigned

    Labels

    NVIDIA GPUIssues specific to NVIDIA GPUsperformancemake things lean and fast

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions