Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,20 @@ repos:
args: ["-c", "pyproject.toml"]
additional_dependencies: ["bandit[toml]"]

# TODO(Steven): Uncomment when ready to use
##### Static Analysis & Typing #####
# - repo: https://github.com/pre-commit/mirrors-mypy
# rev: v1.16.0
# hooks:
# - id: mypy
# args: [--python-version=3.10]
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.16.0
hooks:
- id: mypy
# Run MyPy in phases: limit to the explicitly-configured files in pyproject.toml
# and ignore missing imports in the hook venv (heavy libs like torch aren’t installed here).
pass_filenames: false
always_run: true
args:
- --python-version=3.10
- --config-file=pyproject.toml
- --install-types
- --non-interactive

##### Docstring Checks #####
# - repo: https://github.com/akaihola/darglint2
Expand Down
60 changes: 55 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -259,8 +259,58 @@ default.extend-ignore-identifiers-re = [
# color = true
# paths = ["src/lerobot"]

# [tool.mypy]
# python_version = "3.10"
# warn_return_any = true
# warn_unused_configs = true
# ignore_missing_imports = false
[tool.mypy]
# Minimum viable mypy config. We start with a
# narrow scope and will expand module-by-module.
python_version = "3.10"
warn_return_any = true
warn_unused_configs = true
ignore_missing_imports = false
follow_imports = "skip"

# Ensure package discovery from the src layout
mypy_path = ["src"]
namespace_packages = true
explicit_package_bases = true

# Start small to land CI enablement; extend gradually.
# See issue #1719 and module-specific tasks #1720–#1732.
files = [
"src/lerobot/__init__.py",
"src/lerobot/constants.py",
"src/lerobot/errors.py",
# Start policies with a leaf utility module to keep scope small
"src/lerobot/policies/utils.py",
"src/lerobot/policies/factory.py",
"src/lerobot/policies/pretrained.py",
"src/lerobot/configs/policies.py",
"src/lerobot/policies/tdmpc/configuration_tdmpc.py",
"src/lerobot/policies/act/configuration_act.py",
"src/lerobot/policies/diffusion/configuration_diffusion.py",
"src/lerobot/policies/sac/configuration_sac.py",
"src/lerobot/policies/pi0/configuration_pi0.py",
"src/lerobot/policies/pi0fast/configuration_pi0fast.py",
"src/lerobot/policies/smolvla/configuration_smolvla.py",
"src/lerobot/policies/vqbet/configuration_vqbet.py",
# Modeling files will be added progressively as they are cleaned up
"src/lerobot/policies/sac/modeling_sac.py",
"src/lerobot/policies/pi0fast/modeling_pi0fast.py",
"src/lerobot/policies/act/modeling_act.py",
"src/lerobot/policies/diffusion/modeling_diffusion.py",
"src/lerobot/policies/tdmpc/modeling_tdmpc.py",
"src/lerobot/policies/vqbet/modeling_vqbet.py",
"src/lerobot/policies/smolvla/modeling_smolvla.py",
]

[[tool.mypy.overrides]]
module = [
"draccus",
"datasets",
"datasets.*",
"jsonlines",
"torchvision",
"torchvision.*",
"scipy",
"scipy.*",
]
ignore_missing_imports = true
18 changes: 12 additions & 6 deletions src/lerobot/configs/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import tempfile
from dataclasses import dataclass, field
from pathlib import Path
from typing import TypeVar
from typing import TypeVar, cast

import draccus
from huggingface_hub import hf_hub_download
Expand All @@ -30,14 +30,13 @@
from lerobot.constants import ACTION, OBS_STATE
from lerobot.optim.optimizers import OptimizerConfig
from lerobot.optim.schedulers import LRSchedulerConfig
from lerobot.utils.hub import HubMixin
from lerobot.utils.utils import auto_select_torch_device, is_amp_available, is_torch_device_available

T = TypeVar("T", bound="PreTrainedConfig")


@dataclass
class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
class PreTrainedConfig(draccus.ChoiceRegistry, abc.ABC):
"""
Base configuration class for policy models.
Expand Down Expand Up @@ -75,6 +74,12 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):

def __post_init__(self):
self.pretrained_path = None
# Honor test/device override via environment for consistency in CI and local runs
if not self.device:
env_device = os.environ.get("LEROBOT_TEST_DEVICE")
if env_device:
self.device = env_device

if not self.device or not is_torch_device_available(self.device):
auto_device = auto_select_torch_device()
logging.warning(f"Device '{self.device}' is not available. Switching to '{auto_device}'.")
Expand All @@ -89,7 +94,7 @@ def __post_init__(self):

@property
def type(self) -> str:
return self.get_choice_name(self.__class__)
return cast(str, self.get_choice_name(self.__class__))

@property
@abc.abstractmethod
Expand Down Expand Up @@ -153,7 +158,7 @@ def from_pretrained(
pretrained_name_or_path: str | Path,
*,
force_download: bool = False,
resume_download: bool = None,
resume_download: bool | None = None,
proxies: dict | None = None,
token: str | bool | None = None,
cache_dir: str | Path | None = None,
Expand Down Expand Up @@ -193,6 +198,7 @@ def from_pretrained(
with draccus.config_type("json"):
orig_config = draccus.parse(cls, config_file, args=[])

assert config_file is not None
with open(config_file) as f:
config = json.load(f)

Expand All @@ -204,4 +210,4 @@ def from_pretrained(

cli_overrides = policy_kwargs.pop("cli_overrides", [])
with draccus.config_type("json"):
return draccus.parse(orig_config.__class__, config_file, args=cli_overrides)
return cast(T, draccus.parse(orig_config.__class__, config_file, args=cli_overrides))
7 changes: 5 additions & 2 deletions src/lerobot/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,11 @@

# cache dir
default_cache_path = Path(HF_HOME) / "lerobot"
HF_LEROBOT_HOME = Path(os.getenv("HF_LEROBOT_HOME", default_cache_path)).expanduser()
# Use explicit fallback instead of passing a Path as getenv default
_home_env = os.getenv("HF_LEROBOT_HOME")
HF_LEROBOT_HOME = (Path(_home_env) if _home_env else default_cache_path).expanduser()

# calibration dir
default_calibration_path = HF_LEROBOT_HOME / "calibration"
HF_LEROBOT_CALIBRATION = Path(os.getenv("HF_LEROBOT_CALIBRATION", default_calibration_path)).expanduser()
_calib_env = os.getenv("HF_LEROBOT_CALIBRATION")
HF_LEROBOT_CALIBRATION = (Path(_calib_env) if _calib_env else default_calibration_path).expanduser()
10 changes: 5 additions & 5 deletions src/lerobot/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
class DeviceNotConnectedError(ConnectionError):
"""Exception raised when the device is not connected."""

def __init__(self, message="This device is not connected. Try calling `connect()` first."):
def __init__(self, message: str = "This device is not connected. Try calling `connect()` first.") -> None:
self.message = message
super().__init__(self.message)

Expand All @@ -26,8 +26,8 @@ class DeviceAlreadyConnectedError(ConnectionError):

def __init__(
self,
message="This device is already connected. Try not calling `connect()` twice.",
):
message: str = "This device is already connected. Try not calling `connect()` twice.",
) -> None:
self.message = message
super().__init__(self.message)

Expand All @@ -37,7 +37,7 @@ class InvalidActionError(ValueError):

def __init__(
self,
message="The action is invalid. Check the value follows what it is expected from the action space.",
):
message: str = "The action is invalid. Check the value follows what it is expected from the action space.",
) -> None:
self.message = message
super().__init__(self.message)
63 changes: 37 additions & 26 deletions src/lerobot/policies/act/modeling_act.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@

import math
from collections import deque
from collections.abc import Callable
from collections.abc import Callable, Sequence
from itertools import chain
from typing import Any, TypedDict, cast

import einops
import numpy as np
Expand All @@ -39,6 +40,19 @@
from lerobot.policies.pretrained import PreTrainedPolicy


class OptimizerKwargs(TypedDict, total=False):
params: list[nn.Parameter]
lr: float
weight_decay: float
momentum: float
betas: tuple[float, float]
eps: float


class ForwardKwargs(TypedDict, total=False):
pass


class ACTPolicy(PreTrainedPolicy):
"""
Action Chunking Transformer Policy as per Learning Fine-Grained Bimanual Manipulation with Low-Cost
Expand Down Expand Up @@ -79,7 +93,7 @@ def __init__(

self.reset()

def get_optim_params(self) -> dict:
def get_optim_params(self) -> Sequence[OptimizerKwargs]:
# TODO(aliberts, rcadene): As of now, lr_backbone == lr
# Should we remove this and just `return self.parameters()`?
return [
Expand Down Expand Up @@ -136,7 +150,6 @@ def select_action(self, batch: dict[str, Tensor]) -> Tensor:
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
"""Predict a chunk of actions given environment observations."""
self.eval()

batch = self.normalize_inputs(batch)
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
Expand All @@ -146,7 +159,7 @@ def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
actions = self.unnormalize_outputs({ACTION: actions})[ACTION]
return actions

def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
def forward(self, batch: dict[str, Tensor], **kwargs: object) -> tuple[Tensor, dict]:
"""Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch)
if self.config.image_features:
Expand Down Expand Up @@ -313,21 +326,19 @@ def __init__(self, config: ACTConfig):
self.vae_encoder = ACTEncoder(config, is_vae_encoder=True)
self.vae_encoder_cls_embed = nn.Embedding(1, config.dim_model)
# Projection layer for joint-space configuration to hidden dimension.
if self.config.robot_state_feature:
self.vae_encoder_robot_state_input_proj = nn.Linear(
self.config.robot_state_feature.shape[0], config.dim_model
)
_robot_ft = self.config.robot_state_feature
if _robot_ft is not None:
self.vae_encoder_robot_state_input_proj = nn.Linear(_robot_ft.shape[0], config.dim_model)
# Projection layer for action (joint-space target) to hidden dimension.
self.vae_encoder_action_input_proj = nn.Linear(
self.config.action_feature.shape[0],
config.dim_model,
)
_action_ft = self.config.action_feature
assert _action_ft is not None
self.vae_encoder_action_input_proj = nn.Linear(_action_ft.shape[0], config.dim_model)
# Projection layer from the VAE encoder's output to the latent distribution's parameter space.
self.vae_encoder_latent_output_proj = nn.Linear(config.dim_model, config.latent_dim * 2)
# Fixed sinusoidal positional embedding for the input to the VAE encoder. Unsqueeze for batch
# dimension.
num_input_token_encoder = 1 + config.chunk_size
if self.config.robot_state_feature:
if _robot_ft is not None:
num_input_token_encoder += 1
self.register_buffer(
"vae_encoder_pos_enc",
Expand All @@ -352,14 +363,12 @@ def __init__(self, config: ACTConfig):

# Transformer encoder input projections. The tokens will be structured like
# [latent, (robot_state), (env_state), (image_feature_map_pixels)].
if self.config.robot_state_feature:
self.encoder_robot_state_input_proj = nn.Linear(
self.config.robot_state_feature.shape[0], config.dim_model
)
if self.config.env_state_feature:
self.encoder_env_state_input_proj = nn.Linear(
self.config.env_state_feature.shape[0], config.dim_model
)
_robot_ft = self.config.robot_state_feature
if _robot_ft is not None:
self.encoder_robot_state_input_proj = nn.Linear(_robot_ft.shape[0], config.dim_model)
_env_ft = self.config.env_state_feature
if _env_ft is not None:
self.encoder_env_state_input_proj = nn.Linear(_env_ft.shape[0], config.dim_model)
self.encoder_latent_input_proj = nn.Linear(config.latent_dim, config.dim_model)
if self.config.image_features:
self.encoder_img_feat_input_proj = nn.Conv2d(
Expand All @@ -380,7 +389,9 @@ def __init__(self, config: ACTConfig):
self.decoder_pos_embed = nn.Embedding(config.chunk_size, config.dim_model)

# Final action regression head on the output of the transformer's decoder.
self.action_head = nn.Linear(config.dim_model, self.config.action_feature.shape[0])
_action_ft = self.config.action_feature
assert _action_ft is not None
self.action_head = nn.Linear(config.dim_model, _action_ft.shape[0])

self._reset_parameters()

Expand Down Expand Up @@ -755,12 +766,12 @@ def forward(self, x: Tensor) -> Tensor:
return pos_embed


def get_activation_fn(activation: str) -> Callable:
def get_activation_fn(activation: str) -> Callable[..., Any]:
"""Return an activation function given a string."""
if activation == "relu":
return F.relu
return cast(Callable[..., Any], F.relu)
if activation == "gelu":
return F.gelu
return cast(Callable[..., Any], F.gelu)
if activation == "glu":
return F.glu
return cast(Callable[..., Any], F.glu)
raise RuntimeError(f"activation should be relu/gelu/glu, not {activation}.")
Loading