Skip to content

Commit d0010ce

Browse files
committed
refactor(policies): Remove unnormalization step from action predictions
- Eliminated the unnormalization of actions in both `TDMPCPolicy` and `VQBeTPolicy` classes to streamline action prediction. - This change improves code clarity and aligns with recent updates to policy processors.
1 parent 7b189dd commit d0010ce

File tree

2 files changed

+0
-2
lines changed

2 files changed

+0
-2
lines changed

src/lerobot/policies/tdmpc/modeling_tdmpc.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,6 @@ def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
129129

130130
actions = torch.clamp(actions, -1, +1)
131131

132-
actions = self.unnormalize_outputs({ACTION: actions})[ACTION]
133132
return actions
134133

135134
@torch.no_grad()

src/lerobot/policies/vqbet/modeling_vqbet.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,6 @@ def reset(self):
118118
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
119119
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
120120
actions = self.vqbet(batch, rollout=True)[:, : self.config.action_chunk_size]
121-
actions = self.unnormalize_outputs({ACTION: actions})[ACTION]
122121
return actions
123122

124123
@torch.no_grad

0 commit comments

Comments
 (0)