43
43
44
44
import torch
45
45
import torch .nn as nn
46
-
47
46
from torch import Tensor
48
47
49
- from lerobot .constants import ACTION , OBS_STATE
48
+ from lerobot .constants import ACTION
50
49
from lerobot .policies .normalize import Normalize , Unnormalize
51
50
from lerobot .policies .octo .configuration_octo import OctoConfig
52
-
53
- from lerobot .policies .pretrained import PreTrainedPolicy
54
- from lerobot .policies .utils import populate_queues , log_model_loading_keys
55
-
51
+ from lerobot .policies .octo .diffusion import DiffusionActionHead
56
52
from lerobot .policies .octo .tokenizers import TextProcessor
57
53
from lerobot .policies .octo .transformer import OctoWithoutHead
58
- from lerobot .policies .octo . diffusion import DiffusionActionHead
59
-
54
+ from lerobot .policies .pretrained import PreTrainedPolicy
55
+ from lerobot . policies . utils import log_model_loading_keys , populate_queues
60
56
61
57
# TODO(lilkm): Be aware of normalization the image tokenizer (normalize_images function)
62
58
59
+
63
60
class OctoPolicy (PreTrainedPolicy ):
64
61
"""Wrapper class around Octo model to train and run inference within LeRobot."""
65
62
@@ -110,16 +107,16 @@ def reset(self):
110
107
self ._queues = {
111
108
ACTION : deque (maxlen = self .config .n_action_steps ),
112
109
}
113
-
110
+
114
111
def _apply_selective_freezing (self ):
115
112
"""Apply selective freezing based on configuration settings."""
116
- if hasattr (self .model .octo_transformer , ' task_tokenizers' ):
113
+ if hasattr (self .model .octo_transformer , " task_tokenizers" ):
117
114
for name , tokenizer in self .model .octo_transformer .task_tokenizers .items ():
118
- if name == ' language_instruction' :
115
+ if name == " language_instruction" :
119
116
for param in tokenizer .parameters ():
120
117
param .requires_grad = False
121
- print (f "✓ T5 language encoder frozen (always frozen during finetuning)" )
122
-
118
+ print ("✓ T5 language encoder frozen (always frozen during finetuning)" )
119
+
123
120
# If train_action_head_only is True, freeze everything except the action head
124
121
if self .config .train_action_head_only :
125
122
# Freeze transformer
@@ -133,10 +130,10 @@ def _apply_selective_freezing(self):
133
130
if self .config .freeze_transformer :
134
131
for param in self .model .octo_transformer .parameters ():
135
132
param .requires_grad = False
136
-
133
+
137
134
if self .config .freeze_vision_encoder :
138
135
# Freeze vision encoder components in the transformer
139
- if hasattr (self .model .octo_transformer , ' observation_tokenizers' ):
136
+ if hasattr (self .model .octo_transformer , " observation_tokenizers" ):
140
137
for tokenizer in self .model .octo_transformer .observation_tokenizers .values ():
141
138
for param in tokenizer .parameters ():
142
139
param .requires_grad = False
@@ -170,18 +167,18 @@ def _transform_state_dict_keys(cls, state_dict: dict) -> dict:
170
167
# 1. Replace "action_head." with "head."
171
168
if "action_head." in new_key :
172
169
new_key = new_key .replace ("action_head." , "head." )
173
-
170
+
174
171
# 2. Adjust the transformer nesting to match the LeRobot model.
175
172
# The checkpoint has `transformer.transformer` but LeRobot expects
176
173
# `transformer.transformer.transformer`.
177
174
if "octo_transformer.transformer.transformer." in new_key :
178
- new_key = new_key .replace (
179
- "octo_transformer.transformer.transformer." ,
180
- "octo_transformer.transformer.transformer.transformer."
181
- )
175
+ new_key = new_key .replace (
176
+ "octo_transformer.transformer.transformer." ,
177
+ "octo_transformer.transformer.transformer.transformer." ,
178
+ )
182
179
183
180
transformed_dict [new_key ] = value
184
-
181
+
185
182
return transformed_dict
186
183
187
184
@classmethod
@@ -190,8 +187,9 @@ def _load_as_safetensor(
190
187
) -> "OctoPolicy" :
191
188
"""Override to apply key transformations before loading."""
192
189
from safetensors .torch import load_file
190
+
193
191
from lerobot .utils .utils import init_logging
194
-
192
+
195
193
init_logging ()
196
194
# Load the state dict from file safely
197
195
state_dict = load_file (model_file , device = map_location )
@@ -205,7 +203,7 @@ def _load_as_safetensor(
205
203
# Log message
206
204
log_model_loading_keys (msg .missing_keys , msg .unexpected_keys )
207
205
return model
208
-
206
+
209
207
@classmethod
210
208
def from_pretrained (cls , * args , ** kwargs ):
211
209
"""Override the from_pretrained method to display important information."""
@@ -217,15 +215,17 @@ def from_pretrained(cls, *args, **kwargs):
217
215
)
218
216
return super ().from_pretrained (* args , ** kwargs )
219
217
220
- def _prepare_batch (self , batch : dict [str , Tensor ], raw_tasks : Optional [Sequence [str ]] = None ) -> dict [str , Tensor ]:
218
+ def _prepare_batch (
219
+ self , batch : dict [str , Tensor ], raw_tasks : Optional [Sequence [str ]] = None
220
+ ) -> dict [str , Tensor ]:
221
221
"""
222
222
Prepare batch for model input.
223
223
Transforms a batch from the LeRobotDataset format to the format expected by the OctoModel.
224
224
"""
225
225
batch = self .normalize_inputs (batch )
226
226
# Get device from any available tensor in the batch
227
227
device = next (iter (batch .values ())).device
228
-
228
+
229
229
image_primary = batch ["observation.images.front" ].to (device )
230
230
image_wrist = batch ["observation.images.wrist" ].to (device )
231
231
proprio = batch ["observation.state" ].to (device )
@@ -254,34 +254,36 @@ def _prepare_batch(self, batch: dict[str, Tensor], raw_tasks: Optional[Sequence[
254
254
# Create timestep_pad_mask - all True since we have real data (no padding)
255
255
timestep_pad_mask = torch .ones ((batch_size , window_size ), dtype = torch .bool , device = device )
256
256
257
- task_completed = torch .zeros ((batch_size , window_size , action_horizon ), dtype = torch .bool , device = device )
257
+ task_completed = torch .zeros (
258
+ (batch_size , window_size , action_horizon ), dtype = torch .bool , device = device
259
+ )
258
260
259
261
# Create pad_mask_dict for observations
260
262
obs_pad_mask_dict = {
261
- ' image_primary' : torch .ones ((batch_size , window_size ), dtype = torch .bool , device = device ),
262
- ' image_wrist' : torch .ones ((batch_size , window_size ), dtype = torch .bool , device = device ),
263
- ' proprio' : torch .ones ((batch_size , window_size ), dtype = torch .bool , device = device ),
264
- ' timestep' : torch .ones ((batch_size , window_size ), dtype = torch .bool , device = device ),
263
+ " image_primary" : torch .ones ((batch_size , window_size ), dtype = torch .bool , device = device ),
264
+ " image_wrist" : torch .ones ((batch_size , window_size ), dtype = torch .bool , device = device ),
265
+ " proprio" : torch .ones ((batch_size , window_size ), dtype = torch .bool , device = device ),
266
+ " timestep" : torch .ones ((batch_size , window_size ), dtype = torch .bool , device = device ),
265
267
}
266
268
267
269
observations = {
268
- ' image_primary' : image_primary ,
269
- ' image_wrist' : image_wrist ,
270
- ' proprio' : proprio ,
271
- ' timestep' : timestep ,
272
- ' timestep_pad_mask' : timestep_pad_mask ,
273
- ' task_completed' : task_completed ,
274
- ' pad_mask_dict' : obs_pad_mask_dict
270
+ " image_primary" : image_primary ,
271
+ " image_wrist" : image_wrist ,
272
+ " proprio" : proprio ,
273
+ " timestep" : timestep ,
274
+ " timestep_pad_mask" : timestep_pad_mask ,
275
+ " task_completed" : task_completed ,
276
+ " pad_mask_dict" : obs_pad_mask_dict ,
275
277
}
276
278
277
279
language_instruction = self .text_processor .encode (raw_tasks )
278
280
language_instruction = {k : v .to (device ) for k , v in language_instruction .items ()}
279
281
280
282
tasks = {
281
- ' language_instruction' : language_instruction ,
282
- ' pad_mask_dict' : {
283
- ' language_instruction' : torch .ones (batch_size , dtype = torch .bool , device = device )
284
- }
283
+ " language_instruction" : language_instruction ,
284
+ " pad_mask_dict" : {
285
+ " language_instruction" : torch .ones (batch_size , dtype = torch .bool , device = device )
286
+ },
285
287
}
286
288
287
289
# Handle actions only if they're present (during training)
@@ -295,10 +297,10 @@ def _prepare_batch(self, batch: dict[str, Tensor], raw_tasks: Optional[Sequence[
295
297
# actions to be the target for the diffusion model.
296
298
# raw_actions has shape [batch_size, num_timestamps, action_dim]
297
299
# We need shape [batch_size, window_size, action_horizon, action_dim]
298
-
300
+
299
301
# Select the first `action_horizon` actions from the sequence.
300
302
actions = raw_actions [:, :action_horizon ]
301
-
303
+
302
304
# Add the window_size dimension.
303
305
actions = actions .unsqueeze (1 )
304
306
@@ -326,8 +328,10 @@ def _prepare_batch(self, batch: dict[str, Tensor], raw_tasks: Optional[Sequence[
326
328
# return batch
327
329
328
330
def create_tasks (
329
- self , goals : Optional [Dict [str , torch .Tensor ]] = None , texts : Optional [Sequence [str ]] = None ,
330
- device : Optional [torch .device ] = None
331
+ self ,
332
+ goals : Optional [Dict [str , torch .Tensor ]] = None ,
333
+ texts : Optional [Sequence [str ]] = None ,
334
+ device : Optional [torch .device ] = None ,
331
335
):
332
336
"""Creates tasks dict from goals and texts."""
333
337
assert goals is not None or texts is not None
@@ -348,9 +352,15 @@ def create_tasks(
348
352
else :
349
353
batch_size = len (texts )
350
354
# Create dummy goals if none are provided
351
- tasks .update ({"image_primary" : torch .zeros ((batch_size , 256 , 256 , 3 ), dtype = torch .uint8 , device = device )})
355
+ tasks .update (
356
+ {"image_primary" : torch .zeros ((batch_size , 256 , 256 , 3 ), dtype = torch .uint8 , device = device )}
357
+ )
352
358
tasks ["pad_mask_dict" ].update (
353
- {k : torch .zeros (batch_size , dtype = torch .bool , device = device ) for k in tasks .keys () if k != "pad_mask_dict" }
359
+ {
360
+ k : torch .zeros (batch_size , dtype = torch .bool , device = device )
361
+ for k in tasks .keys ()
362
+ if k != "pad_mask_dict"
363
+ }
354
364
)
355
365
356
366
if texts is not None :
@@ -359,14 +369,18 @@ def create_tasks(
359
369
# Move to the correct device
360
370
encoded = {k : v .to (device ) for k , v in encoded .items ()}
361
371
tasks ["language_instruction" ] = encoded
362
- tasks ["pad_mask_dict" ]["language_instruction" ] = torch .ones (len (texts ), dtype = torch .bool , device = device )
372
+ tasks ["pad_mask_dict" ]["language_instruction" ] = torch .ones (
373
+ len (texts ), dtype = torch .bool , device = device
374
+ )
363
375
else :
364
376
batch_size = next (iter (goals .values ())).shape [0 ]
365
377
dummy_texts = ["" ] * batch_size
366
378
encoded = self .text_processor .encode (dummy_texts )
367
379
encoded = {k : v .to (device ) for k , v in encoded .items ()}
368
380
tasks ["language_instruction" ] = encoded
369
- tasks ["pad_mask_dict" ]["language_instruction" ] = torch .zeros (batch_size , dtype = torch .bool , device = device )
381
+ tasks ["pad_mask_dict" ]["language_instruction" ] = torch .zeros (
382
+ batch_size , dtype = torch .bool , device = device
383
+ )
370
384
371
385
return tasks
372
386
@@ -401,10 +415,10 @@ def predict_action_chunk(self, batch: dict[str, Tensor], tasks: Optional[Sequenc
401
415
def select_action (self , batch : dict [str , Tensor ]) -> Tensor :
402
416
"""Select a single action given environment observations."""
403
417
self .eval ()
404
-
418
+
405
419
# First, populate queues with the original, simple batch
406
420
self ._queues = populate_queues (self ._queues , batch , exclude_keys = [ACTION ])
407
-
421
+
408
422
# Then, prepare the complex batch for the model
409
423
prepared_batch = self ._prepare_batch (batch )
410
424
@@ -489,7 +503,6 @@ def forward(
489
503
timestep_pad_mask : torch .Tensor ,
490
504
embodiment_action_dim : Optional [int ] = None ,
491
505
) -> torch .Tensor :
492
-
493
506
transformer_outputs = self .octo_transformer (observations , tasks , timestep_pad_mask )
494
507
actions = self .head .predict_action (
495
508
transformer_outputs = transformer_outputs , embodiment_action_dim = embodiment_action_dim
0 commit comments