diff --git a/pytensor/tensor/__init__.py b/pytensor/tensor/__init__.py index afcc08a612..75e17e1e21 100644 --- a/pytensor/tensor/__init__.py +++ b/pytensor/tensor/__init__.py @@ -107,7 +107,6 @@ def _get_vector_length_Constant(op: Op | Variable, var: Constant) -> int: from pytensor.tensor import ( blas, blas_c, - blas_scipy, sharedvar, xlogx, ) diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index f189766c9c..931c7009b3 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -1801,8 +1801,7 @@ def do_constant_folding(self, fgraph, node): | pytensor.tensor.blas.Gemv | pytensor.tensor.blas_c.CGemv | pytensor.tensor.blas.Ger - | pytensor.tensor.blas_c.CGer - | pytensor.tensor.blas_scipy.ScipyGer, + | pytensor.tensor.blas_c.CGer, ) ): # Ops that will work inplace on the Alloc. So if they diff --git a/pytensor/tensor/blas.py b/pytensor/tensor/blas.py index fc8afcea50..a183431a0e 100644 --- a/pytensor/tensor/blas.py +++ b/pytensor/tensor/blas.py @@ -83,6 +83,7 @@ from pathlib import Path import numpy as np +from scipy.linalg import get_blas_funcs from pytensor.graph import vectorize_graph from pytensor.npy_2_compat import normalize_axis_tuple @@ -288,18 +289,17 @@ def make_node(self, A, alpha, x, y): return Apply(self, inputs, [A.type()]) - def perform(self, node, inp, out): - cA, calpha, cx, cy = inp - (cZ,) = out - if self.destructive: - A = cA - else: - A = cA.copy() - if calpha != 1: - A += calpha * np.outer(cx, cy) - else: - A += np.outer(cx, cy) - cZ[0] = A + def perform(self, node, inputs, output_storage): + A, alpha, x, y = inputs + if A.size: + # GER doesn't handle zero-sized inputs + ger_func = get_blas_funcs("ger", dtype=A.dtype) + if A.flags["C_CONTIGUOUS"]: + # Work on transposed system to avoid copying + A = ger_func(alpha, y, x, a=A.T, overwrite_a=self.destructive).T + else: + A = ger_func(alpha, x, y, a=A, overwrite_a=self.destructive) + output_storage[0][0] = A def infer_shape(self, fgraph, node, input_shapes): return [input_shapes[0]] @@ -1128,16 +1128,8 @@ def make_node(self, x, y): outputs = [tensor(dtype=x.type.dtype, shape=(x.type.shape[0], y.type.shape[1]))] return Apply(self, [x, y], outputs) - def perform(self, node, inp, out): - x, y = inp - (z,) = out - try: - z[0] = np.asarray(np.dot(x, y)) - except ValueError as e: - # The error raised by numpy has no shape information, we mean to - # add that - e.args = (*e.args, x.shape, y.shape) - raise + def perform(self, node, inputs, output_storage): + output_storage[0][0] = np.dot(*inputs) def infer_shape(self, fgraph, node, input_shapes): return [[input_shapes[0][0], input_shapes[1][1]]] diff --git a/pytensor/tensor/blas_scipy.py b/pytensor/tensor/blas_scipy.py deleted file mode 100644 index bb3ccf9354..0000000000 --- a/pytensor/tensor/blas_scipy.py +++ /dev/null @@ -1,34 +0,0 @@ -""" -Implementations of BLAS Ops based on scipy's BLAS bindings. -""" - -from pytensor.tensor.blas import Ger - - -class ScipyGer(Ger): - def perform(self, node, inputs, output_storage): - from scipy.linalg.blas import get_blas_funcs - - cA, calpha, cx, cy = inputs - (cZ,) = output_storage - # N.B. some versions of scipy (e.g. mine) don't actually work - # in-place on a, even when I tell it to. - A = cA - local_ger = get_blas_funcs("ger", dtype=cA.dtype) - if A.size == 0: - # We don't have to compute anything, A is empty. - # We need this special case because Numpy considers it - # C-contiguous, which is confusing. - if not self.destructive: - # Sometimes numpy thinks empty matrices can share memory, - # so here to stop DebugMode from complaining. - A = A.copy() - elif A.flags["C_CONTIGUOUS"]: - A = local_ger(calpha, cy, cx, a=A.T, overwrite_a=int(self.destructive)).T - else: - A = local_ger(calpha, cx, cy, a=A, overwrite_a=int(self.destructive)) - cZ[0] = A - - -scipy_ger_no_inplace = ScipyGer(False) -scipy_ger_inplace = ScipyGer(True) diff --git a/pytensor/tensor/math.py b/pytensor/tensor/math.py index 743da35c84..dd4280a149 100644 --- a/pytensor/tensor/math.py +++ b/pytensor/tensor/math.py @@ -40,12 +40,13 @@ get_normalized_batch_axes, scalar_elemwise, ) -from pytensor.tensor.shape import shape, specify_broadcastable +from pytensor.tensor.shape import shape, specify_shape from pytensor.tensor.type import ( DenseTensorType, complex_dtypes, continuous_dtypes, discrete_dtypes, + float_dtypes, int_dtypes, tensor, uint_dtypes, @@ -2986,9 +2987,7 @@ def clip(x, min, max): class Dot(Op): """ - Computes the dot product of two variables. For two matrices, this is - equivalent to matrix multiplication. For two vectors, this is the inner - product. + Computes the dot product of two matrices variables Notes ----- @@ -3001,97 +3000,57 @@ class Dot(Op): """ + gufunc_signature = "(m,n),(n,p)->(m,p)" + gufunc_spec = ("matmul", 2, 1) __props__ = () - # the rationale for Dot22 is related to getting GEMM Ops into the - # graph. See Dot22 in tensor.blas for details. - - def make_node(self, *inputs): - inputs = list(map(as_tensor_variable, inputs)) + def make_node(self, x, y): + x = as_tensor_variable(x) + y = as_tensor_variable(y) - if len(inputs) != 2: - raise TypeError(f"Two arguments required, {len(inputs)} given ") - if inputs[0].ndim not in (1, 2): + if x.type.ndim != 2: raise TypeError( - "Input 0 (0-indexed) must have ndim of " - f"1 or 2, {int(inputs[0].ndim)} given. Consider calling " - "pytensor.tensor.dot instead." + f"Dot Op expects a 2D tensor as input 0, got {x} with {x.type.ndim} dimensions" ) - if inputs[1].ndim not in (1, 2): + if y.type.ndim != 2: raise TypeError( - "Input 1 (0-indexed) must have ndim of " - f"1 or 2, {int(inputs[1].ndim)} given. Consider calling " - "pytensor.tensor.dot instead." + f"Dot Op expects a 2D tensor as input 1, got {y} with {y.type.ndim} dimensions" ) - sx, sy = (input.type.shape for input in inputs) + sx, sy = x.type.shape, y.type.shape if sx[-1] is not None and sy[0] is not None and sx[-1] != sy[0]: raise ValueError( f"Incompatible shared dimension for dot product: {sx}, {sy}" ) + sz = sx[:-1] + sy[-1:] + outputs = [tensor(dtype=ps.upcast(x.type.dtype, y.type.dtype), shape=sz)] + return Apply(self, [x, y], outputs) - if len(sy) == 2: - sz = sx[:-1] + sy[-1:] - elif len(sy) == 1: - sz = sx[:-1] - - i_dtypes = [input.type.dtype for input in inputs] - outputs = [tensor(dtype=ps.upcast(*i_dtypes), shape=sz)] - return Apply(self, inputs, outputs) - - def perform(self, node, inp, out): - x, y = inp - (z,) = out - - # the asarray is here because dot between two vectors - # gives a numpy float object but we need to return a 0d - # ndarray - z[0] = np.asarray(np.dot(x, y)) + def perform(self, node, inputs, output_storage): + output_storage[0][0] = np.matmul(*inputs) def grad(self, inp, grads): x, y = inp (gz,) = grads - xdim, ydim, gdim = x.type.ndim, y.type.ndim, gz.type.ndim - - # grad is scalar, so x is vector and y is vector - if gdim == 0: - xgrad = gz * y - ygrad = gz * x - - # x is vector, y is matrix, grad is vector - elif xdim == 1 and ydim == 2: - xgrad = dot(gz, y.T) - ygrad = outer(x.T, gz) - - # x is matrix, y is vector, grad is vector - elif xdim == 2 and ydim == 1: - xgrad = outer(gz, y.T) - ygrad = dot(x.T, gz) - # x is matrix, y is matrix, grad is matrix - elif xdim == ydim == 2: - xgrad = dot(gz, y.T) - ygrad = dot(x.T, gz) + xgrad = self(gz, y.T) + ygrad = self(x.T, gz) # If x or y contain broadcastable dimensions but only one of # them know that a matching dimensions is broadcastable, the # above code don't always return the right broadcast pattern. # This cause problem down the road. See gh-1461. - if xgrad.broadcastable != x.broadcastable: - xgrad = specify_broadcastable( - xgrad, *(ax for (ax, b) in enumerate(x.type.broadcastable) if b) - ) - if ygrad.broadcastable != y.broadcastable: - ygrad = specify_broadcastable( - ygrad, *(ax for (ax, b) in enumerate(y.type.broadcastable) if b) - ) - - rval = xgrad, ygrad + if xgrad.type.shape != x.type.shape: + xgrad = specify_shape(xgrad, x.type.shape) + if ygrad.type.shape != y.type.shape: + ygrad = specify_shape(ygrad, y.type.shape) - for elem in rval: - assert elem.dtype.find("float") != -1 + if xgrad.type.dtype not in float_dtypes: + raise TypeError("Dot grad x output must be a float type") + if ygrad.type.dtype not in float_dtypes: + raise TypeError("Dot grad y output must be a float type") - return rval + return xgrad, ygrad def R_op(self, inputs, eval_points): # R_op for a \dot b evaluated at c for a and d for b is @@ -3116,24 +3075,7 @@ def R_op(self, inputs, eval_points): def infer_shape(self, fgraph, node, shapes): xshp, yshp = shapes - x, y = node.inputs - - # vector / vector - if x.ndim == 1 and y.ndim == 1: - return [()] - # matrix / vector - if x.ndim == 2 and y.ndim == 1: - return [xshp[:-1]] - # vector / matrix - if x.ndim == 1 and y.ndim == 2: - return [yshp[-1:]] - # matrix / matrix - if x.ndim == 2 and y.ndim == 2: - return [xshp[:-1] + yshp[-1:]] - raise NotImplementedError() - - def __str__(self): - return "dot" + return [[xshp[0], yshp[1]]] _dot = Dot() @@ -3215,7 +3157,24 @@ def dense_dot(a, b): elif a.ndim > 2 or b.ndim > 2: return tensordot(a, b, [[a.ndim - 1], [np.maximum(0, b.ndim - 2)]]) else: - return _dot(a, b) + row_vector = a.ndim == 1 + if row_vector: + # Promote to row matrix + a = a[None] + + col_vector = b.ndim == 1 + if col_vector: + # Promote to column matrix + b = b[:, None] + + out = _dot(a, b) + if row_vector: + # If we promoted a to a row matrix, we need to squeeze the first dimension + out = out.squeeze(0) + if col_vector: + # If we promoted b to a column matrix, we need to squeeze the last dimension + out = out.squeeze(-1) + return out def tensordot( @@ -3921,11 +3880,7 @@ def logsumexp(x, axis=None, keepdims=False): return log(sum(exp(x), axis=axis, keepdims=keepdims)) -_matmul = Blockwise( - _dot, - signature="(m,k),(k,n)->(m,n)", - gufunc_spec=("numpy.matmul", 2, 1), -) +_matmul = Blockwise(_dot, name="Matmul") def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None): @@ -3975,7 +3930,7 @@ def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None if x1.type.ndim == 0 or x2.type.ndim == 0: raise ValueError("matmul operand cannot be scalar") if x1.type.ndim == 1 and x2.type.ndim == 1: - out = _dot(x1, x2) + out = vecdot(x1, x2) elif x1.type.ndim == 1: out = vecmat(x1, x2) elif x2.type.ndim == 1: @@ -4139,23 +4094,7 @@ def vecmat( @_vectorize_node.register(Dot) def vectorize_node_dot(op, node, batched_x, batched_y): - old_x, old_y = node.inputs - old_x_ndim = old_x.type.ndim - old_y_ndim = old_y.type.ndim - match (old_x_ndim, old_y_ndim): - case (1, 1): - batch_fn = vecdot - case (2, 1): - batch_fn = matvec - case (1, 2): - batch_fn = vecmat - case (2, 2): - batch_fn = matmul - case _: - raise ValueError( - f"Core dot Op should have 1D or 2D inputs, got {old_x_ndim}D and {old_y_ndim}D." - ) - return batch_fn(batched_x, batched_y).owner + return matmul(batched_x, batched_y).owner def nan_to_num(x, nan=0.0, posinf=None, neginf=None): diff --git a/pytensor/tensor/rewriting/__init__.py b/pytensor/tensor/rewriting/__init__.py index 6d411d3827..2293e1d8dd 100644 --- a/pytensor/tensor/rewriting/__init__.py +++ b/pytensor/tensor/rewriting/__init__.py @@ -1,7 +1,6 @@ import pytensor.tensor.rewriting.basic import pytensor.tensor.rewriting.blas import pytensor.tensor.rewriting.blas_c -import pytensor.tensor.rewriting.blas_scipy import pytensor.tensor.rewriting.blockwise import pytensor.tensor.rewriting.einsum import pytensor.tensor.rewriting.elemwise diff --git a/pytensor/tensor/rewriting/blas.py b/pytensor/tensor/rewriting/blas.py index 74b4d235dc..685cec5785 100644 --- a/pytensor/tensor/rewriting/blas.py +++ b/pytensor/tensor/rewriting/blas.py @@ -107,7 +107,6 @@ ) from pytensor.tensor.rewriting.elemwise import local_dimshuffle_lift from pytensor.tensor.type import ( - DenseTensorType, TensorType, integer_dtypes, values_eq_approx_remove_inf_nan, @@ -580,12 +579,6 @@ def print_profile(cls, stream, prof, level=0): def local_dot_to_dot22(fgraph, node): # This works for tensor.outer too because basic.outer is a macro that # produces a dot(dimshuffle,dimshuffle) of form 4 below - if not isinstance(node.op, Dot): - return - - if any(not isinstance(i.type, DenseTensorType) for i in node.inputs): - return False - x, y = node.inputs if y.type.dtype != x.type.dtype: # TODO: upcast one so the types match @@ -593,16 +586,7 @@ def local_dot_to_dot22(fgraph, node): return if y.type.dtype in ("float16", "float32", "float64", "complex64", "complex128"): - if x.ndim == 2 and y.ndim == 2: - new_out = [_dot22(*node.inputs)] - elif x.ndim == 2 and y.ndim == 1: - new_out = [_dot22(x, y.dimshuffle(0, "x")).dimshuffle(0)] - elif x.ndim == 1 and y.ndim == 2: - new_out = [_dot22(x.dimshuffle("x", 0), y).dimshuffle(1)] - elif x.ndim == 1 and y.ndim == 1: - new_out = [_dot22(x.dimshuffle("x", 0), y.dimshuffle(0, "x")).dimshuffle()] - else: - return + new_out = [_dot22(*node.inputs)] copy_stack_trace(node.outputs, new_out) return new_out @@ -636,93 +620,89 @@ def local_inplace_ger(fgraph, node): @node_rewriter([gemm_no_inplace]) def local_gemm_to_gemv(fgraph, node): """GEMM acting on row or column matrices -> GEMV.""" - if node.op == gemm_no_inplace: - z, a, x, y, b = node.inputs - if z.broadcastable == x.broadcastable == (True, False): - r = gemv_no_inplace(z.dimshuffle(1), a, y.T, x.dimshuffle(1), b) - new_out = [r.dimshuffle("x", 0)] - elif z.broadcastable == y.broadcastable == (False, True): - r = gemv_no_inplace(z.dimshuffle(0), a, x, y.dimshuffle(0), b) - new_out = [r.dimshuffle(0, "x")] - else: - return - copy_stack_trace(node.outputs, new_out) - return new_out + z, a, x, y, b = node.inputs + if z.broadcastable == x.broadcastable == (True, False): + r = gemv_no_inplace(z.dimshuffle(1), a, y.T, x.dimshuffle(1), b) + new_out = [r.dimshuffle("x", 0)] + elif z.broadcastable == y.broadcastable == (False, True): + r = gemv_no_inplace(z.dimshuffle(0), a, x, y.dimshuffle(0), b) + new_out = [r.dimshuffle(0, "x")] + else: + return + copy_stack_trace(node.outputs, new_out) + return new_out @node_rewriter([gemm_no_inplace]) def local_gemm_to_ger(fgraph, node): """GEMM computing an outer-product -> GER.""" - if node.op == gemm_no_inplace: - z, a, x, y, b = node.inputs - if x.broadcastable[1] and y.broadcastable[0]: - # x and y are both vectors so this might qualifies for a GER - xv = x.dimshuffle(0) - yv = y.dimshuffle(1) - try: - bval = ptb.get_underlying_scalar_constant_value(b) - except NotScalarConstantError: - # b isn't a constant, GEMM is doing useful pre-scaling - return - - if bval == 1: # best case a natural GER - rval = ger(z, a, xv, yv) - new_out = [rval] - elif bval == 0: # GER on zeros_like should be faster than GEMM - zeros = ptb.zeros([x.shape[0], y.shape[1]], x.dtype) - rval = ger(zeros, a, xv, yv) - new_out = [rval] - else: - # if bval is another constant, then z is being usefully - # pre-scaled and GER isn't really the right tool for the job. - return - copy_stack_trace(node.outputs, new_out) - return new_out - + z, a, x, y, b = node.inputs + if x.broadcastable[1] and y.broadcastable[0]: + # x and y are both vectors so this might qualifies for a GER + xv = x.dimshuffle(0) + yv = y.dimshuffle(1) + try: + bval = ptb.get_underlying_scalar_constant_value(b) + except NotScalarConstantError: + # b isn't a constant, GEMM is doing useful pre-scaling + return -# TODO: delete this optimization when we have the proper dot->gemm->ger pipeline -# working -@node_rewriter([_dot22]) -def local_dot22_to_ger_or_gemv(fgraph, node): - """dot22 computing an outer-product -> GER.""" - if node.op == _dot22: - x, y = node.inputs - xb = x.broadcastable - yb = y.broadcastable - one = ptb.as_tensor_variable(np.asarray(1, dtype=x.dtype)) - zero = ptb.as_tensor_variable(np.asarray(0, dtype=x.dtype)) - if xb[1] and yb[0]: - # x and y are both vectors so this might qualifies for a GER - xv = x.dimshuffle(0) - yv = y.dimshuffle(1) - zeros = ptb.zeros([x.shape[0], y.shape[1]], dtype=x.dtype) - rval = ger(zeros, one, xv, yv) + if bval == 1: # best case a natural GER + rval = ger(z, a, xv, yv) + new_out = [rval] + elif bval == 0: # GER on zeros_like should be faster than GEMM + zeros = ptb.zeros([x.shape[0], y.shape[1]], x.dtype) + rval = ger(zeros, a, xv, yv) new_out = [rval] - elif xb[0] and yb[1]: - # x and y are both vectors so this qualifies for a sdot / ddot - # PyTensor's CGemv will call sdot/ddot at runtime, the Scipy Gemv may not - xv = x.dimshuffle(1) - zeros = ptb.AllocEmpty(x.dtype)(1) - rval = gemv_no_inplace(zeros, one, y.T, xv, zero) - new_out = [rval.dimshuffle("x", 0)] - elif xb[0] and not yb[0] and not yb[1]: - # x is vector, y is matrix so try gemv - xv = x.dimshuffle(1) - zeros = ptb.AllocEmpty(x.dtype)(y.shape[1]) - rval = gemv_no_inplace(zeros, one, y.T, xv, zero) - new_out = [rval.dimshuffle("x", 0)] - elif not xb[0] and not xb[1] and yb[1]: - # x is matrix, y is vector, try gemv - yv = y.dimshuffle(0) - zeros = ptb.AllocEmpty(x.dtype)(x.shape[0]) - rval = gemv_no_inplace(zeros, one, x, yv, zero) - new_out = [rval.dimshuffle(0, "x")] else: + # if bval is another constant, then z is being usefully + # pre-scaled and GER isn't really the right tool for the job. return copy_stack_trace(node.outputs, new_out) return new_out +# TODO: delete this optimization when we have the proper dot->gemm->ger pipeline working +@node_rewriter([_dot22]) +def local_dot22_to_ger_or_gemv(fgraph, node): + """dot22 computing an outer-product -> GER.""" + x, y = node.inputs + xb = x.broadcastable + yb = y.broadcastable + one = ptb.as_tensor_variable(np.asarray(1, dtype=x.dtype)) + zero = ptb.as_tensor_variable(np.asarray(0, dtype=x.dtype)) + if xb[1] and yb[0]: + # x and y are both vectors so this might qualifies for a GER + xv = x.dimshuffle(0) + yv = y.dimshuffle(1) + zeros = ptb.zeros([x.shape[0], y.shape[1]], dtype=x.dtype) + rval = ger(zeros, one, xv, yv) + new_out = [rval] + elif xb[0] and yb[1]: + # x and y are both vectors so this qualifies for a sdot / ddot + # PyTensor's CGemv will call sdot/ddot at runtime, the Scipy Gemv may not + xv = x.dimshuffle(1) + zeros = ptb.AllocEmpty(x.dtype)(1) + rval = gemv_no_inplace(zeros, one, y.T, xv, zero) + new_out = [rval.dimshuffle("x", 0)] + elif xb[0] and not yb[0] and not yb[1]: + # x is vector, y is matrix so try gemv + xv = x.dimshuffle(1) + zeros = ptb.AllocEmpty(x.dtype)(y.shape[1]) + rval = gemv_no_inplace(zeros, one, y.T, xv, zero) + new_out = [rval.dimshuffle("x", 0)] + elif not xb[0] and not xb[1] and yb[1]: + # x is matrix, y is vector, try gemv + yv = y.dimshuffle(0) + zeros = ptb.AllocEmpty(x.dtype)(x.shape[0]) + rval = gemv_no_inplace(zeros, one, x, yv, zero) + new_out = [rval.dimshuffle(0, "x")] + else: + return + copy_stack_trace(node.outputs, new_out) + return new_out + + ################################# # # Set up the BlasOpt optimizer diff --git a/pytensor/tensor/rewriting/blas_scipy.py b/pytensor/tensor/rewriting/blas_scipy.py deleted file mode 100644 index 2ed0279e45..0000000000 --- a/pytensor/tensor/rewriting/blas_scipy.py +++ /dev/null @@ -1,37 +0,0 @@ -from pytensor.graph.rewriting.basic import in2out -from pytensor.tensor.blas import ger, ger_destructive -from pytensor.tensor.blas_scipy import scipy_ger_inplace, scipy_ger_no_inplace -from pytensor.tensor.rewriting.blas import blas_optdb, node_rewriter, optdb - - -@node_rewriter([ger, ger_destructive]) -def use_scipy_ger(fgraph, node): - if node.op == ger: - return [scipy_ger_no_inplace(*node.inputs)] - - -@node_rewriter([scipy_ger_no_inplace]) -def make_ger_destructive(fgraph, node): - if node.op == scipy_ger_no_inplace: - return [scipy_ger_inplace(*node.inputs)] - - -use_scipy_blas = in2out(use_scipy_ger) -make_scipy_blas_destructive = in2out(make_ger_destructive) - - -# scipy_blas is scheduled in the blas_optdb very late, because scipy sortof -# sucks [citation needed], but it is almost always present. -# C implementations should be scheduled earlier than this, so that they take -# precedence. Once the original Ger is replaced, then these optimizations -# have no effect. -blas_optdb.register("scipy_blas", use_scipy_blas, "fast_run", position=100) - -# this matches the InplaceBlasOpt defined in blas.py -optdb.register( - "make_scipy_blas_destructive", - make_scipy_blas_destructive, - "fast_run", - "inplace", - position=50.2, -) diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index 45ce2a4605..8367642c4c 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -23,7 +23,6 @@ diag, diagonal, ) -from pytensor.tensor.blas import Dot22 from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.math import Dot, Prod, _matmul, log, outer, prod @@ -103,12 +102,12 @@ def transinv_to_invtrans(fgraph, node): @register_stabilize -@node_rewriter([Dot, Dot22]) +@node_rewriter([Dot]) def inv_as_solve(fgraph, node): """ This utilizes a boolean `symmetric` tag on the matrices. """ - if isinstance(node.op, Dot | Dot22): + if isinstance(node.op, Dot): l, r = node.inputs if ( l.owner @@ -277,15 +276,7 @@ def cholesky_ldotlt(fgraph, node): A = node.inputs[0] if not ( - A.owner is not None - and ( - ( - isinstance(A.owner.op, Dot | Dot22) - # This rewrite only applies to matrix Dot - and A.owner.inputs[0].type.ndim == 2 - ) - or (A.owner.op == _matmul) - ) + A.owner is not None and (isinstance(A.owner.op, Dot) or (A.owner.op == _matmul)) ): return diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index 2de62e7e11..9d9a959e77 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -19,7 +19,6 @@ node_rewriter, ) from pytensor.graph.rewriting.utils import get_clients_at_depth -from pytensor.raise_op import assert_op from pytensor.tensor.basic import ( Alloc, Join, @@ -34,6 +33,7 @@ ones_like, register_infer_shape, switch, + zeros, zeros_like, ) from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise @@ -44,12 +44,10 @@ Prod, Sum, _conj, - _dot, _matmul, add, digamma, dot, - eq, erf, erfc, exp, @@ -130,16 +128,12 @@ def scalarconsts_rest(inputs, elemwise=True, only_process_constants=False): return consts, origconsts, nonconsts -@register_canonicalize -@register_stabilize +@register_canonicalize("shape_unsafe") +@register_stabilize("shape_unsafe") @node_rewriter([Dot]) def local_0_dot_x(fgraph, node): - if not isinstance(node.op, Dot): - return False - - x = node.inputs[0] - y = node.inputs[1] - replace = ( + x, y = node.inputs + if ( get_underlying_scalar_constant_value( x, only_process_constants=True, raise_not_constant=False ) @@ -148,26 +142,12 @@ def local_0_dot_x(fgraph, node): y, only_process_constants=True, raise_not_constant=False ) == 0 - ) - - if replace: - constant_zero = constant(0, dtype=node.outputs[0].type.dtype) - if x.ndim == 2 and y.ndim == 2: - constant_zero = assert_op(constant_zero, eq(x.shape[1], y.shape[0])) - return [alloc(constant_zero, x.shape[0], y.shape[1])] - elif x.ndim == 1 and y.ndim == 2: - constant_zero = assert_op(constant_zero, eq(x.shape[0], y.shape[0])) - return [alloc(constant_zero, y.shape[1])] - elif x.ndim == 2 and y.ndim == 1: - constant_zero = assert_op(constant_zero, eq(x.shape[1], y.shape[0])) - return [alloc(constant_zero, x.shape[0])] - elif x.ndim == 1 and y.ndim == 1: - constant_zero = assert_op(constant_zero, eq(x.shape[0], y.shape[0])) - return [constant_zero] + ): + return [zeros((x.shape[0], y.shape[1]), dtype=node.outputs[0].type.dtype)] @register_canonicalize -@node_rewriter([DimShuffle]) +@node_rewriter([Dot, _matmul]) def local_lift_transpose_through_dot(fgraph, node): r"""Perform the rewrite ``dot(x,y).T -> dot(y.T, x.T)``. @@ -176,22 +156,25 @@ def local_lift_transpose_through_dot(fgraph, node): and to later merge consecutive `DimShuffle`\s. """ - if not ( - is_matrix_transpose(node.outputs[0]) - and node.inputs[0].owner - and ((dot_op := node.inputs[0].owner.op) in (_dot, _matmul)) - ): - return False + clients = fgraph.clients[node.out] - x, y = node.inputs[0].owner.inputs + if len(clients) != 1: + # If the dot is used in more than one place, we don't want to duplicate it + return None - if x.ndim >= y.ndim >= 2: - # Output is dot product of transposed inputs in reverse order - ret = [dot_op(y.mT, x.mT)] + [(client, _)] = clients - # Copy over stack trace to output from result of dot-product - copy_stack_trace(node.inputs[0], ret) - return ret + if not (isinstance(client.op, DimShuffle) and is_matrix_transpose(client.out)): + return None + + x, y = node.inputs + # Output is dot product of transposed inputs in reverse order + ret = node.op(y.mT, x.mT) + + # Copy over stack trace to output from result of dot-product + copy_stack_trace(node.out, ret) + + return {client.out: ret} def _batched_matmul_to_core_matmul(fgraph, node, allow_reshape: bool): @@ -344,57 +327,34 @@ def local_batched_matmul_to_core_matmul_with_reshape(fgraph, node): @register_canonicalize @register_specialize -@node_rewriter([_matmul]) -def local_blockwise_dot_to_mul(fgraph, node): - """Rewrite blockwise dots that correspond to multiplication without summation. +@node_rewriter([_matmul, Dot]) +def local_dot_to_mul(fgraph, node): + """Rewrite dots that correspond to multiplication without summation. - We don't touch the regular dot, to not interfere with the BLAS optimizations. + We don't touch outer product without batch-dimensions, to allow rewriting into GER, + which seems more performant in that case. + + # TODO: Once we blockwise Blas operations we shouldn't do it for outer product with batch-dimensions either + # TODO: We may still want to canonicalize outer dot as mul, and detect that for GER. """ a, b = node.inputs a_static_shape = a.type.shape b_static_shape = b.type.shape - core_a_ndim = len(node.op.inputs_sig[0]) - core_b_ndim = len(node.op.inputs_sig[1]) - if core_a_ndim > 2 or core_b_ndim > 2: - # Shouldn't happen, but here just in case + # Check if we have matrix-matrix product: (..., m, 1) * (..., 1, n) -> (..., m, n) + if not (a_static_shape[-1] == 1 or b_static_shape[-2] == 1): return None - if core_b_ndim == 1: - if a_static_shape[-1] == 1 or b_static_shape[-1] == 1: - if core_a_ndim == 1: - # inner product: (..., 1) * (..., 1) -> (...) - # just squeeze the last dimensions of a and b - new_a = a.squeeze(-1) - new_b = b.squeeze(-1) - else: - # matrix vector product: (..., m, 1) * (..., 1) -> (..., m) - # the last dimension of b is already aligned for the elemwise multiplication - # after we squeeze the last dimension of a - new_a = a.squeeze(-1) - new_b = b - else: - return None - - else: - if a_static_shape[-1] == 1 or b_static_shape[-2] == 1: - if core_a_ndim == 1: - # vector_matrix product: (..., 1) * (..., 1, n) -> (..., n) - # the last dimension of a is already aligned for the elemwise multiplication - # after we squeeze the one to last dimension of b - new_a = a - new_b = b.squeeze(-2) - else: - # matrix matrix product: (..., m, 1) * (..., 1, n) -> (..., m, n) - # the dimensions of a and b are already aligned for the elemwise multiplication - new_a = a - new_b = b - else: - return None + # If it's a core Dot we only rewrite if there's no outer product + # (1, 1) * (1, n) or (m, 1) * (1, 1) + # Otherwise we leave as is, so GER can be used instead + if isinstance(node.op, Dot) and not ( + a_static_shape[-2] == 1 or b_static_shape[-1] == 1 + ): + return None - new_a = copy_stack_trace(a, new_a) - new_b = copy_stack_trace(b, new_b) - new_out = copy_stack_trace(node.out, mul(new_a, new_b)) + new_out = mul(a, b) + copy_stack_trace(node.out, new_out) return [new_out] diff --git a/pytensor/tensor/rewriting/subtensor_lift.py b/pytensor/tensor/rewriting/subtensor_lift.py index e70c97d81e..bfb78a98e5 100644 --- a/pytensor/tensor/rewriting/subtensor_lift.py +++ b/pytensor/tensor/rewriting/subtensor_lift.py @@ -158,26 +158,11 @@ def local_subtensor_of_dot(fgraph, node): a = a.type.clone(shape=a.type.shape[batch_ndim:])() b = b.type.clone(shape=b.type.shape[batch_ndim:])() - a_ndim = a.ndim - b_ndim = b.ndim - num_a_indices = min(a_ndim - 1, len(idx_list)) - a_indices = idx_list[:num_a_indices] - b_indices = idx_list[num_a_indices:] - - # This is necessary because np.dot sums the last index of a with the second to last of b - # so we want to skip the second-to-last index into b. - # This wasn't necessary for a, because we just omitted the last index. - # We skip this if b.ndim = 1, since then we just want b_sub = b, not b_sub = b[:] - # (dot also handles b.ndim < 2 as a special case) - if b_ndim > 1 and len(b_indices) >= b_ndim - 1: - b_indices = ( - b_indices[: b_ndim - 2] - + (slice(None, None, None),) - + b_indices[b_ndim - 2 :] - ) + a_indices = idx_list[:1] + b_indices = (slice(None), *idx_list[1:]) a_sub = a[tuple(a_indices)] - b_sub = b[tuple(b_indices)] if b_indices else b + b_sub = b[tuple(b_indices)] r = dot(a_sub, b_sub) if batch_ndim: diff --git a/tests/graph/rewriting/test_kanren.py b/tests/graph/rewriting/test_kanren.py index 1b5ffb1564..a1dc310ce5 100644 --- a/tests/graph/rewriting/test_kanren.py +++ b/tests/graph/rewriting/test_kanren.py @@ -37,51 +37,51 @@ def clear_assoccomm(): def test_kanren_basic(): A_pt = pt.matrix("A") - x_pt = pt.vector("x") + B_pt = pt.matrix("B") - y_pt = pt.dot(A_pt, x_pt) + y_pt = pt.dot(A_pt, B_pt) q = var() - res = list(run(None, q, eq(y_pt, etuple(_dot, q, x_pt)))) + res = list(run(None, q, eq(y_pt, etuple(_dot, q, B_pt)))) assert res == [A_pt] def test_KanrenRelationSub_filters(): - x_pt = pt.vector("x") - y_pt = pt.vector("y") - z_pt = pt.vector("z") A_pt = pt.matrix("A") + B_pt = pt.matrix("B") + C_pt = pt.matrix("C") + D_pt = pt.matrix("D") fact(commutative, _dot) fact(commutative, pt.add) fact(associative, pt.add) - Z_pt = A_pt.dot((x_pt + y_pt) + z_pt) + Z_pt = A_pt.dot((B_pt + C_pt) + D_pt) fgraph = FunctionGraph(outputs=[Z_pt], clone=False) def distributes(in_lv, out_lv): - A_lv, x_lv, y_lv, z_lv = vars(4) + A_lv, B_lv, C_lv, D_lv = vars(4) return lall( # lhs == A * (x + y + z) eq_assoccomm( - etuple(_dot, A_lv, etuple(pt.add, x_lv, etuple(pt.add, y_lv, z_lv))), + etuple(_dot, A_lv, etuple(pt.add, B_lv, etuple(pt.add, C_lv, D_lv))), in_lv, ), # This relation does nothing but provide us with a means of # generating associative-commutative matches in the `kanren` # output. - eq((A_lv, x_lv, y_lv, z_lv), out_lv), + eq((A_lv, B_lv, C_lv, D_lv), out_lv), ) def results_filter(results): _results = [eval_if_etuple(v) for v in results] # Make sure that at least a couple permutations are present - assert (A_pt, x_pt, y_pt, z_pt) in _results - assert (A_pt, y_pt, x_pt, z_pt) in _results - assert (A_pt, z_pt, x_pt, y_pt) in _results + assert (A_pt, B_pt, C_pt, D_pt) in _results + assert (A_pt, C_pt, B_pt, D_pt) in _results + assert (A_pt, D_pt, B_pt, C_pt) in _results return None @@ -121,13 +121,13 @@ def relation(in_lv, out_lv): def test_KanrenRelationSub_dot(): """Make sure we can run miniKanren "optimizations" over a graph until a fixed-point/normal-form is reached.""" - x_pt = pt.vector("x") - c_pt = pt.vector("c") - d_pt = pt.vector("d") A_pt = pt.matrix("A") B_pt = pt.matrix("B") + C_pt = pt.matrix("C") + D_pt = pt.matrix("D") + E_pt = pt.matrix("E") - Z_pt = A_pt.dot(x_pt + B_pt.dot(c_pt + d_pt)) + Z_pt = A_pt.dot(E_pt + B_pt.dot(C_pt + D_pt)) fgraph = FunctionGraph(outputs=[Z_pt], clone=False) @@ -137,15 +137,15 @@ def distributes(in_lv, out_lv): return lall( # lhs == A * (x + b) eq( - etuple(_dot, var("A"), etuple(pt.add, var("x"), var("b"))), + etuple(_dot, var("A"), etuple(pt.add, var("E"), var("B"))), in_lv, ), # rhs == A * x + A * b eq( etuple( pt.add, - etuple(_dot, var("A"), var("x")), - etuple(_dot, var("A"), var("b")), + etuple(_dot, var("A"), var("E")), + etuple(_dot, var("A"), var("B")), ), out_lv, ), diff --git a/tests/link/numba/test_basic.py b/tests/link/numba/test_basic.py index 3b880616df..95cf6ec557 100644 --- a/tests/link/numba/test_basic.py +++ b/tests/link/numba/test_basic.py @@ -631,7 +631,7 @@ def test_Dot(x, y): x, x_test_value = x y, y_test_value = y - g = ptm.Dot()(x, y) + g = ptm.dot(x, y) compare_numba_and_py( [x, y], diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index 0684015f4e..8a6b46e7b9 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -4714,14 +4714,15 @@ def test_local_dot_to_mul(batched, a_shape, b_shape): == 1 ) - # For now rewrite only applies to Batched Dots + # For now we do not rewrite only the case of unbatched outer + core_outer = (not batched) and (a_shape == (3, 1)) and (b_shape == (1, 3)) rewritten_out = rewrite_graph(out) assert rewritten_out.type.shape == out.type.shape assert sum( isinstance(var.owner.op, (Blockwise | Dot)) for var in ancestors([rewritten_out]) if var.owner - ) == (0 if batched else 1) + ) == (1 if core_outer else 0) a_test = np.random.normal(size=a.type.shape).astype(a.type.dtype) b_test = np.random.normal(size=b.type.shape).astype(b.type.dtype) diff --git a/tests/tensor/test_blas.py b/tests/tensor/test_blas.py index 1332266e3d..0052bbba4b 100644 --- a/tests/tensor/test_blas.py +++ b/tests/tensor/test_blas.py @@ -9,7 +9,6 @@ import pytensor import pytensor.scalar as ps import pytensor.tensor as pt -import pytensor.tensor.blas_scipy from pytensor.compile.function import function from pytensor.compile.io import In from pytensor.compile.mode import Mode diff --git a/tests/tensor/test_blas_c.py b/tests/tensor/test_blas_c.py index b6ba1987b9..e7d426eb9f 100644 --- a/tests/tensor/test_blas_c.py +++ b/tests/tensor/test_blas_c.py @@ -8,7 +8,6 @@ from pytensor.tensor.basic import AllocEmpty from pytensor.tensor.blas import Ger from pytensor.tensor.blas_c import CGemv, CGer, must_initialize_y_gemv -from pytensor.tensor.blas_scipy import ScipyGer from pytensor.tensor.type import dmatrix, dvector, matrix, scalar, tensor, vector from tests import unittest_tools from tests.tensor.test_blas import BaseGemv, TestBlasStrides @@ -68,8 +67,6 @@ def test_eq(self): assert CGer(False) == CGer(False) assert CGer(False) != CGer(True) - assert CGer(True) != ScipyGer(True) - assert CGer(False) != ScipyGer(False) assert CGer(True) != Ger(True) assert CGer(False) != Ger(False) @@ -486,3 +483,26 @@ def test_gemv_negative_strides_perf(neg_stride0, neg_stride1, F_layout, benchmar np.testing.assert_allclose(res, fn(test_A.copy(), test_x, test_y)) benchmark(fn, test_A, test_x, test_y) + + +@pytest.mark.parametrize("inplace", (True, False), ids=["inplace", "no_inplace"]) +@pytest.mark.parametrize("n", [2**7, 2**9, 2**13]) +def test_ger_benchmark(n, inplace, benchmark): + alpha = pt.dscalar("alpha") + x = pt.dvector("x") + y = pt.dvector("y") + A = pt.dmatrix("A") + + out = alpha * pt.outer(x, y) + A + + fn = pytensor.function( + [alpha, x, y, pytensor.In(A, mutable=inplace)], out, trust_input=True + ) + + rng = np.random.default_rng([2274, n]) + alpha_test = rng.normal(size=()) + x_test = rng.normal(size=(n,)) + y_test = rng.normal(size=(n,)) + A_test = rng.normal(size=(n, n)) + + benchmark(fn, alpha_test, x_test, y_test, A_test) diff --git a/tests/tensor/test_blas_scipy.py b/tests/tensor/test_blas_scipy.py deleted file mode 100644 index 716eab7bbe..0000000000 --- a/tests/tensor/test_blas_scipy.py +++ /dev/null @@ -1,75 +0,0 @@ -import pickle - -import numpy as np - -import pytensor -from pytensor import tensor as pt -from pytensor.tensor.blas_scipy import ScipyGer -from pytensor.tensor.math import outer -from pytensor.tensor.type import tensor -from tests.tensor.test_blas import TestBlasStrides, gemm_no_inplace -from tests.unittest_tools import OptimizationTestMixin - - -class TestScipyGer(OptimizationTestMixin): - def setup_method(self): - self.mode = pytensor.compile.get_default_mode() - self.mode = self.mode.including("fast_run") - self.mode = self.mode.excluding("c_blas") # c_blas trumps scipy Ops - dtype = self.dtype = "float64" # optimization isn't dtype-dependent - self.A = tensor(dtype=dtype, shape=(None, None)) - self.a = tensor(dtype=dtype, shape=()) - self.x = tensor(dtype=dtype, shape=(None,)) - self.y = tensor(dtype=dtype, shape=(None,)) - self.Aval = np.ones((2, 3), dtype=dtype) - self.xval = np.asarray([1, 2], dtype=dtype) - self.yval = np.asarray([1.5, 2.7, 3.9], dtype=dtype) - - def function(self, inputs, outputs): - return pytensor.function(inputs, outputs, self.mode) - - def run_f(self, f): - f(self.Aval, self.xval, self.yval) - f(self.Aval[::-1, ::-1], self.xval[::-1], self.yval[::-1]) - - def b(self, bval): - return pt.as_tensor_variable(np.asarray(bval, dtype=self.dtype)) - - def test_outer(self): - f = self.function([self.x, self.y], outer(self.x, self.y)) - self.assertFunctionContains(f, ScipyGer(destructive=True)) - - def test_A_plus_outer(self): - f = self.function([self.A, self.x, self.y], self.A + outer(self.x, self.y)) - self.assertFunctionContains(f, ScipyGer(destructive=False)) - self.run_f(f) # DebugMode tests correctness - - def test_A_plus_scaled_outer(self): - f = self.function( - [self.A, self.x, self.y], self.A + 0.1 * outer(self.x, self.y) - ) - self.assertFunctionContains(f, ScipyGer(destructive=False)) - self.run_f(f) # DebugMode tests correctness - - def test_scaled_A_plus_scaled_outer(self): - f = self.function( - [self.A, self.x, self.y], 0.2 * self.A + 0.1 * outer(self.x, self.y) - ) - self.assertFunctionContains(f, gemm_no_inplace) - self.run_f(f) # DebugMode tests correctness - - def test_pickle(self): - out = ScipyGer(destructive=False)(self.A, self.a, self.x, self.y) - f = pytensor.function([self.A, self.a, self.x, self.y], out) - new_f = pickle.loads(pickle.dumps(f)) - - assert isinstance(new_f.maker.fgraph.toposort()[-1].op, ScipyGer) - assert np.allclose( - f(self.Aval, 1.0, self.xval, self.yval), - new_f(self.Aval, 1.0, self.xval, self.yval), - ) - - -class TestBlasStridesScipy(TestBlasStrides): - mode = pytensor.compile.get_default_mode() - mode = mode.including("fast_run").excluding("gpu", "c_blas") diff --git a/tests/tensor/test_math.py b/tests/tensor/test_math.py index 8df50db765..935b9ada52 100644 --- a/tests/tensor/test_math.py +++ b/tests/tensor/test_math.py @@ -1998,50 +1998,20 @@ def test_list(self): assert mean(ll).eval() == 1 -def test_dot_numpy_inputs(): - """Test the `PyTensor.tensor.dot` interface function with NumPy inputs.""" - a = np.ones(2) - b = np.ones(2) - res = dot(a, b) - assert isinstance(res, Variable) - assert isinstance(res.owner.op, Dot) - - class TestDot: - def test_Op_dims(self): + def test_valid_ndim(self): d0 = scalar() d1 = vector() d2 = matrix() d3 = tensor3() - with pytest.raises(TypeError): - _dot(d0, d0) - with pytest.raises(TypeError): - _dot(d0, d1) with pytest.raises(TypeError): _dot(d0, d2) with pytest.raises(TypeError): - _dot(d0, d3) - with pytest.raises(TypeError): - _dot(d1, d0) - _dot(d1, d1) - _dot(d1, d2) - with pytest.raises(TypeError): - _dot(d1, d3) - with pytest.raises(TypeError): - _dot(d2, d0) - _dot(d2, d1) - _dot(d2, d2) - with pytest.raises(TypeError): - _dot(d2, d3) - with pytest.raises(TypeError): - _dot(d3, d0) - with pytest.raises(TypeError): - _dot(d3, d1) + _dot(d1, d2) with pytest.raises(TypeError): _dot(d3, d2) - with pytest.raises(TypeError): - _dot(d3, d3) + _dot(d2, d2) # Fine def test_grad(self): rng = np.random.default_rng(seed=utt.fetch_seed()) @@ -2089,6 +2059,14 @@ def is_super_shape(var1, var2): g = grad(z.sum(), y) assert is_super_shape(y, g) + def test_dot_numpy_inputs(self): + """Test the `PyTensor.tensor.dot` interface function with NumPy inputs.""" + a = np.ones((2, 2)) + b = np.ones((2, 2)) + res = dot(a, b) + assert isinstance(res, Variable) + assert isinstance(res.owner.op, Dot) + def test_matrix_vector_ops(): """Test vecdot, matvec, and vecmat helper functions.""" @@ -2796,7 +2774,7 @@ def test_Dot(self): bdvec_val = random(4, rng=rng) self._compile_and_check( [advec, bdvec], - [Dot()(advec, bdvec)], + [dot(advec, bdvec)], [advec_val, bdvec_val], (Dot, blas.Dot22, blas.Gemv, blas_c.CGemv), ) @@ -2808,7 +2786,7 @@ def test_Dot(self): bdmat_val = random(5, 3, rng=rng) self._compile_and_check( [admat, bdmat], - [Dot()(admat, bdmat)], + [dot(admat, bdmat)], [admat_val, bdmat_val], (Dot, blas.Dot22), ) @@ -2817,7 +2795,7 @@ def test_Dot(self): bdmat_val = random(4, 5, rng=rng) self._compile_and_check( [advec, bdmat], - [Dot()(advec, bdmat)], + [dot(advec, bdmat)], [advec_val, bdmat_val], (Dot, blas.Dot22, blas.Gemv, blas_c.CGemv), ) @@ -2826,7 +2804,7 @@ def test_Dot(self): admat_val = random(5, 4, rng=rng) self._compile_and_check( [admat, bdvec], - [Dot()(admat, bdvec)], + [dot(admat, bdvec)], [admat_val, bdvec_val], (Dot, blas.Dot22, blas.Gemv, blas_c.CGemv), ) diff --git a/tests/test_printing.py b/tests/test_printing.py index 4dd4f3866d..95c3c938cf 100644 --- a/tests/test_printing.py +++ b/tests/test_printing.py @@ -333,7 +333,7 @@ def test_debugprint(): def test_debugprint_id_type(): - a_at = dvector() + a_at = dmatrix() b_at = dmatrix() d_at = b_at.dot(a_at) @@ -344,10 +344,10 @@ def test_debugprint_id_type(): s = s.getvalue() exp_res = f"""Add [id {e_at.auto_name}] - ├─ dot [id {d_at.auto_name}] + ├─ Dot [id {d_at.auto_name}] │ ├─ [id {b_at.auto_name}] - │ └─ [id {a_at.auto_name}] - └─ [id {a_at.auto_name}] + │ └─ [id {a_at.auto_name}] + └─ [id {a_at.auto_name}] """ assert [l.strip() for l in s.split("\n")] == [ diff --git a/tests/xtensor/test_math.py b/tests/xtensor/test_math.py index 376532f8ab..afd720ff8a 100644 --- a/tests/xtensor/test_math.py +++ b/tests/xtensor/test_math.py @@ -312,5 +312,7 @@ def test_dot_errors(): x_test = DataArray(np.ones((2, 3)), dims=("a", "b")) y_test = DataArray(np.ones((4, 5)), dims=("b", "c")) # Doesn't fail until the rewrite - with pytest.raises(ValueError, match="not aligned"): + with pytest.raises( + ValueError, match="Input operand 1 has a mismatch in its core dimension 0" + ): fn(x_test, y_test)