Skip to content
28 changes: 15 additions & 13 deletions advanced_source/pendulum.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@
from tensordict.nn import TensorDictModule
from torch import nn

from torchrl.data import BoundedTensorSpec, CompositeSpec, UnboundedContinuousTensorSpec
from torchrl.data import Bounded, Composite, Unbounded
from torchrl.envs import (
CatTensors,
EnvBase,
Expand Down Expand Up @@ -403,14 +403,14 @@ def _reset(self, tensordict):

def _make_spec(self, td_params):
# Under the hood, this will populate self.output_spec["observation"]
self.observation_spec = CompositeSpec(
th=BoundedTensorSpec(
self.observation_spec = Composite(
th=Bounded(
low=-torch.pi,
high=torch.pi,
shape=(),
dtype=torch.float32,
),
thdot=BoundedTensorSpec(
thdot=Bounded(
low=-td_params["params", "max_speed"],
high=td_params["params", "max_speed"],
shape=(),
Expand All @@ -426,24 +426,26 @@ def _make_spec(self, td_params):
self.state_spec = self.observation_spec.clone()
# action-spec will be automatically wrapped in input_spec when
# `self.action_spec = spec` will be called supported
self.action_spec = BoundedTensorSpec(
self.action_spec = Bounded(
low=-td_params["params", "max_torque"],
high=td_params["params", "max_torque"],
shape=(1,),
dtype=torch.float32,
)
self.reward_spec = UnboundedContinuousTensorSpec(shape=(*td_params.shape, 1))
self.reward_spec = Unbounded(shape=(*td_params.shape, 1))


def make_composite_from_td(td):
# custom function to convert a ``tensordict`` in a similar spec structure
# of unbounded values.
composite = CompositeSpec(
composite = Composite(
{
key: make_composite_from_td(tensor)
if isinstance(tensor, TensorDictBase)
else UnboundedContinuousTensorSpec(
dtype=tensor.dtype, device=tensor.device, shape=tensor.shape
key: (
make_composite_from_td(tensor)
if isinstance(tensor, TensorDictBase)
else Unbounded(
dtype=tensor.dtype, device=tensor.device, shape=tensor.shape
)
)
for key, tensor in td.items()
},
Expand Down Expand Up @@ -687,7 +689,7 @@ def _reset(
# is of type ``Composite``
@_apply_to_composite
def transform_observation_spec(self, observation_spec):
return BoundedTensorSpec(
return Bounded(
low=-1,
high=1,
shape=observation_spec.shape,
Expand All @@ -711,7 +713,7 @@ def _reset(
# is of type ``Composite``
@_apply_to_composite
def transform_observation_spec(self, observation_spec):
return BoundedTensorSpec(
return Bounded(
low=-1,
high=1,
shape=observation_spec.shape,
Expand Down
85 changes: 83 additions & 2 deletions intermediate_source/per_sample_grads.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,9 +168,90 @@ def compute_loss(params, buffers, sample, target):
# we can double check that the results using ``grad`` and ``vmap`` match the
# results of hand processing each one individually:

for per_sample_grad, ft_per_sample_grad in zip(per_sample_grads, ft_per_sample_grads.values()):
# Create a float64 baseline for more precise comparison
def compute_grad_fp64(sample, target):
# Convert to float64 for higher precision
sample_fp64 = sample.to(torch.float64)
target_fp64 = target

# Create a float64 version of the model and explicitly convert it to float64
model_fp64 = SimpleCNN().to(device=device).to(torch.float64)

# No need to manually copy parameters as the model is already in float64

sample_fp64 = sample_fp64.unsqueeze(0) # prepend batch dimension
target_fp64 = target_fp64.unsqueeze(0)

prediction = model_fp64(sample_fp64)
loss = loss_fn(prediction, target_fp64)

return torch.autograd.grad(loss, list(model_fp64.parameters()))


def compute_fp64_baseline(data, targets, indices):
"""Compute float64 gradient for a specific sample"""
# Only compute for the sample with the largest difference to save computation
i = indices[0] # Sample index
sample_grad = compute_grad_fp64(data[i], targets[i])
return sample_grad


for i, (per_sample_grad, ft_per_sample_grad) in enumerate(
zip(per_sample_grads, ft_per_sample_grads.values())
):
is_close = torch.allclose(per_sample_grad, ft_per_sample_grad, atol=3e-3, rtol=1e-5)
if not is_close:
# Calculate and print the maximum absolute difference
abs_diff = (per_sample_grad - ft_per_sample_grad).abs()
max_diff = abs_diff.max().item()
mean_diff = abs_diff.mean().item()
print(f"Gradient {i} mismatch:")
print(f" Max absolute difference: {max_diff}")
print(f" Mean absolute difference: {mean_diff}")
print(f" Shape of tensors: {per_sample_grad.shape}")

# Find the location of maximum difference
max_idx = abs_diff.argmax().item()
flat_idx = max_idx
if len(abs_diff.shape) > 1:
# Convert flat index to multi-dimensional index
indices = []
temp_shape = abs_diff.shape
for dim in reversed(temp_shape):
indices.insert(0, flat_idx % dim)
flat_idx //= dim
print(f" Max difference at index: {indices}")
print(f" Manual gradient value: {per_sample_grad[tuple(indices)].item()}")
print(
f" Functional gradient value: {ft_per_sample_grad[tuple(indices)].item()}"
)

# Compute float64 baseline for the sample with the largest difference
print("\nComputing float64 baseline for comparison...")
try:
fp64_grads = compute_fp64_baseline(data, targets, indices)
fp64_value = fp64_grads[i][
tuple(indices[1:])
].item() # Skip batch dimension
print(f" Float64 baseline value: {fp64_value}")

# Compare both methods against float64 baseline
manual_diff = abs(per_sample_grad[tuple(indices)].item() - fp64_value)
functional_diff = abs(
ft_per_sample_grad[tuple(indices)].item() - fp64_value
)
print(f" Manual method vs float64 difference: {manual_diff}")
print(f" Functional method vs float64 difference: {functional_diff}")

if manual_diff < functional_diff:
print(" Manual method is closer to float64 baseline")
else:
print(" Functional method is closer to float64 baseline")
except Exception as e:
print(f" Error computing float64 baseline: {e}")

# Keep the original assertion
assert torch.allclose(per_sample_grad, ft_per_sample_grad, atol=3e-3, rtol=1e-5)

######################################################################
# A quick note: there are limitations around what types of functions can be
# transformed by ``vmap``. The best functions to transform are ones that are pure
Expand Down
Loading