Skip to content

Commit 311059c

Browse files
committed
Fix pendulum.py issues, updated to use newer APIs
1 parent 19e68c8 commit 311059c

File tree

1 file changed

+15
-13
lines changed

1 file changed

+15
-13
lines changed

advanced_source/pendulum.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@
100100
from tensordict.nn import TensorDictModule
101101
from torch import nn
102102

103-
from torchrl.data import BoundedTensorSpec, CompositeSpec, UnboundedContinuousTensorSpec
103+
from torchrl.data import Bounded, Composite, Unbounded
104104
from torchrl.envs import (
105105
CatTensors,
106106
EnvBase,
@@ -403,14 +403,14 @@ def _reset(self, tensordict):
403403

404404
def _make_spec(self, td_params):
405405
# Under the hood, this will populate self.output_spec["observation"]
406-
self.observation_spec = CompositeSpec(
407-
th=BoundedTensorSpec(
406+
self.observation_spec = Composite(
407+
th=Bounded(
408408
low=-torch.pi,
409409
high=torch.pi,
410410
shape=(),
411411
dtype=torch.float32,
412412
),
413-
thdot=BoundedTensorSpec(
413+
thdot=Bounded(
414414
low=-td_params["params", "max_speed"],
415415
high=td_params["params", "max_speed"],
416416
shape=(),
@@ -426,24 +426,26 @@ def _make_spec(self, td_params):
426426
self.state_spec = self.observation_spec.clone()
427427
# action-spec will be automatically wrapped in input_spec when
428428
# `self.action_spec = spec` will be called supported
429-
self.action_spec = BoundedTensorSpec(
429+
self.action_spec = Bounded(
430430
low=-td_params["params", "max_torque"],
431431
high=td_params["params", "max_torque"],
432432
shape=(1,),
433433
dtype=torch.float32,
434434
)
435-
self.reward_spec = UnboundedContinuousTensorSpec(shape=(*td_params.shape, 1))
435+
self.reward_spec = Unbounded(shape=(*td_params.shape, 1))
436436

437437

438438
def make_composite_from_td(td):
439439
# custom function to convert a ``tensordict`` in a similar spec structure
440440
# of unbounded values.
441-
composite = CompositeSpec(
441+
composite = Composite(
442442
{
443-
key: make_composite_from_td(tensor)
444-
if isinstance(tensor, TensorDictBase)
445-
else UnboundedContinuousTensorSpec(
446-
dtype=tensor.dtype, device=tensor.device, shape=tensor.shape
443+
key: (
444+
make_composite_from_td(tensor)
445+
if isinstance(tensor, TensorDictBase)
446+
else Unbounded(
447+
dtype=tensor.dtype, device=tensor.device, shape=tensor.shape
448+
)
447449
)
448450
for key, tensor in td.items()
449451
},
@@ -687,7 +689,7 @@ def _reset(
687689
# is of type ``Composite``
688690
@_apply_to_composite
689691
def transform_observation_spec(self, observation_spec):
690-
return BoundedTensorSpec(
692+
return Bounded(
691693
low=-1,
692694
high=1,
693695
shape=observation_spec.shape,
@@ -711,7 +713,7 @@ def _reset(
711713
# is of type ``Composite``
712714
@_apply_to_composite
713715
def transform_observation_spec(self, observation_spec):
714-
return BoundedTensorSpec(
716+
return Bounded(
715717
low=-1,
716718
high=1,
717719
shape=observation_spec.shape,

0 commit comments

Comments
 (0)