Skip to content

Conversation

jaro-sevcik
Copy link
Contributor

@jaro-sevcik jaro-sevcik commented Aug 27, 2025

Let us use a batched syev for jax.lax.linalg.eigh to speed up the function
for non-trivial batch sizes.

Fixes #31368

Depends on openxla/xla#30810

@hawkinsp
Copy link
Collaborator

This LGTM. Do you also want to change the C++ code to select jacobi always when passed "default" as the algorithm?

@hawkinsp
Copy link
Collaborator

There are some test failures. PTAL? Most look like small tolerance bumps but e.g. the rank deficient one seems large.

@jaro-sevcik jaro-sevcik force-pushed the use-jacobi-for-lax-eigh branch 2 times, most recently from e7590e9 to 370da03 Compare September 1, 2025 14:37
@jaro-sevcik jaro-sevcik changed the title [gpu] Always use the Jacobi algorithm for lax.linalg.eigh [gpu] Use batched syev for lax.linalg.eigh Sep 1, 2025
@jaro-sevcik jaro-sevcik force-pushed the use-jacobi-for-lax-eigh branch 2 times, most recently from a538331 to 3ed5f81 Compare September 3, 2025 14:29
Use syevBatched for eigh
@jaro-sevcik jaro-sevcik force-pushed the use-jacobi-for-lax-eigh branch from 3ed5f81 to 158d4bd Compare September 4, 2025 19:36
@google-ml-butler google-ml-butler bot added kokoro:force-run pull ready Ready for copybara import and testing labels Sep 10, 2025
@copybara-service copybara-service bot merged commit ee59f81 into jax-ml:main Sep 10, 2025
23 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
kokoro:force-run pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

jax.lax.linalg.eigh performance cliff for larger batch sizes (on GPUs)
2 participants