|
| 1 | +from typing import Literal |
| 2 | + |
| 3 | +import scipy.sparse |
| 4 | + |
| 5 | +from pytensor.graph import Apply |
| 6 | +from pytensor.sparse import as_sparse_or_tensor_variable, matrix |
| 7 | +from pytensor.tensor import TensorVariable |
| 8 | +from pytensor.tensor.slinalg import BaseBlockDiagonal, _largest_common_dtype |
| 9 | + |
| 10 | + |
| 11 | +class SparseBlockDiagonal(BaseBlockDiagonal): |
| 12 | + __props__ = ( |
| 13 | + "n_inputs", |
| 14 | + "format", |
| 15 | + ) |
| 16 | + |
| 17 | + def __init__(self, n_inputs: int, format: Literal["csc", "csr"] = "csc"): |
| 18 | + super().__init__(n_inputs) |
| 19 | + self.format = format |
| 20 | + |
| 21 | + def make_node(self, *matrices): |
| 22 | + matrices = self._validate_and_prepare_inputs( |
| 23 | + matrices, as_sparse_or_tensor_variable |
| 24 | + ) |
| 25 | + dtype = _largest_common_dtype(matrices) |
| 26 | + out_type = matrix(format=self.format, dtype=dtype) |
| 27 | + |
| 28 | + return Apply(self, matrices, [out_type]) |
| 29 | + |
| 30 | + def perform(self, node, inputs, output_storage, params=None): |
| 31 | + dtype = node.outputs[0].type.dtype |
| 32 | + output_storage[0][0] = scipy.sparse.block_diag( |
| 33 | + inputs, format=self.format |
| 34 | + ).astype(dtype) |
| 35 | + |
| 36 | + |
| 37 | +def block_diag(*matrices: TensorVariable, format: Literal["csc", "csr"] = "csc"): |
| 38 | + r""" |
| 39 | + Construct a block diagonal matrix from a sequence of input matrices. |
| 40 | +
|
| 41 | + Given the inputs `A`, `B` and `C`, the output will have these arrays arranged on the diagonal: |
| 42 | +
|
| 43 | + [[A, 0, 0], |
| 44 | + [0, B, 0], |
| 45 | + [0, 0, C]] |
| 46 | +
|
| 47 | + Parameters |
| 48 | + ---------- |
| 49 | + A, B, C ... : tensors |
| 50 | + Input tensors to form the block diagonal matrix. last two dimensions of the inputs will be used, and all |
| 51 | + inputs should have at least 2 dimensins. |
| 52 | +
|
| 53 | + Note that the input matrices need not be sparse themselves, and will be automatically converted to the |
| 54 | + requested format if they are not. |
| 55 | +
|
| 56 | + format: str, optional |
| 57 | + The format of the output sparse matrix. One of 'csr' or 'csc'. Default is 'csr'. Ignored if sparse=False. |
| 58 | +
|
| 59 | + Returns |
| 60 | + ------- |
| 61 | + out: sparse matrix tensor |
| 62 | + Symbolic sparse matrix in the specified format. |
| 63 | +
|
| 64 | + Examples |
| 65 | + -------- |
| 66 | + Create a sparse block diagonal matrix from two sparse 2x2 matrices: |
| 67 | +
|
| 68 | + .. testcode:: |
| 69 | + import numpy as np |
| 70 | + from pytensor.sparse import block_diag |
| 71 | + from scipy.sparse import csr_matrix |
| 72 | +
|
| 73 | + A = csr_matrix([[1, 2], [3, 4]]) |
| 74 | + B = csr_matrix([[5, 6], [7, 8]]) |
| 75 | + result_sparse = block_diag(A, B, format='csr') |
| 76 | +
|
| 77 | + print(result_sparse) |
| 78 | + print(result_sparse.toarray().eval()) |
| 79 | +
|
| 80 | + .. testoutput:: |
| 81 | +
|
| 82 | + SparseVariable{csr,int64} |
| 83 | + [[1 2 0 0] |
| 84 | + [3 4 0 0] |
| 85 | + [0 0 5 6] |
| 86 | + [0 0 7 8]] |
| 87 | +
|
| 88 | + """ |
| 89 | + if len(matrices) == 1: |
| 90 | + return matrices |
| 91 | + |
| 92 | + _sparse_block_diagonal = SparseBlockDiagonal(n_inputs=len(matrices), format=format) |
| 93 | + return _sparse_block_diagonal(*matrices) |
0 commit comments