From 0e449daa5ed62a73f42ed48aa261fb4f9bc33c8c Mon Sep 17 00:00:00 2001 From: Luca Carminati Date: Fri, 18 Jul 2025 11:12:48 +0200 Subject: [PATCH] Fix Binary reshaping --- test/test_specs.py | 2 ++ torchrl/data/tensor_specs.py | 16 +++++++++------- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/test/test_specs.py b/test/test_specs.py index dd9c2ffcd27..def9cd3032c 100644 --- a/test/test_specs.py +++ b/test/test_specs.py @@ -3913,6 +3913,7 @@ def test_expand(self): disc = Categorical(shape=(-1, 1, 2), n=4) moneh = MultiOneHot(shape=(-1, 1, 2, 7), nvec=[3, 4]) mdisc = MultiCategorical(shape=(-1, 1, 2, 2), nvec=[3, 4]) + binary = Binary(shape=(-1, 1, 2)) spec = Composite( unb=unb, @@ -3922,6 +3923,7 @@ def test_expand(self): disc=disc, moneh=moneh, mdisc=mdisc, + binary=binary, shape=(-1, 1, 2), ) assert spec.shape == (-1, 1, 2) diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 5f36b597f99..a91159ec9ba 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -4430,12 +4430,14 @@ def __init__( raise ValueError( f"'n' must be zero for spec {self.__class__} when using an empty shape" ) - else: + elif n is not None: if shape[-1] != n: raise ValueError( f"The last value of the shape must match 'n' for spec {self.__class__}. " f"Got n={n} and shape={shape}." ) + else: + n = shape[-1] super().__init__(n=2, shape=shape, device=device, dtype=dtype) self.encode = self._encode_eager @@ -4449,7 +4451,7 @@ def expand(self, *shape): f"shape of the {self.__class__.__name__} spec in expand()." ) return self.__class__( - n=self.shape[-1] if len(self.shape) > 0 else None, + n=shape[-1] if len(shape) > 0 else None, shape=shape, device=self.device, dtype=self.dtype, @@ -4457,7 +4459,7 @@ def expand(self, *shape): def _reshape(self, shape): return self.__class__( - n=self.shape[-1] if len(self.shape) > 0 else None, + n=shape[-1] if len(shape) > 0 else None, shape=shape, device=self.device, dtype=self.dtype, @@ -4470,7 +4472,7 @@ def _unflatten(self, dim, sizes): .shape ) return self.__class__( - n=self.shape[-1] if len(self.shape) > 0 else None, + n=shape[-1] if len(shape) > 0 else None, shape=shape, device=self.device, dtype=self.dtype, @@ -4481,7 +4483,7 @@ def squeeze(self, dim=None): if shape is None: return self return self.__class__( - n=self.shape[-1] if len(self.shape) > 0 else None, + n=shape[-1] if len(shape) > 0 else None, shape=shape, device=self.device, dtype=self.dtype, @@ -4490,7 +4492,7 @@ def squeeze(self, dim=None): def unsqueeze(self, dim: int): shape = _unsqueezed_shape(self.shape, dim) return self.__class__( - n=self.shape[-1] if len(self.shape) > 0 else None, + n=shape[-1] if len(shape) > 0 else None, shape=shape, device=self.device, dtype=self.dtype, @@ -4510,7 +4512,7 @@ def unbind(self, dim: int = 0): shape = tuple(s for i, s in enumerate(self.shape) if i != dim) return tuple( self.__class__( - n=self.shape[-1] if len(self.shape) > 0 else None, + n=shape[-1] if len(shape) > 0 else None, shape=shape, device=self.device, dtype=self.dtype,