-
Notifications
You must be signed in to change notification settings - Fork 139
Fix shape issues in jax tridiagonal solve #1414
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
a95e1e1
d0239a3
a093550
6924f4e
4ee85c2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -54,7 +54,21 @@ def solve(a, b): | |
dl = jax.numpy.diagonal(a, offset=-1, axis1=-2, axis2=-1) | ||
d = jax.numpy.diagonal(a, offset=0, axis1=-2, axis2=-1) | ||
du = jax.numpy.diagonal(a, offset=1, axis1=-2, axis2=-1) | ||
return jax.lax.linalg.tridiagonal_solve(dl, d, du, b, lower=lower) | ||
# jax requires dl and du to have the same shape as d | ||
dl = jax.numpy.pad(dl, (1, 0)) | ||
du = jax.numpy.pad(du, (0, 1)) | ||
# if b is a vector, broadcast it to be a matrix | ||
benmaier marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
b_is_vec = len(b.shape) == 1 | ||
|
||
if b_is_vec: | ||
b = jax.numpy.expand_dims(b, -1) | ||
|
||
res = jax.lax.linalg.tridiagonal_solve(dl, d, du, b) | ||
|
||
if b_is_vec: | ||
# if b is a vector, return a vector | ||
return res.flatten() | ||
benmaier marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
else: | ||
benmaier marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
return res | ||
|
||
else: | ||
if assume_a not in ("gen", "sym", "her", "pos"): | ||
|
Uh oh!
There was an error while loading. Please reload this page.