Skip to content

Commit 7067e2d

Browse files
committed
Refactor and clean up unused code
1 parent ed3f1fb commit 7067e2d

File tree

7 files changed

+90
-77
lines changed

7 files changed

+90
-77
lines changed

src/lerobot/policies/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414

1515
from .act.configuration_act import ACTConfig as ACTConfig
1616
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
17+
from .octo.configuration_octo import OctoConfig as OctoConfig
1718
from .pi0.configuration_pi0 import PI0Config as PI0Config
1819
from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig
1920
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
2021
from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig
21-
from .octo.configuration_octo import OctoConfig as OctoConfig

src/lerobot/policies/factory.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,13 @@
2626
from lerobot.envs.utils import env_to_policy_features
2727
from lerobot.policies.act.configuration_act import ACTConfig
2828
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
29+
from lerobot.policies.octo.configuration_octo import OctoConfig
2930
from lerobot.policies.pi0.configuration_pi0 import PI0Config
3031
from lerobot.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
3132
from lerobot.policies.pretrained import PreTrainedPolicy
3233
from lerobot.policies.sac.configuration_sac import SACConfig
3334
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
3435
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
35-
from lerobot.policies.octo.configuration_octo import OctoConfig
3636
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
3737
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
3838

src/lerobot/policies/octo/configuration_octo.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from dataclasses import dataclass, field
1616

