-
Notifications
You must be signed in to change notification settings - Fork 2.5k
mypy: Make Policies module MyPy‑compliant (#1719, #1720) #1805
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
- --config-file=pyproject.toml | ||
- --install-types | ||
- --non-interactive | ||
- --ignore-missing-imports |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
--ignore-missing-import
is true here instead of false, we should remove this flag right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thx for the review. :) Yes, we should remove --ignore-missing-imports.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey thank you for your contribution, I started the review for Act policy 😄
I will do the next policy later 🤗
@@ -89,7 +94,7 @@ def __post_init__(self): | |||
|
|||
@property | |||
def type(self) -> str: | |||
return self.get_choice_name(self.__class__) | |||
return cast(str, self.get_choice_name(self.__class__)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cast
should be always justified in the code, this is unsafe.
In this case it should be ok.
@@ -204,4 +210,4 @@ def from_pretrained( | |||
|
|||
cli_overrides = policy_kwargs.pop("cli_overrides", []) | |||
with draccus.config_type("json"): | |||
return draccus.parse(orig_config.__class__, config_file, args=cli_overrides) | |||
return cast(T, draccus.parse(orig_config.__class__, config_file, args=cli_overrides)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ditto: we should always justify cast.
Question this cast is necessary?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, i think so. this cast is necessary for MyPy compliance.
Without it, MyPy reports: Incompatible return value type (got "PreTrainedConfig", expected "T")
The issue is that while orig_config.__class__
and cls
are the same type at runtime, MyPy can't statically infer that draccus.parse(orig_config.__class__, ...)
will return the specific generic type T
rather than the base PreTrainedConfig
type.
This is a use of cast()
to help the type checker understand a type relationship that's guaranteed at runtime but not statically provable.
@@ -79,7 +81,7 @@ def __init__( | |||
|
|||
self.reset() | |||
|
|||
def get_optim_params(self) -> dict: | |||
def get_optim_params(self) -> object: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using object here feels too generic and makes the typing less helpful. A stricter approach would be:
Option A (inline type):
Sequence[dict[str, list[nn.Parameter] | float]]
Option B (preferred, clearer): define a dedicated TypedDict
for optimizer kwargs and then use:
from typing import TypedDict, NotRequired
from collections.abc import Sequence
from torch import nn
class OptimizerKwargs(TypedDict, total=False):
params: list[nn.Parameter]
lr: float
weight_decay: float
momentum: float
betas: tuple[float, float]
eps: float
BatchKwargs = Sequence[OptimizerKwargs]
Option B has the advantage of being more self-documenting and makes it easier to see what fields are expected in the optimizer kwargs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're right.
I'll go with Option B (TypedDict) as you suggested. It makes it clear what the optimizer expects. I'll define the OptimizerKwargs
TypedDict and use Sequence[OptimizerKwargs]
as the return type.
Thanks for the detailed suggestion.
# Ensure batch tensors are on the same device as the module | ||
device = get_device_from_parameters(self) | ||
batch = {k: (v.to(device) if torch.is_tensor(v) else v) for k, v in batch.items()} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This addition is not related to mypy type?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, you are right sorry for that. This device handling code is not related to MyPy typing. This was accidentally included from another change I was working on.
I'll remove this addition and keep this PR focused solely on type annotations and MyPy compliance.
Thanks for catching this.
@@ -146,8 +150,11 @@ def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor: | |||
actions = self.unnormalize_outputs({ACTION: actions})[ACTION] | |||
return actions | |||
|
|||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]: | |||
def forward(self, batch: dict[str, Tensor], **kwargs: object) -> tuple[Tensor, dict]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
**kwargs
should be typed using typing.TypedDict
and typing_extensions.Unpack
.
# Ensure batch tensors are on the same device as the module | ||
device = get_device_from_parameters(self) | ||
batch = {k: (v.to(device) if torch.is_tensor(v) else v) for k, v in batch.items()} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
def get_activation_fn(activation: str) -> Callable[..., Any]: | ||
"""Return an activation function given a string.""" | ||
if activation == "relu": | ||
return F.relu | ||
return cast(Callable[..., Any], F.relu) | ||
if activation == "gelu": | ||
return F.gelu | ||
return cast(Callable[..., Any], F.gelu) | ||
if activation == "glu": | ||
return F.glu | ||
return cast(Callable[..., Any], F.glu) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are cast
necessary here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These casts are unnecessary. "F.relu", "F.gelu", and "F.glu" are already properly typed as callables in PyTorch, so MyPy doesn't complain without the casts. I was being overly cautious here.
We can remove these casts and keep the function signatures clean:
def get_activation_fn(activation: str) -> Callable[..., Any]:
if activation == "relu":
return F.relu
if activation == "gelu":
return F.gelu
if activation == "glu":
return F.glu
Thanks for catching this unnecessary complexity. :)
What this does
device semantics.
imports in the hook venv to avoid heavy deps). No global behavior change.
How it was tested
backward_compatibility'
How to checkout & try? (for the reviewer)
From repo root
export LEROBOT_TEST_DEVICE=cpu
mkdir -p .cache/hf_datasets .cache/hf_home
export HF_DATASETS_CACHE=$(pwd)/.cache/hf_datasets
export HF_HOME=$(pwd)/.cache/hf_home
Static checks
pre-commit run -a
mypy
SAC tests
PYTHONPATH=src pytest -q tests/policies -k sac
ACT back-compat tests
PYTHONPATH=src pytest -q tests/policies -k 'act and backward_compatibility'
If you’d like me to extend MyPy coverage to the next module (e.g., Datasets) in a follow‑up PR, I can prepare the scoped
config and initial fixes similarly.
cc @CadeRemi @AdilZouitine for review