@@ -71,6 +71,8 @@ def __init__(
71
71
self .name = name
72
72
self .numpy_dtype = np .dtype (self .dtype )
73
73
self .filter_checks_isfinite = False
74
+ # broadcastable is here just for code that would work fine with XTensorType but checks for it
75
+ self .broadcastable = (False ,) * self .ndim
74
76
75
77
def clone (
76
78
self ,
@@ -93,6 +95,10 @@ def filter(self, value, strict=False, allow_downcast=None):
93
95
self , value , strict = strict , allow_downcast = allow_downcast
94
96
)
95
97
98
+ @staticmethod
99
+ def may_share_memory (a , b ):
100
+ return TensorType .may_share_memory (a , b )
101
+
96
102
def filter_variable (self , other , allow_convert = True ):
97
103
if not isinstance (other , Variable ):
98
104
# The value is not a Variable: we cast it into
@@ -160,7 +166,7 @@ def convert_variable(self, var):
160
166
return None
161
167
162
168
def __repr__ (self ):
163
- return f"XTensorType({ self .dtype } , { self .dims } , { self .shape } )"
169
+ return f"XTensorType({ self .dtype } , shape= { self .shape } , dims= { self .dims } )"
164
170
165
171
def __hash__ (self ):
166
172
return hash ((type (self ), self .dtype , self .shape , self .dims ))
0 commit comments