|
5 | 5 |
|
6 | 6 | import numpy as np
|
7 | 7 |
|
8 |
| -import pytensor.tensor as pt |
9 | 8 | from pytensor import scalar as ps
|
10 | 9 | from pytensor.compile.builders import OpFromGraph
|
11 | 10 | from pytensor.gradient import DisconnectedType
|
12 | 11 | from pytensor.graph.basic import Apply
|
13 | 12 | from pytensor.graph.op import Op
|
14 |
| -from pytensor.ifelse import ifelse |
15 | 13 | from pytensor.npy_2_compat import normalize_axis_tuple
|
16 |
| -from pytensor.raise_op import Assert |
17 | 14 | from pytensor.tensor import TensorLike
|
18 | 15 | from pytensor.tensor import basic as ptb
|
19 | 16 | from pytensor.tensor import math as ptm
|
@@ -468,173 +465,6 @@ def eigh(a, UPLO="L"):
|
468 | 465 | return Eigh(UPLO)(a)
|
469 | 466 |
|
470 | 467 |
|
471 |
| -class QRFull(Op): |
472 |
| - """ |
473 |
| - Full QR Decomposition. |
474 |
| -
|
475 |
| - Computes the QR decomposition of a matrix. |
476 |
| - Factor the matrix a as qr, where q is orthonormal |
477 |
| - and r is upper-triangular. |
478 |
| -
|
479 |
| - """ |
480 |
| - |
481 |
| - __props__ = ("mode",) |
482 |
| - |
483 |
| - def __init__(self, mode): |
484 |
| - self.mode = mode |
485 |
| - |
486 |
| - def make_node(self, x): |
487 |
| - x = as_tensor_variable(x) |
488 |
| - |
489 |
| - assert x.ndim == 2, "The input of qr function should be a matrix." |
490 |
| - |
491 |
| - in_dtype = x.type.numpy_dtype |
492 |
| - out_dtype = np.dtype(f"f{in_dtype.itemsize}") |
493 |
| - |
494 |
| - q = matrix(dtype=out_dtype) |
495 |
| - |
496 |
| - if self.mode != "raw": |
497 |
| - r = matrix(dtype=out_dtype) |
498 |
| - else: |
499 |
| - r = vector(dtype=out_dtype) |
500 |
| - |
501 |
| - if self.mode != "r": |
502 |
| - q = matrix(dtype=out_dtype) |
503 |
| - outputs = [q, r] |
504 |
| - else: |
505 |
| - outputs = [r] |
506 |
| - |
507 |
| - return Apply(self, [x], outputs) |
508 |
| - |
509 |
| - def perform(self, node, inputs, outputs): |
510 |
| - (x,) = inputs |
511 |
| - assert x.ndim == 2, "The input of qr function should be a matrix." |
512 |
| - res = np.linalg.qr(x, self.mode) |
513 |
| - if self.mode != "r": |
514 |
| - outputs[0][0], outputs[1][0] = res |
515 |
| - else: |
516 |
| - outputs[0][0] = res |
517 |
| - |
518 |
| - def L_op(self, inputs, outputs, output_grads): |
519 |
| - """ |
520 |
| - Reverse-mode gradient of the QR function. |
521 |
| -
|
522 |
| - References |
523 |
| - ---------- |
524 |
| - .. [1] Jinguo Liu. "Linear Algebra Autodiff (complex valued)", blog post https://giggleliu.github.io/posts/2019-04-02-einsumbp/ |
525 |
| - .. [2] Hai-Jun Liao, Jin-Guo Liu, Lei Wang, Tao Xiang. "Differentiable Programming Tensor Networks", arXiv:1903.09650v2 |
526 |
| - """ |
527 |
| - |
528 |
| - from pytensor.tensor.slinalg import solve_triangular |
529 |
| - |
530 |
| - (A,) = (cast(ptb.TensorVariable, x) for x in inputs) |
531 |
| - m, n = A.shape |
532 |
| - |
533 |
| - def _H(x: ptb.TensorVariable): |
534 |
| - return x.conj().mT |
535 |
| - |
536 |
| - def _copyltu(x: ptb.TensorVariable): |
537 |
| - return ptb.tril(x, k=0) + _H(ptb.tril(x, k=-1)) |
538 |
| - |
539 |
| - if self.mode == "raw": |
540 |
| - raise NotImplementedError("Gradient of qr not implemented for mode=raw") |
541 |
| - |
542 |
| - elif self.mode == "r": |
543 |
| - # We need all the components of the QR to compute the gradient of A even if we only |
544 |
| - # use the upper triangular component in the cost function. |
545 |
| - Q, R = qr(A, mode="reduced") |
546 |
| - dQ = Q.zeros_like() |
547 |
| - dR = cast(ptb.TensorVariable, output_grads[0]) |
548 |
| - |
549 |
| - else: |
550 |
| - Q, R = (cast(ptb.TensorVariable, x) for x in outputs) |
551 |
| - if self.mode == "complete": |
552 |
| - qr_assert_op = Assert( |
553 |
| - "Gradient of qr not implemented for m x n matrices with m > n and mode=complete" |
554 |
| - ) |
555 |
| - R = qr_assert_op(R, ptm.le(m, n)) |
556 |
| - |
557 |
| - new_output_grads = [] |
558 |
| - is_disconnected = [ |
559 |
| - isinstance(x.type, DisconnectedType) for x in output_grads |
560 |
| - ] |
561 |
| - if all(is_disconnected): |
562 |
| - # This should never be reached by Pytensor |
563 |
| - return [DisconnectedType()()] # pragma: no cover |
564 |
| - |
565 |
| - for disconnected, output_grad, output in zip( |
566 |
| - is_disconnected, output_grads, [Q, R], strict=True |
567 |
| - ): |
568 |
| - if disconnected: |
569 |
| - new_output_grads.append(output.zeros_like()) |
570 |
| - else: |
571 |
| - new_output_grads.append(output_grad) |
572 |
| - |
573 |
| - (dQ, dR) = (cast(ptb.TensorVariable, x) for x in new_output_grads) |
574 |
| - |
575 |
| - # gradient expression when m >= n |
576 |
| - M = R @ _H(dR) - _H(dQ) @ Q |
577 |
| - K = dQ + Q @ _copyltu(M) |
578 |
| - A_bar_m_ge_n = _H(solve_triangular(R, _H(K))) |
579 |
| - |
580 |
| - # gradient expression when m < n |
581 |
| - Y = A[:, m:] |
582 |
| - U = R[:, :m] |
583 |
| - dU, dV = dR[:, :m], dR[:, m:] |
584 |
| - dQ_Yt_dV = dQ + Y @ _H(dV) |
585 |
| - M = U @ _H(dU) - _H(dQ_Yt_dV) @ Q |
586 |
| - X_bar = _H(solve_triangular(U, _H(dQ_Yt_dV + Q @ _copyltu(M)))) |
587 |
| - Y_bar = Q @ dV |
588 |
| - A_bar_m_lt_n = pt.concatenate([X_bar, Y_bar], axis=1) |
589 |
| - |
590 |
| - return [ifelse(ptm.ge(m, n), A_bar_m_ge_n, A_bar_m_lt_n)] |
591 |
| - |
592 |
| - |
593 |
| -def qr(a, mode="reduced"): |
594 |
| - """ |
595 |
| - Computes the QR decomposition of a matrix. |
596 |
| - Factor the matrix a as qr, where q |
597 |
| - is orthonormal and r is upper-triangular. |
598 |
| -
|
599 |
| - Parameters |
600 |
| - ---------- |
601 |
| - a : array_like, shape (M, N) |
602 |
| - Matrix to be factored. |
603 |
| -
|
604 |
| - mode : {'reduced', 'complete', 'r', 'raw'}, optional |
605 |
| - If K = min(M, N), then |
606 |
| -
|
607 |
| - 'reduced' |
608 |
| - returns q, r with dimensions (M, K), (K, N) |
609 |
| -
|
610 |
| - 'complete' |
611 |
| - returns q, r with dimensions (M, M), (M, N) |
612 |
| -
|
613 |
| - 'r' |
614 |
| - returns r only with dimensions (K, N) |
615 |
| -
|
616 |
| - 'raw' |
617 |
| - returns h, tau with dimensions (N, M), (K,) |
618 |
| -
|
619 |
| - Note that array h returned in 'raw' mode is |
620 |
| - transposed for calling Fortran. |
621 |
| -
|
622 |
| - Default mode is 'reduced' |
623 |
| -
|
624 |
| - Returns |
625 |
| - ------- |
626 |
| - q : matrix of float or complex, optional |
627 |
| - A matrix with orthonormal columns. When mode = 'complete' the |
628 |
| - result is an orthogonal/unitary matrix depending on whether or |
629 |
| - not a is real/complex. The determinant may be either +/- 1 in |
630 |
| - that case. |
631 |
| - r : matrix of float or complex, optional |
632 |
| - The upper-triangular matrix. |
633 |
| -
|
634 |
| - """ |
635 |
| - return QRFull(mode)(a) |
636 |
| - |
637 |
| - |
638 | 468 | class SVD(Op):
|
639 | 469 | """
|
640 | 470 | Computes singular value decomposition of matrix A, into U, S, V such that A = U @ S @ V
|
@@ -1291,7 +1121,6 @@ def kron(a, b):
|
1291 | 1121 | "det",
|
1292 | 1122 | "eig",
|
1293 | 1123 | "eigh",
|
1294 |
| - "qr", |
1295 | 1124 | "svd",
|
1296 | 1125 | "lstsq",
|
1297 | 1126 | "matrix_power",
|
|
0 commit comments