Skip to content

Commit cd51b5f

Browse files
Refactor QR
1 parent 2ab60b3 commit cd51b5f

File tree

4 files changed

+512
-272
lines changed

4 files changed

+512
-272
lines changed

pytensor/tensor/nlinalg.py

Lines changed: 0 additions & 171 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,12 @@
55

66
import numpy as np
77

8-
import pytensor.tensor as pt
98
from pytensor import scalar as ps
109
from pytensor.compile.builders import OpFromGraph
1110
from pytensor.gradient import DisconnectedType
1211
from pytensor.graph.basic import Apply
1312
from pytensor.graph.op import Op
14-
from pytensor.ifelse import ifelse
1513
from pytensor.npy_2_compat import normalize_axis_tuple
16-
from pytensor.raise_op import Assert
1714
from pytensor.tensor import TensorLike
1815
from pytensor.tensor import basic as ptb
1916
from pytensor.tensor import math as ptm
@@ -468,173 +465,6 @@ def eigh(a, UPLO="L"):
468465
return Eigh(UPLO)(a)
469466

470467

471-
class QRFull(Op):
472-
"""
473-
Full QR Decomposition.
474-
475-
Computes the QR decomposition of a matrix.
476-
Factor the matrix a as qr, where q is orthonormal
477-
and r is upper-triangular.
478-
479-
"""
480-
481-
__props__ = ("mode",)
482-
483-
def __init__(self, mode):
484-
self.mode = mode
485-
486-
def make_node(self, x):
487-
x = as_tensor_variable(x)
488-
489-
assert x.ndim == 2, "The input of qr function should be a matrix."
490-
491-
in_dtype = x.type.numpy_dtype
492-
out_dtype = np.dtype(f"f{in_dtype.itemsize}")
493-
494-
q = matrix(dtype=out_dtype)
495-
496-
if self.mode != "raw":
497-
r = matrix(dtype=out_dtype)
498-
else:
499-
r = vector(dtype=out_dtype)
500-
501-
if self.mode != "r":
502-
q = matrix(dtype=out_dtype)
503-
outputs = [q, r]
504-
else:
505-
outputs = [r]
506-
507-
return Apply(self, [x], outputs)
508-
509-
def perform(self, node, inputs, outputs):
510-
(x,) = inputs
511-
assert x.ndim == 2, "The input of qr function should be a matrix."
512-
res = np.linalg.qr(x, self.mode)
513-
if self.mode != "r":
514-
outputs[0][0], outputs[1][0] = res
515-
else:
516-
outputs[0][0] = res
517-
518-
def L_op(self, inputs, outputs, output_grads):
519-
"""
520-
Reverse-mode gradient of the QR function.
521-
522-
References
523-
----------
524-
.. [1] Jinguo Liu. "Linear Algebra Autodiff (complex valued)", blog post https://giggleliu.github.io/posts/2019-04-02-einsumbp/
525-
.. [2] Hai-Jun Liao, Jin-Guo Liu, Lei Wang, Tao Xiang. "Differentiable Programming Tensor Networks", arXiv:1903.09650v2
526-
"""
527-
528-
from pytensor.tensor.slinalg import solve_triangular
529-
530-
(A,) = (cast(ptb.TensorVariable, x) for x in inputs)
531-
m, n = A.shape
532-
533-
def _H(x: ptb.TensorVariable):
534-
return x.conj().mT
535-
536-
def _copyltu(x: ptb.TensorVariable):
537-
return ptb.tril(x, k=0) + _H(ptb.tril(x, k=-1))
538-
539-
if self.mode == "raw":
540-
raise NotImplementedError("Gradient of qr not implemented for mode=raw")
541-
542-
elif self.mode == "r":
543-
# We need all the components of the QR to compute the gradient of A even if we only
544-
# use the upper triangular component in the cost function.
545-
Q, R = qr(A, mode="reduced")
546-
dQ = Q.zeros_like()
547-
dR = cast(ptb.TensorVariable, output_grads[0])
548-
549-
else:
550-
Q, R = (cast(ptb.TensorVariable, x) for x in outputs)
551-
if self.mode == "complete":
552-
qr_assert_op = Assert(
553-
"Gradient of qr not implemented for m x n matrices with m > n and mode=complete"
554-
)
555-
R = qr_assert_op(R, ptm.le(m, n))
556-
557-
new_output_grads = []
558-
is_disconnected = [
559-
isinstance(x.type, DisconnectedType) for x in output_grads
560-
]
561-
if all(is_disconnected):
562-
# This should never be reached by Pytensor
563-
return [DisconnectedType()()] # pragma: no cover
564-
565-
for disconnected, output_grad, output in zip(
566-
is_disconnected, output_grads, [Q, R], strict=True
567-
):
568-
if disconnected:
569-
new_output_grads.append(output.zeros_like())
570-
else:
571-
new_output_grads.append(output_grad)
572-
573-
(dQ, dR) = (cast(ptb.TensorVariable, x) for x in new_output_grads)
574-
575-
# gradient expression when m >= n
576-
M = R @ _H(dR) - _H(dQ) @ Q
577-
K = dQ + Q @ _copyltu(M)
578-
A_bar_m_ge_n = _H(solve_triangular(R, _H(K)))
579-
580-
# gradient expression when m < n
581-
Y = A[:, m:]
582-
U = R[:, :m]
583-
dU, dV = dR[:, :m], dR[:, m:]
584-
dQ_Yt_dV = dQ + Y @ _H(dV)
585-
M = U @ _H(dU) - _H(dQ_Yt_dV) @ Q
586-
X_bar = _H(solve_triangular(U, _H(dQ_Yt_dV + Q @ _copyltu(M))))
587-
Y_bar = Q @ dV
588-
A_bar_m_lt_n = pt.concatenate([X_bar, Y_bar], axis=1)
589-
590-
return [ifelse(ptm.ge(m, n), A_bar_m_ge_n, A_bar_m_lt_n)]
591-
592-
593-
def qr(a, mode="reduced"):
594-
"""
595-
Computes the QR decomposition of a matrix.
596-
Factor the matrix a as qr, where q
597-
is orthonormal and r is upper-triangular.
598-
599-
Parameters
600-
----------
601-
a : array_like, shape (M, N)
602-
Matrix to be factored.
603-
604-
mode : {'reduced', 'complete', 'r', 'raw'}, optional
605-
If K = min(M, N), then
606-
607-
'reduced'
608-
returns q, r with dimensions (M, K), (K, N)
609-
610-
'complete'
611-
returns q, r with dimensions (M, M), (M, N)
612-
613-
'r'
614-
returns r only with dimensions (K, N)
615-
616-
'raw'
617-
returns h, tau with dimensions (N, M), (K,)
618-
619-
Note that array h returned in 'raw' mode is
620-
transposed for calling Fortran.
621-
622-
Default mode is 'reduced'
623-
624-
Returns
625-
-------
626-
q : matrix of float or complex, optional
627-
A matrix with orthonormal columns. When mode = 'complete' the
628-
result is an orthogonal/unitary matrix depending on whether or
629-
not a is real/complex. The determinant may be either +/- 1 in
630-
that case.
631-
r : matrix of float or complex, optional
632-
The upper-triangular matrix.
633-
634-
"""
635-
return QRFull(mode)(a)
636-
637-
638468
class SVD(Op):
639469
"""
640470
Computes singular value decomposition of matrix A, into U, S, V such that A = U @ S @ V
@@ -1291,7 +1121,6 @@ def kron(a, b):
12911121
"det",
12921122
"eig",
12931123
"eigh",
1294-
"qr",
12951124
"svd",
12961125
"lstsq",
12971126
"matrix_power",

0 commit comments

Comments
 (0)