100
100
from tensordict .nn import TensorDictModule
101
101
from torch import nn
102
102
103
- from torchrl .data import BoundedTensorSpec , CompositeSpec , UnboundedContinuousTensorSpec
103
+ from torchrl .data import Bounded , Composite , Unbounded
104
104
from torchrl .envs import (
105
105
CatTensors ,
106
106
EnvBase ,
@@ -403,14 +403,14 @@ def _reset(self, tensordict):
403
403
404
404
def _make_spec (self , td_params ):
405
405
# Under the hood, this will populate self.output_spec["observation"]
406
- self .observation_spec = CompositeSpec (
407
- th = BoundedTensorSpec (
406
+ self .observation_spec = Composite (
407
+ th = Bounded (
408
408
low = - torch .pi ,
409
409
high = torch .pi ,
410
410
shape = (),
411
411
dtype = torch .float32 ,
412
412
),
413
- thdot = BoundedTensorSpec (
413
+ thdot = Bounded (
414
414
low = - td_params ["params" , "max_speed" ],
415
415
high = td_params ["params" , "max_speed" ],
416
416
shape = (),
@@ -426,24 +426,26 @@ def _make_spec(self, td_params):
426
426
self .state_spec = self .observation_spec .clone ()
427
427
# action-spec will be automatically wrapped in input_spec when
428
428
# `self.action_spec = spec` will be called supported
429
- self .action_spec = BoundedTensorSpec (
429
+ self .action_spec = Bounded (
430
430
low = - td_params ["params" , "max_torque" ],
431
431
high = td_params ["params" , "max_torque" ],
432
432
shape = (1 ,),
433
433
dtype = torch .float32 ,
434
434
)
435
- self .reward_spec = UnboundedContinuousTensorSpec (shape = (* td_params .shape , 1 ))
435
+ self .reward_spec = Unbounded (shape = (* td_params .shape , 1 ))
436
436
437
437
438
438
def make_composite_from_td (td ):
439
439
# custom function to convert a ``tensordict`` in a similar spec structure
440
440
# of unbounded values.
441
- composite = CompositeSpec (
441
+ composite = Composite (
442
442
{
443
- key : make_composite_from_td (tensor )
444
- if isinstance (tensor , TensorDictBase )
445
- else UnboundedContinuousTensorSpec (
446
- dtype = tensor .dtype , device = tensor .device , shape = tensor .shape
443
+ key : (
444
+ make_composite_from_td (tensor )
445
+ if isinstance (tensor , TensorDictBase )
446
+ else Unbounded (
447
+ dtype = tensor .dtype , device = tensor .device , shape = tensor .shape
448
+ )
447
449
)
448
450
for key , tensor in td .items ()
449
451
},
@@ -687,7 +689,7 @@ def _reset(
687
689
# is of type ``Composite``
688
690
@_apply_to_composite
689
691
def transform_observation_spec (self , observation_spec ):
690
- return BoundedTensorSpec (
692
+ return Bounded (
691
693
low = - 1 ,
692
694
high = 1 ,
693
695
shape = observation_spec .shape ,
@@ -711,7 +713,7 @@ def _reset(
711
713
# is of type ``Composite``
712
714
@_apply_to_composite
713
715
def transform_observation_spec (self , observation_spec ):
714
- return BoundedTensorSpec (
716
+ return Bounded (
715
717
low = - 1 ,
716
718
high = 1 ,
717
719
shape = observation_spec .shape ,
0 commit comments