Skip to content

Conversation

zhigenzhao
Copy link

Summary

Refactor diffusion policy action generation to properly maintain observation queues when
predict_action_chunk() is called outside of select_action().

Problem

The original implementation of predict_action_chunk() assumed it would only be called
through select_action(), which handles input normalization and queue population. When
predict_action_chunk() was called directly (outside the select_action() workflow), it
would:

  • Skip input normalization
  • Not populate observation queues properly
  • Fail to handle image feature stacking
  • Lead to inconsistent behavior depending on call context

Solution

  • Extract _get_action_chunk() method: Created a stateless method that handles the
    core action generation logic from prepared observations
  • Refactor predict_action_chunk(): Now properly normalizes inputs, handles image
    features, and populates queues before delegating to _get_action_chunk()
  • Update select_action(): Modified to use _get_action_chunk() directly when
    generating new actions, avoiding duplicate queue operations

Changes

  • Added _get_action_chunk() method for stateless action generation from prepared
    observations
  • Enhanced predict_action_chunk() to handle input normalization, image feature stacking,
    and queue population
  • Updated select_action() to use the new stateless method when the action queue is empty
  • Ensures consistent behavior regardless of whether actions are generated through
    select_action() or predict_action_chunk()

- Extract _get_action_chunk() method for stateless action generation
- Refactor predict_action_chunk() to properly normalize inputs and populate queues
- Update select_action() to use _get_action_chunk() directly when needed
@Copilot Copilot AI review requested due to automatic review settings September 11, 2025 13:27
Copy link
Contributor

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR refactors the diffusion policy action generation logic to ensure consistent behavior when predict_action_chunk() is called independently versus through select_action(). The refactoring extracts the core action generation logic into a stateless method and improves input handling.

Key changes:

  • Extract stateless _get_action_chunk() method for core action generation logic
  • Enhanced predict_action_chunk() to properly handle input normalization and queue population
  • Updated select_action() to use the new stateless method to avoid duplicate operations

Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

# Normalize and prepare batch
batch = self.normalize_inputs(batch)
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
Copy link
Preview

Copilot AI Sep 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The shallow copy approach may not be sufficient if the batch contains nested mutable objects. Consider using copy.deepcopy() or ensure that all batch values are immutable to prevent unintended side effects.

Copilot uses AI. Check for mistakes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant