Skip to content

Commit b33a422

Browse files
Reorganize the sparse module
1 parent 602eb04 commit b33a422

File tree

12 files changed

+4497
-4839
lines changed

12 files changed

+4497
-4839
lines changed

pytensor/sparse/basic.py

Lines changed: 165 additions & 2562 deletions
Large diffs are not rendered by default.

pytensor/sparse/linalg.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
from typing import Literal
2+
3+
import scipy.sparse
4+
5+
from pytensor.graph import Apply
6+
from pytensor.sparse import as_sparse_or_tensor_variable, matrix
7+
from pytensor.tensor import TensorVariable
8+
from pytensor.tensor.slinalg import BaseBlockDiagonal, _largest_common_dtype
9+
10+
11+
class SparseBlockDiagonal(BaseBlockDiagonal):
12+
__props__ = (
13+
"n_inputs",
14+
"format",
15+
)
16+
17+
def __init__(self, n_inputs: int, format: Literal["csc", "csr"] = "csc"):
18+
super().__init__(n_inputs)
19+
self.format = format
20+
21+
def make_node(self, *matrices):
22+
matrices = self._validate_and_prepare_inputs(
23+
matrices, as_sparse_or_tensor_variable
24+
)
25+
dtype = _largest_common_dtype(matrices)
26+
out_type = matrix(format=self.format, dtype=dtype)
27+
28+
return Apply(self, matrices, [out_type])
29+
30+
def perform(self, node, inputs, output_storage, params=None):
31+
dtype = node.outputs[0].type.dtype
32+
output_storage[0][0] = scipy.sparse.block_diag(
33+
inputs, format=self.format
34+
).astype(dtype)
35+
36+
37+
def block_diag(*matrices: TensorVariable, format: Literal["csc", "csr"] = "csc"):
38+
r"""
39+
Construct a block diagonal matrix from a sequence of input matrices.
40+
41+
Given the inputs `A`, `B` and `C`, the output will have these arrays arranged on the diagonal:
42+
43+
[[A, 0, 0],
44+
[0, B, 0],
45+
[0, 0, C]]
46+
47+
Parameters
48+
----------
49+
A, B, C ... : tensors
50+
Input tensors to form the block diagonal matrix. last two dimensions of the inputs will be used, and all
51+
inputs should have at least 2 dimensins.
52+
53+
Note that the input matrices need not be sparse themselves, and will be automatically converted to the
54+
requested format if they are not.
55+
56+
format: str, optional
57+
The format of the output sparse matrix. One of 'csr' or 'csc'. Default is 'csr'. Ignored if sparse=False.
58+
59+
Returns
60+
-------
61+
out: sparse matrix tensor
62+
Symbolic sparse matrix in the specified format.
63+
64+
Examples
65+
--------
66+
Create a sparse block diagonal matrix from two sparse 2x2 matrices:
67+
68+
.. testcode::
69+
import numpy as np
70+
from pytensor.sparse import block_diag
71+
from scipy.sparse import csr_matrix
72+
73+
A = csr_matrix([[1, 2], [3, 4]])
74+
B = csr_matrix([[5, 6], [7, 8]])
75+
result_sparse = block_diag(A, B, format='csr')
76+
77+
print(result_sparse)
78+
print(result_sparse.toarray().eval())
79+
80+
.. testoutput::
81+
82+
SparseVariable{csr,int64}
83+
[[1 2 0 0]
84+
[3 4 0 0]
85+
[0 0 5 6]
86+
[0 0 7 8]]
87+
88+
"""
89+
if len(matrices) == 1:
90+
return matrices
91+
92+
_sparse_block_diagonal = SparseBlockDiagonal(n_inputs=len(matrices), format=format)
93+
return _sparse_block_diagonal(*matrices)

0 commit comments

Comments
 (0)