1717
from lerobot.configs.policies import PreTrainedConfig
18-
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
18+
from lerobot.configs.types import NormalizationMode
1919
from lerobot.optim.optimizers import AdamWConfig
2020
from lerobot.optim.schedulers import (
2121
CosineDecayWithWarmupSchedulerConfig,
@@ -31,12 +31,12 @@ class OctoConfig(PreTrainedConfig):
3131
num_layers: int = 12
3232
num_heads: int = 12
3333
mlp_dim: int = 3072
34-
34+
3535
# Input / output structure
3636
n_obs_steps: int = 1
3737
chunk_size: int = 10 # max_horizon in octo
3838
n_action_steps: int = 4 # action_horizon in octo
39-
39+
4040
# Normalization
4141
normalization_mapping: dict[str, NormalizationMode] = field(
4242
default_factory=lambda: {
@@ -47,22 +47,22 @@ class OctoConfig(PreTrainedConfig):
4747
)
4848

4949
push_to_hub: bool = False
50-
50+
5151
# Image preprocessing
5252
resize_primary_image: tuple[int, int] = (256, 256)
5353
resize_wrist_image: tuple[int, int] = (128, 128)
54-
54+
5555
# Language model
5656
language_model_name: str = "t5-base"
5757
language_max_length: int = 16
5858
freeze_language_encoder: bool = True
59-
59+
6060
# Transformer settings
6161
repeat_task_tokens: bool = True
6262
dropout_rate: float = 0.0
6363
attention_dropout_rate: float = 0.0
6464
add_position_embedding: bool = False
65-
65+
6666
# Diffusion settings
6767
diffusion_steps: int = 20
6868
n_diffusion_samples: int = 1
@@ -73,26 +73,26 @@ class OctoConfig(PreTrainedConfig):
7373
num_blocks: int = 3
7474
hidden_dim: int = 256
7575
use_layer_norm: bool = True
76-
76+
7777
# Finetuning settings
7878
freeze_transformer: bool = False
7979
freeze_vision_encoder: bool = True
8080
train_action_head_only: bool = False
81-
81+
8282
# Training presets
8383
optimizer_lr: float = 1e-4
8484
optimizer_betas: tuple[float, float] = (0.9, 0.999)
8585
optimizer_eps: float = 1e-8
8686
optimizer_weight_decay: float = 1e-4
8787
optimizer_grad_clip_norm: float = 10.0
88-
88+
8989
scheduler_warmup_steps: int = 1_000
9090
scheduler_decay_steps: int = 100_000
9191
scheduler_decay_lr: float = 1e-5
9292

9393
def __post_init__(self):
9494
super().__post_init__()
95-
95+
9696
# Set architecture parameters based on model_name
9797
if self.model_name == "octo-base":
9898
self.token_embedding_size = 768
@@ -106,7 +106,7 @@ def __post_init__(self):
106106
self.mlp_dim = 1536
107107
else:
108108
raise ValueError(f"Unknown model name: {self.model_name}")
109-
109+
110110
# Input validation
111111
if self.n_action_steps > self.chunk_size:
112112
raise ValueError(

src/lerobot/policies/octo/diffusion.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919

2020
import torch
2121
import torch.nn as nn
22-
2322
from einops import rearrange
2423

2524
from lerobot.policies.octo.base import TokenGroup

src/lerobot/policies/octo/modeling_octo.py

Lines changed: 65 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -43,23 +43,20 @@
4343

4444
import torch
4545
import torch.nn as nn
46-
4746
from torch import Tensor
4847

49-
from lerobot.constants import ACTION, OBS_STATE
48+
from lerobot.constants import ACTION
5049
from lerobot.policies.normalize import Normalize, Unnormalize
5150
from lerobot.policies.octo.configuration_octo import OctoConfig
52-
53-
from lerobot.policies.pretrained import PreTrainedPolicy
54-
from lerobot.policies.utils import populate_queues, log_model_loading_keys
55-
51+
from lerobot.policies.octo.diffusion import DiffusionActionHead
5652
from lerobot.policies.octo.tokenizers import TextProcessor
5753
from lerobot.policies.octo.transformer import OctoWithoutHead
58-
from lerobot.policies.octo.diffusion import DiffusionActionHead
59-
54+
from lerobot.policies.pretrained import PreTrainedPolicy
55+
from lerobot.policies.utils import log_model_loading_keys, populate_queues
6056

6157
# TODO(lilkm): Be aware of normalization the image tokenizer (normalize_images function)
6258

59+
6360
class OctoPolicy(PreTrainedPolicy):
6461
"""Wrapper class around Octo model to train and run inference within LeRobot."""
6562

@@ -110,16 +107,16 @@ def reset(self):
110107
self._queues = {
111108
ACTION: deque(maxlen=self.config.n_action_steps),
112109
}
113-
110+
114111
def _apply_selective_freezing(self):
115112
"""Apply selective freezing based on configuration settings."""
116-
if hasattr(self.model.octo_transformer, 'task_tokenizers'):
113+
if hasattr(self.model.octo_transformer, "task_tokenizers"):
117114
for name, tokenizer in self.model.octo_transformer.task_tokenizers.items():
118-
if name == 'language_instruction':
115+
if name == "language_instruction":
119116
for param in tokenizer.parameters():
120117
param.requires_grad = False
121-
print(f"✓ T5 language encoder frozen (always frozen during finetuning)")
122-
118+
print("✓ T5 language encoder frozen (always frozen during finetuning)")
119+
123120
# If train_action_head_only is True, freeze everything except the action head
124121
if self.config.train_action_head_only:
125122
# Freeze transformer
@@ -133,10 +130,10 @@ def _apply_selective_freezing(self):
133130
if self.config.freeze_transformer:
134131
for param in self.model.octo_transformer.parameters():
135132
param.requires_grad = False
136-
133+
137134
if self.config.freeze_vision_encoder:
138135
# Freeze vision encoder components in the transformer
139-
if hasattr(self.model.octo_transformer, 'observation_tokenizers'):
136+
if hasattr(self.model.octo_transformer, "observation_tokenizers"):
140137
for tokenizer in self.model.octo_transformer.observation_tokenizers.values():
141138
for param in tokenizer.parameters():
142139
param.requires_grad = False
@@ -170,18 +167,18 @@ def _transform_state_dict_keys(cls, state_dict: dict) -> dict:
170167
# 1. Replace "action_head." with "head."
171168
if "action_head." in new_key:
172169
new_key = new_key.replace("action_head.", "head.")
173-
170+
174171
# 2. Adjust the transformer nesting to match the LeRobot model.
175172
# The checkpoint has `transformer.transformer` but LeRobot expects
176173
# `transformer.transformer.transformer`.
177174
if "octo_transformer.transformer.transformer." in new_key:
178-
new_key = new_key.replace(
179-
"octo_transformer.transformer.transformer.",
180-
"octo_transformer.transformer.transformer.transformer."
181-
)
175+
new_key = new_key.replace(
176+
"octo_transformer.transformer.transformer.",
177+
"octo_transformer.transformer.transformer.transformer.",
178+
)
182179

183180
transformed_dict[new_key] = value
184-
181+
185182
return transformed_dict
186183

187184
@classmethod
@@ -190,8 +187,9 @@ def _load_as_safetensor(
190187
) -> "OctoPolicy":
191188
"""Override to apply key transformations before loading."""
192189
from safetensors.torch import load_file
190+
193191
from lerobot.utils.utils import init_logging
194-
192+
195193
init_logging()
196194
# Load the state dict from file safely
197195
state_dict = load_file(model_file, device=map_location)
@@ -205,7 +203,7 @@ def _load_as_safetensor(
205203
# Log message
206204
log_model_loading_keys(msg.missing_keys, msg.unexpected_keys)
207205
return model
208-
206+
209207
@classmethod
210208
def from_pretrained(cls, *args, **kwargs):
211209
"""Override the from_pretrained method to display important information."""
@@ -217,15 +215,17 @@ def from_pretrained(cls, *args, **kwargs):
217215
)
218216
return super().from_pretrained(*args, **kwargs)
219217

220-
def _prepare_batch(self, batch: dict[str, Tensor], raw_tasks: Optional[Sequence[str]] = None) -> dict[str, Tensor]:
218+
def _prepare_batch(
219+
self, batch: dict[str, Tensor], raw_tasks: Optional[Sequence[str]] = None
220+
) -> dict[str, Tensor]:
221221
"""
222222
Prepare batch for model input.
223223
Transforms a batch from the LeRobotDataset format to the format expected by the OctoModel.
224224
"""
225225
batch = self.normalize_inputs(batch)
226226
# Get device from any available tensor in the batch
227227
device = next(iter(batch.values())).device
228-
228+
229229
image_primary = batch["observation.images.front"].to(device)
230230
image_wrist = batch["observation.images.wrist"].to(device)
231231
proprio = batch["observation.state"].to(device)
@@ -254,34 +254,36 @@ def _prepare_batch(self, batch: dict[str, Tensor], raw_tasks: Optional[Sequence[
254254
# Create timestep_pad_mask - all True since we have real data (no padding)
255255
timestep_pad_mask = torch.ones((batch_size, window_size), dtype=torch.bool, device=device)
256256

257-
task_completed = torch.zeros((batch_size, window_size, action_horizon), dtype=torch.bool, device=device)
257+
task_completed = torch.zeros(
258+
(batch_size, window_size, action_horizon), dtype=torch.bool, device=device
259+
)
258260

259261
# Create pad_mask_dict for observations
260262
obs_pad_mask_dict = {
261-
'image_primary': torch.ones((batch_size, window_size), dtype=torch.bool, device=device),
262-
'image_wrist': torch.ones((batch_size, window_size), dtype=torch.bool, device=device),
263-
'proprio': torch.ones((batch_size, window_size), dtype=torch.bool, device=device),
264-
'timestep': torch.ones((batch_size, window_size), dtype=torch.bool, device=device),
263+
"image_primary": torch.ones((batch_size, window_size), dtype=torch.bool, device=device),
264+
"image_wrist": torch.ones((batch_size, window_size), dtype=torch.bool, device=device),
265+
"proprio": torch.ones((batch_size, window_size), dtype=torch.bool, device=device),
266+
"timestep": torch.ones((batch_size, window_size), dtype=torch.bool, device=device),
265267
}
266268

267269
observations = {
268-
'image_primary': image_primary,
269-
'image_wrist': image_wrist,
270-
'proprio': proprio,
271-
'timestep': timestep,
272-
'timestep_pad_mask': timestep_pad_mask,
273-
'task_completed': task_completed,
274-
'pad_mask_dict': obs_pad_mask_dict
270+
"image_primary": image_primary,
271+
"image_wrist": image_wrist,
272+
"proprio": proprio,
273+
"timestep": timestep,
274+
"timestep_pad_mask": timestep_pad_mask,
275+
"task_completed": task_completed,
276+
"pad_mask_dict": obs_pad_mask_dict,
275277
}
276278

277279
language_instruction = self.text_processor.encode(raw_tasks)
278280
language_instruction = {k: v.to(device) for k, v in language_instruction.items()}
279281

280282
tasks = {
281-
'language_instruction': language_instruction,
282-
'pad_mask_dict': {
283-
'language_instruction': torch.ones(batch_size, dtype=torch.bool, device=device)
284-
}
283+
"language_instruction": language_instruction,
284+
"pad_mask_dict": {
285+
"language_instruction": torch.ones(batch_size, dtype=torch.bool, device=device)
286+
},
285287
}
286288

287289
# Handle actions only if they're present (during training)
@@ -295,10 +297,10 @@ def _prepare_batch(self, batch: dict[str, Tensor], raw_tasks: Optional[Sequence[
295297
# actions to be the target for the diffusion model.
296298
# raw_actions has shape [batch_size, num_timestamps, action_dim]
297299
# We need shape [batch_size, window_size, action_horizon, action_dim]
298-
300+
299301
# Select the first `action_horizon` actions from the sequence.
300302
actions = raw_actions[:, :action_horizon]
301-
303+
302304
# Add the window_size dimension.
303305
actions = actions.unsqueeze(1)
304306

@@ -326,8 +328,10 @@ def _prepare_batch(self, batch: dict[str, Tensor], raw_tasks: Optional[Sequence[
326328
# return batch
327329

328330
def create_tasks(
329-
self, goals: Optional[Dict[str, torch.Tensor]] = None, texts: Optional[Sequence[str]] = None,
330-
device: Optional[torch.device] = None
331+
self,
332+
goals: Optional[Dict[str, torch.Tensor]] = None,
333+
texts: Optional[Sequence[str]] = None,
334+
device: Optional[torch.device] = None,
331335
):
332336
"""Creates tasks dict from goals and texts."""
333337
assert goals is not None or texts is not None
@@ -348,9 +352,15 @@ def create_tasks(
348352
else:
349353
batch_size = len(texts)
350354
# Create dummy goals if none are provided
351-
tasks.update({"image_primary": torch.zeros((batch_size, 256, 256, 3), dtype=torch.uint8, device=device)})
355+
tasks.update(
356+
{"image_primary": torch.zeros((batch_size, 256, 256, 3), dtype=torch.uint8, device=device)}
357+
)
352358
tasks["pad_mask_dict"].update(
353-
{k: torch.zeros(batch_size, dtype=torch.bool, device=device) for k in tasks.keys() if k != "pad_mask_dict"}
359+
{
360+
k: torch.zeros(batch_size, dtype=torch.bool, device=device)
361+
for k in tasks.keys()
362+
if k != "pad_mask_dict"
363+
}
354364
)
355365

356366
if texts is not None:
@@ -359,14 +369,18 @@ def create_tasks(
359369
# Move to the correct device
360370
encoded = {k: v.to(device) for k, v in encoded.items()}
361371
tasks["language_instruction"] = encoded
362-
tasks["pad_mask_dict"]["language_instruction"] = torch.ones(len(texts), dtype=torch.bool, device=device)
372+
tasks["pad_mask_dict"]["language_instruction"] = torch.ones(
373+
len(texts), dtype=torch.bool, device=device
374+
)
363375
else:
364376
batch_size = next(iter(goals.values())).shape[0]
365377
dummy_texts = [""] * batch_size
366378
encoded = self.text_processor.encode(dummy_texts)
367379
encoded = {k: v.to(device) for k, v in encoded.items()}
368380
tasks["language_instruction"] = encoded
369-
tasks["pad_mask_dict"]["language_instruction"] = torch.zeros(batch_size, dtype=torch.bool, device=device)
381+
tasks["pad_mask_dict"]["language_instruction"] = torch.zeros(
382+
batch_size, dtype=torch.bool, device=device
383+
)
370384

371385
return tasks
372386

@@ -401,10 +415,10 @@ def predict_action_chunk(self, batch: dict[str, Tensor], tasks: Optional[Sequenc
401415
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
402416
"""Select a single action given environment observations."""
403417
self.eval()
404-
418+
405419
# First, populate queues with the original, simple batch
406420
self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])
407-
421+
408422
# Then, prepare the complex batch for the model
409423
prepared_batch = self._prepare_batch(batch)
410424

@@ -489,7 +503,6 @@ def forward(
489503
timestep_pad_mask: torch.Tensor,
490504
embodiment_action_dim: Optional[int] = None,
491505
) -> torch.Tensor:
492-
493506
transformer_outputs = self.octo_transformer(observations, tasks, timestep_pad_mask)
494507
actions = self.head.predict_action(
495508
transformer_outputs=transformer_outputs, embodiment_action_dim=embodiment_action_dim

0 commit comments

Comments
 (0)