Skip to content

Commit 0e449da

Browse files
Luca Carminativmoens
authored andcommitted
Fix Binary reshaping
1 parent 4001d9c commit 0e449da

File tree

2 files changed

+11
-7
lines changed

2 files changed

+11
-7
lines changed

test/test_specs.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3913,6 +3913,7 @@ def test_expand(self):
39133913
disc = Categorical(shape=(-1, 1, 2), n=4)
39143914
moneh = MultiOneHot(shape=(-1, 1, 2, 7), nvec=[3, 4])
39153915
mdisc = MultiCategorical(shape=(-1, 1, 2, 2), nvec=[3, 4])
3916+
binary = Binary(shape=(-1, 1, 2))
39163917

39173918
spec = Composite(
39183919
unb=unb,
@@ -3922,6 +3923,7 @@ def test_expand(self):
39223923
disc=disc,
39233924
moneh=moneh,
39243925
mdisc=mdisc,
3926+
binary=binary,
39253927
shape=(-1, 1, 2),
39263928
)
39273929
assert spec.shape == (-1, 1, 2)

torchrl/data/tensor_specs.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4430,12 +4430,14 @@ def __init__(
44304430
raise ValueError(
44314431
f"'n' must be zero for spec {self.__class__} when using an empty shape"
44324432
)
4433-
else:
4433+
elif n is not None:
44344434
if shape[-1] != n:
44354435
raise ValueError(
44364436
f"The last value of the shape must match 'n' for spec {self.__class__}. "
44374437
f"Got n={n} and shape={shape}."
44384438
)
4439+
else:
4440+
n = shape[-1]
44394441

44404442
super().__init__(n=2, shape=shape, device=device, dtype=dtype)
44414443
self.encode = self._encode_eager
@@ -4449,15 +4451,15 @@ def expand(self, *shape):
44494451
f"shape of the {self.__class__.__name__} spec in expand()."
44504452
)
44514453
return self.__class__(
4452-
n=self.shape[-1] if len(self.shape) > 0 else None,
4454+
n=shape[-1] if len(shape) > 0 else None,
44534455
shape=shape,
44544456
device=self.device,
44554457
dtype=self.dtype,
44564458
)
44574459

44584460
def _reshape(self, shape):
44594461
return self.__class__(
4460-
n=self.shape[-1] if len(self.shape) > 0 else None,
4462+
n=shape[-1] if len(shape) > 0 else None,
44614463
shape=shape,
44624464
device=self.device,
44634465
dtype=self.dtype,
@@ -4470,7 +4472,7 @@ def _unflatten(self, dim, sizes):
44704472
.shape
44714473
)
44724474
return self.__class__(
4473-
n=self.shape[-1] if len(self.shape) > 0 else None,
4475+
n=shape[-1] if len(shape) > 0 else None,
44744476
shape=shape,
44754477
device=self.device,
44764478
dtype=self.dtype,
@@ -4481,7 +4483,7 @@ def squeeze(self, dim=None):
44814483
if shape is None:
44824484
return self
44834485
return self.__class__(
4484-
n=self.shape[-1] if len(self.shape) > 0 else None,
4486+
n=shape[-1] if len(shape) > 0 else None,
44854487
shape=shape,
44864488
device=self.device,
44874489
dtype=self.dtype,
@@ -4490,7 +4492,7 @@ def squeeze(self, dim=None):
44904492
def unsqueeze(self, dim: int):
44914493
shape = _unsqueezed_shape(self.shape, dim)
44924494
return self.__class__(
4493-
n=self.shape[-1] if len(self.shape) > 0 else None,
4495+
n=shape[-1] if len(shape) > 0 else None,
44944496
shape=shape,
44954497
device=self.device,
44964498
dtype=self.dtype,
@@ -4510,7 +4512,7 @@ def unbind(self, dim: int = 0):
45104512
shape = tuple(s for i, s in enumerate(self.shape) if i != dim)
45114513
return tuple(
45124514
self.__class__(
4513-
n=self.shape[-1] if len(self.shape) > 0 else None,
4515+
n=shape[-1] if len(shape) > 0 else None,
45144516
shape=shape,
45154517
device=self.device,
45164518
dtype=self.dtype,

0 commit comments

Comments
 (0)