Refactor action generation logic in diffusion policy #1917
+22
−7
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Summary
Refactor diffusion policy action generation to properly maintain observation queues when
predict_action_chunk()
is called outside ofselect_action()
.Problem
The original implementation of
predict_action_chunk()
assumed it would only be calledthrough
select_action()
, which handles input normalization and queue population. Whenpredict_action_chunk()
was called directly (outside theselect_action()
workflow), itwould:
Solution
_get_action_chunk()
method: Created a stateless method that handles thecore action generation logic from prepared observations
predict_action_chunk()
: Now properly normalizes inputs, handles imagefeatures, and populates queues before delegating to
_get_action_chunk()
select_action()
: Modified to use_get_action_chunk()
directly whengenerating new actions, avoiding duplicate queue operations
Changes
_get_action_chunk()
method for stateless action generation from preparedobservations
predict_action_chunk()
to handle input normalization, image feature stacking,and queue population
select_action()
to use the new stateless method when the action queue is emptyselect_action()
orpredict_action_chunk()