@@ -4449,15 +4449,15 @@ def expand(self, *shape):
4449
4449
f"shape of the { self .__class__ .__name__ } spec in expand()."
4450
4450
)
4451
4451
return self .__class__ (
4452
- n = self . shape [- 1 ] if len (self . shape ) > 0 else None ,
4452
+ n = shape [- 1 ] if len (shape ) > 0 else None ,
4453
4453
shape = shape ,
4454
4454
device = self .device ,
4455
4455
dtype = self .dtype ,
4456
4456
)
4457
4457
4458
4458
def _reshape (self , shape ):
4459
4459
return self .__class__ (
4460
- n = self . shape [- 1 ] if len (self . shape ) > 0 else None ,
4460
+ n = shape [- 1 ] if len (shape ) > 0 else None ,
4461
4461
shape = shape ,
4462
4462
device = self .device ,
4463
4463
dtype = self .dtype ,
@@ -4470,7 +4470,7 @@ def _unflatten(self, dim, sizes):
4470
4470
.shape
4471
4471
)
4472
4472
return self .__class__ (
4473
- n = self . shape [- 1 ] if len (self . shape ) > 0 else None ,
4473
+ n = shape [- 1 ] if len (shape ) > 0 else None ,
4474
4474
shape = shape ,
4475
4475
device = self .device ,
4476
4476
dtype = self .dtype ,
@@ -4481,7 +4481,7 @@ def squeeze(self, dim=None):
4481
4481
if shape is None :
4482
4482
return self
4483
4483
return self .__class__ (
4484
- n = self . shape [- 1 ] if len (self . shape ) > 0 else None ,
4484
+ n = shape [- 1 ] if len (shape ) > 0 else None ,
4485
4485
shape = shape ,
4486
4486
device = self .device ,
4487
4487
dtype = self .dtype ,
@@ -4490,7 +4490,7 @@ def squeeze(self, dim=None):
4490
4490
def unsqueeze (self , dim : int ):
4491
4491
shape = _unsqueezed_shape (self .shape , dim )
4492
4492
return self .__class__ (
4493
- n = self . shape [- 1 ] if len (self . shape ) > 0 else None ,
4493
+ n = shape [- 1 ] if len (shape ) > 0 else None ,
4494
4494
shape = shape ,
4495
4495
device = self .device ,
4496
4496
dtype = self .dtype ,
@@ -4510,7 +4510,7 @@ def unbind(self, dim: int = 0):
4510
4510
shape = tuple (s for i , s in enumerate (self .shape ) if i != dim )
4511
4511
return tuple (
4512
4512
self .__class__ (
4513
- n = self . shape [- 1 ] if len (self . shape ) > 0 else None ,
4513
+ n = shape [- 1 ] if len (shape ) > 0 else None ,
4514
4514
shape = shape ,
4515
4515
device = self .device ,
4516
4516
dtype = self .dtype ,
0 commit comments