Skip to content

Commit b9fc4f8

Browse files
Preserve static shape information in block_diag (#1529)
1 parent 68d8dc7 commit b9fc4f8

File tree

2 files changed

+29
-1
lines changed

2 files changed

+29
-1
lines changed

pytensor/tensor/slinalg.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1651,7 +1651,18 @@ class BlockDiagonal(BaseBlockDiagonal):
16511651
def make_node(self, *matrices):
16521652
matrices = self._validate_and_prepare_inputs(matrices, pt.as_tensor)
16531653
dtype = _largest_common_dtype(matrices)
1654-
out_type = pytensor.tensor.matrix(dtype=dtype)
1654+
1655+
shapes_by_dim = tuple(zip(*(m.type.shape for m in matrices)))
1656+
out_shape = tuple(
1657+
[
1658+
sum(dim_shapes)
1659+
if not any(shape is None for shape in dim_shapes)
1660+
else None
1661+
for dim_shapes in shapes_by_dim
1662+
]
1663+
)
1664+
1665+
out_type = pytensor.tensor.matrix(shape=out_shape, dtype=dtype)
16551666
return Apply(self, matrices, [out_type])
16561667

16571668
def perform(self, node, inputs, output_storage, params=None):

tests/tensor/test_slinalg.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1040,11 +1040,28 @@ def test_block_diagonal():
10401040
A = np.array([[1.0, 2.0], [3.0, 4.0]])
10411041
B = np.array([[5.0, 6.0], [7.0, 8.0]])
10421042
result = block_diag(A, B)
1043+
assert result.type.shape == (4, 4)
10431044
assert result.owner.op.core_op._props_dict() == {"n_inputs": 2}
10441045

10451046
np.testing.assert_allclose(result.eval(), scipy.linalg.block_diag(A, B))
10461047

10471048

1049+
def test_block_diagonal_static_shape():
1050+
A = pt.dmatrix("A", shape=(5, 5))
1051+
B = pt.dmatrix("B", shape=(3, 10))
1052+
result = block_diag(A, B)
1053+
assert result.type.shape == (8, 15)
1054+
1055+
A = pt.dmatrix("A", shape=(5, 5))
1056+
B = pt.dmatrix("B", shape=(3, None))
1057+
result = block_diag(A, B)
1058+
assert result.type.shape == (8, None)
1059+
1060+
A = pt.dmatrix("A", shape=(None, 5))
1061+
result = block_diag(A, B)
1062+
assert result.type.shape == (None, None)
1063+
1064+
10481065
def test_block_diagonal_grad():
10491066
A = np.array([[1.0, 2.0], [3.0, 4.0]])
10501067
B = np.array([[5.0, 6.0], [7.0, 8.0]])

0 commit comments

Comments
 (0)