Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 22 additions & 7 deletions src/lerobot/policies/diffusion/modeling_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,17 +99,28 @@ def reset(self):
if self.config.env_state_feature:
self._queues["observation.environment_state"] = deque(maxlen=self.config.n_obs_steps)

def _get_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
"""Stateless method to generate actions from prepared observations."""
actions = self.diffusion.generate_actions(batch)
# TODO(rcadene): make above methods return output dictionary?
actions = self.unnormalize_outputs({ACTION: actions})[ACTION]
return actions

@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
"""Predict a chunk of actions given environment observations."""
# stack n latest observations from the queue
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
actions = self.diffusion.generate_actions(batch)
# Normalize and prepare batch
batch = self.normalize_inputs(batch)
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
Copy link
Preview

Copilot AI Sep 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The shallow copy approach may not be sufficient if the batch contains nested mutable objects. Consider using copy.deepcopy() or ensure that all batch values are immutable to prevent unintended side effects.

Copilot uses AI. Check for mistakes.

batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)

# TODO(rcadene): make above methods return output dictionary?
actions = self.unnormalize_outputs({ACTION: actions})[ACTION]
# Populate queues with current batch
self._queues = populate_queues(self._queues, batch)

return actions
# Stack observations from queues
prepared_batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
return self._get_action_chunk(prepared_batch)

@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
Expand Down Expand Up @@ -145,7 +156,11 @@ def select_action(self, batch: dict[str, Tensor]) -> Tensor:
self._queues = populate_queues(self._queues, batch)

if len(self._queues[ACTION]) == 0:
actions = self.predict_action_chunk(batch)
# Create prepared batch for action generation
prepared_batch = {
k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues
}
actions = self._get_action_chunk(prepared_batch)
self._queues[ACTION].extend(actions.transpose(0, 1))

action = self._queues[ACTION].popleft()
Expand Down