Skip to content

Commit 5761704

Browse files
committed
Small tweaks to XTensorType
1 parent 0bb15f9 commit 5761704

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

pytensor/xtensor/type.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,8 @@ def __init__(
7171
self.name = name
7272
self.numpy_dtype = np.dtype(self.dtype)
7373
self.filter_checks_isfinite = False
74+
# broadcastable is here just for code that would work fine with XTensorType but checks for it
75+
self.broadcastable = (False,) * self.ndim
7476

7577
def clone(
7678
self,
@@ -93,6 +95,10 @@ def filter(self, value, strict=False, allow_downcast=None):
9395
self, value, strict=strict, allow_downcast=allow_downcast
9496
)
9597

98+
@staticmethod
99+
def may_share_memory(a, b):
100+
return TensorType.may_share_memory(a, b)
101+
96102
def filter_variable(self, other, allow_convert=True):
97103
if not isinstance(other, Variable):
98104
# The value is not a Variable: we cast it into
@@ -160,7 +166,7 @@ def convert_variable(self, var):
160166
return None
161167

162168
def __repr__(self):
163-
return f"XTensorType({self.dtype}, {self.dims}, {self.shape})"
169+
return f"XTensorType({self.dtype}, shape={self.shape}, dims={self.dims})"
164170

165171
def __hash__(self):
166172
return hash((type(self), self.dtype, self.shape, self.dims))

0 commit comments

Comments
 (0)