@@ -4430,12 +4430,14 @@ def __init__(
4430
4430
raise ValueError (
4431
4431
f"'n' must be zero for spec { self .__class__ } when using an empty shape"
4432
4432
)
4433
- else :
4433
+ elif n is not None :
4434
4434
if shape [- 1 ] != n :
4435
4435
raise ValueError (
4436
4436
f"The last value of the shape must match 'n' for spec { self .__class__ } . "
4437
4437
f"Got n={ n } and shape={ shape } ."
4438
4438
)
4439
+ else :
4440
+ n = shape [- 1 ]
4439
4441
4440
4442
super ().__init__ (n = 2 , shape = shape , device = device , dtype = dtype )
4441
4443
self .encode = self ._encode_eager
@@ -4449,15 +4451,15 @@ def expand(self, *shape):
4449
4451
f"shape of the { self .__class__ .__name__ } spec in expand()."
4450
4452
)
4451
4453
return self .__class__ (
4452
- n = self . shape [- 1 ] if len (self . shape ) > 0 else None ,
4454
+ n = shape [- 1 ] if len (shape ) > 0 else None ,
4453
4455
shape = shape ,
4454
4456
device = self .device ,
4455
4457
dtype = self .dtype ,
4456
4458
)
4457
4459
4458
4460
def _reshape (self , shape ):
4459
4461
return self .__class__ (
4460
- n = self . shape [- 1 ] if len (self . shape ) > 0 else None ,
4462
+ n = shape [- 1 ] if len (shape ) > 0 else None ,
4461
4463
shape = shape ,
4462
4464
device = self .device ,
4463
4465
dtype = self .dtype ,
@@ -4470,7 +4472,7 @@ def _unflatten(self, dim, sizes):
4470
4472
.shape
4471
4473
)
4472
4474
return self .__class__ (
4473
- n = self . shape [- 1 ] if len (self . shape ) > 0 else None ,
4475
+ n = shape [- 1 ] if len (shape ) > 0 else None ,
4474
4476
shape = shape ,
4475
4477
device = self .device ,
4476
4478
dtype = self .dtype ,
@@ -4481,7 +4483,7 @@ def squeeze(self, dim=None):
4481
4483
if shape is None :
4482
4484
return self
4483
4485
return self .__class__ (
4484
- n = self . shape [- 1 ] if len (self . shape ) > 0 else None ,
4486
+ n = shape [- 1 ] if len (shape ) > 0 else None ,
4485
4487
shape = shape ,
4486
4488
device = self .device ,
4487
4489
dtype = self .dtype ,
@@ -4490,7 +4492,7 @@ def squeeze(self, dim=None):
4490
4492
def unsqueeze (self , dim : int ):
4491
4493
shape = _unsqueezed_shape (self .shape , dim )
4492
4494
return self .__class__ (
4493
- n = self . shape [- 1 ] if len (self . shape ) > 0 else None ,
4495
+ n = shape [- 1 ] if len (shape ) > 0 else None ,
4494
4496
shape = shape ,
4495
4497
device = self .device ,
4496
4498
dtype = self .dtype ,
@@ -4510,7 +4512,7 @@ def unbind(self, dim: int = 0):
4510
4512
shape = tuple (s for i , s in enumerate (self .shape ) if i != dim )
4511
4513
return tuple (
4512
4514
self .__class__ (
4513
- n = self . shape [- 1 ] if len (self . shape ) > 0 else None ,
4515
+ n = shape [- 1 ] if len (shape ) > 0 else None ,
4514
4516
shape = shape ,
4515
4517
device = self .device ,
4516
4518
dtype = self .dtype ,
0 commit comments