Skip to content

Commit 015a31f

Browse files
author
Jesse Grabowski
committed
Implement L_op instead of grad in Eigh
1 parent 271bf5c commit 015a31f

File tree

2 files changed

+19
-6
lines changed

2 files changed

+19
-6
lines changed

pytensor/tensor/nlinalg.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,7 @@ def infer_shape(self, fgraph, node, shapes):
346346

347347
return [(n,), (n, n)]
348348

349-
def L_op(self, inputs, outputs, grad_outputs):
349+
def L_op(self, inputs, outputs, output_grads):
350350
raise NotImplementedError(
351351
"Gradients for Eig is not implemented because it always returns complex values, "
352352
"for which autodiff is not yet supported in PyTensor (PRs welcome :) ).\n"
@@ -404,7 +404,7 @@ def perform(self, node, inputs, outputs):
404404
(w, v) = outputs
405405
w[0], v[0] = np.linalg.eigh(x, self.UPLO)
406406

407-
def grad(self, inputs, g_outputs):
407+
def L_op(self, inputs, outputs, output_grads):
408408
r"""The gradient function should return
409409
410410
.. math:: \sum_n\left(W_n\frac{\partial\,w_n}
@@ -428,10 +428,9 @@ def grad(self, inputs, g_outputs):
428428
429429
"""
430430
(x,) = inputs
431-
w, v = self(x)
432-
# Replace gradients wrt disconnected variables with
433-
# zeros. This is a work-around for issue #1063.
434-
gw, gv = _zero_disconnected([w, v], g_outputs)
431+
w, v = outputs
432+
gw, gv = _zero_disconnected([w, v], output_grads)
433+
435434
return [EighGrad(self.UPLO)(x, w, v, gw, gv)]
436435

437436

tests/tensor/test_nlinalg.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -453,6 +453,20 @@ def test_eval(self):
453453
class TestEigh(TestEig):
454454
op = staticmethod(eigh)
455455

456+
def test_eval(self):
457+
A = matrix(dtype=self.dtype)
458+
fn = function([A], self.op(A))
459+
460+
# Symmetric input (real eigenvalues)
461+
A_val = self.rng.normal(size=(5, 5)).astype(self.dtype)
462+
A_val = A_val + A_val.T
463+
464+
w, v = fn(A_val)
465+
w_np, v_np = np.linalg.eigh(A_val)
466+
np.testing.assert_allclose(w, w_np)
467+
np.testing.assert_allclose(v, v_np)
468+
assert_array_almost_equal(np.dot(A_val, v), w * v)
469+
456470
def test_uplo(self):
457471
S = self.S
458472
a = matrix(dtype=self.dtype)

0 commit comments

Comments
 (0)