Skip to content

Commit f072295

Browse files
author
Luca Carminati
committed
Fix Binary reshaping
1 parent 4001d9c commit f072295

File tree

2 files changed

+8
-6
lines changed

2 files changed

+8
-6
lines changed

test/test_specs.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3913,6 +3913,7 @@ def test_expand(self):
39133913
disc = Categorical(shape=(-1, 1, 2), n=4)
39143914
moneh = MultiOneHot(shape=(-1, 1, 2, 7), nvec=[3, 4])
39153915
mdisc = MultiCategorical(shape=(-1, 1, 2, 2), nvec=[3, 4])
3916+
binary = Binary(shape=(-1, 1, 2))
39163917

39173918
spec = Composite(
39183919
unb=unb,
@@ -3922,6 +3923,7 @@ def test_expand(self):
39223923
disc=disc,
39233924
moneh=moneh,
39243925
mdisc=mdisc,
3926+
binary=binary,
39253927
shape=(-1, 1, 2),
39263928
)
39273929
assert spec.shape == (-1, 1, 2)

torchrl/data/tensor_specs.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4449,15 +4449,15 @@ def expand(self, *shape):
44494449
f"shape of the {self.__class__.__name__} spec in expand()."
44504450
)
44514451
return self.__class__(
4452-
n=self.shape[-1] if len(self.shape) > 0 else None,
4452+
n=shape[-1] if len(shape) > 0 else None,
44534453
shape=shape,
44544454
device=self.device,
44554455
dtype=self.dtype,
44564456
)
44574457

44584458
def _reshape(self, shape):
44594459
return self.__class__(
4460-
n=self.shape[-1] if len(self.shape) > 0 else None,
4460+
n=shape[-1] if len(shape) > 0 else None,
44614461
shape=shape,
44624462
device=self.device,
44634463
dtype=self.dtype,
@@ -4470,7 +4470,7 @@ def _unflatten(self, dim, sizes):
44704470
.shape
44714471
)
44724472
return self.__class__(
4473-
n=self.shape[-1] if len(self.shape) > 0 else None,
4473+
n=shape[-1] if len(shape) > 0 else None,
44744474
shape=shape,
44754475
device=self.device,
44764476
dtype=self.dtype,
@@ -4481,7 +4481,7 @@ def squeeze(self, dim=None):
44814481
if shape is None:
44824482
return self
44834483
return self.__class__(
4484-
n=self.shape[-1] if len(self.shape) > 0 else None,
4484+
n=shape[-1] if len(shape) > 0 else None,
44854485
shape=shape,
44864486
device=self.device,
44874487
dtype=self.dtype,
@@ -4490,7 +4490,7 @@ def squeeze(self, dim=None):
44904490
def unsqueeze(self, dim: int):
44914491
shape = _unsqueezed_shape(self.shape, dim)
44924492
return self.__class__(
4493-
n=self.shape[-1] if len(self.shape) > 0 else None,
4493+
n=shape[-1] if len(shape) > 0 else None,
44944494
shape=shape,
44954495
device=self.device,
44964496
dtype=self.dtype,
@@ -4510,7 +4510,7 @@ def unbind(self, dim: int = 0):
45104510
shape = tuple(s for i, s in enumerate(self.shape) if i != dim)
45114511
return tuple(
45124512
self.__class__(
4513-
n=self.shape[-1] if len(self.shape) > 0 else None,
4513+
n=shape[-1] if len(shape) > 0 else None,
45144514
shape=shape,
45154515
device=self.device,
45164516
dtype=self.dtype,

0 commit comments

Comments
 (0)