Skip to content

Commit bbd12a7

Browse files
committed
Add error message for incompatible static shape in Dot Op
1 parent 262d3aa commit bbd12a7

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

pytensor/tensor/math.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3025,6 +3025,11 @@ def make_node(self, *inputs):
30253025
)
30263026

30273027
sx, sy = (input.type.shape for input in inputs)
3028+
if sx[-1] is not None and sy[0] is not None and sx[-1] != sy[0]:
3029+
raise ValueError(
3030+
f"Incompatible shared dimension for dot product: {sx}, {sy}"
3031+
)
3032+
30283033
if len(sy) == 2:
30293034
sz = sx[:-1] + sy[-1:]
30303035
elif len(sy) == 1:

0 commit comments

Comments
 (0)