Skip to content

Commit 2e5908c

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent ceec6f1 commit 2e5908c

File tree

5 files changed

+66
-72
lines changed

5 files changed

+66
-72
lines changed

src/lerobot/policies/octo/base.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515
from dataclasses import dataclass
1616
from enum import Enum
1717
from fnmatch import fnmatch
18-
from typing import Any, Dict, List, Mapping
18+
from typing import Any, Dict, List
19+
from collections.abc import Mapping
20+
1921
import torch
2022

2123

@@ -38,7 +40,7 @@ class AttentionRule(Enum):
3840
}
3941

4042

41-
def find_match(pattern_dict: Dict[str, Any], name: str, default: Any) -> Any:
43+
def find_match(pattern_dict: dict[str, Any], name: str, default: Any) -> Any:
4244
"""Find the first matching pattern in the dictionary, or return the default value."""
4345
for pattern, value in pattern_dict.items():
4446
if fnmatch(name, pattern):
@@ -61,7 +63,7 @@ def __post_init__(self):
6163
)
6264

6365
@classmethod
64-
def concatenate(cls, group_list: List["TokenGroup"], axis: int = -2) -> "TokenGroup":
66+
def concatenate(cls, group_list: list["TokenGroup"], axis: int = -2) -> "TokenGroup":
6567
"""Concatenates a list of TokenGroups along a specified axis."""
6668
if not group_list:
6769
raise ValueError("Cannot concatenate an empty list of TokenGroups")

src/lerobot/policies/octo/diffusion.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ class MLPResNetBlock(nn.Module):
8282
"""Implementation of MLPResNetBlock."""
8383

8484
def __init__(
85-
self, features: int, activation, dropout_rate: Optional[float] = None, use_layer_norm: bool = False
85+
self, features: int, activation, dropout_rate: float | None = None, use_layer_norm: bool = False
8686
):
8787
super().__init__()
8888
self.features = features
@@ -173,7 +173,7 @@ def __init__(
173173
num_blocks: int,
174174
out_dim: int,
175175
in_dim: int,
176-
dropout_rate: Optional[float] = None,
176+
dropout_rate: float | None = None,
177177
use_layer_norm: bool = False,
178178
hidden_dim: int = 256,
179179
activation=nn.SiLU,
@@ -280,9 +280,9 @@ def _cosine_beta_schedule(self, timesteps: int, s: float = 0.008) -> torch.Tenso
280280

281281
def forward(
282282
self,
283-
transformer_outputs: Dict[str, TokenGroup],
284-
time: Optional[torch.Tensor] = None,
285-
noisy_actions: Optional[torch.Tensor] = None,
283+
transformer_outputs: dict[str, TokenGroup],
284+
time: torch.Tensor | None = None,
285+
noisy_actions: torch.Tensor | None = None,
286286
) -> torch.Tensor:
287287
"""Performs a single forward pass through the diffusion model."""
288288

@@ -336,8 +336,8 @@ def loss(self, transformer_outputs, actions, timestep_pad_mask, action_pad_mask)
336336

337337
def predict_action(
338338
self,
339-
transformer_outputs: Dict[str, TokenGroup],
340-
embodiment_action_dim: Optional[int] = None,
339+
transformer_outputs: dict[str, TokenGroup],
340+
embodiment_action_dim: int | None = None,
341341
sample_shape: tuple = (),
342342
) -> torch.Tensor:
343343
"""Convenience method for predicting actions for the final timestep."""

src/lerobot/policies/octo/modeling_octo.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@
3939
"""
4040

4141
from collections import deque
42-
from typing import Dict, Optional, Sequence
42+
from typing import Dict, Optional
43+
from collections.abc import Sequence
4344

4445
import torch
4546
import torch.nn as nn
@@ -216,7 +217,7 @@ def from_pretrained(cls, *args, **kwargs):
216217
return super().from_pretrained(*args, **kwargs)
217218

218219
def _prepare_batch(
219-
self, batch: dict[str, Tensor], raw_tasks: Optional[Sequence[str]] = None
220+
self, batch: dict[str, Tensor], raw_tasks: Sequence[str] | None = None
220221
) -> dict[str, Tensor]:
221222
"""
222223
Prepare batch for model input.
@@ -330,9 +331,9 @@ def _prepare_batch(
330331

331332
def create_tasks(
332333
self,
333-
goals: Optional[Dict[str, torch.Tensor]] = None,
334-
texts: Optional[Sequence[str]] = None,
335-
device: Optional[torch.device] = None,
334+
goals: dict[str, torch.Tensor] | None = None,
335+
texts: Sequence[str] | None = None,
336+
device: torch.device | None = None,
336337
):
337338
"""Creates tasks dict from goals and texts."""
338339
assert goals is not None or texts is not None
@@ -401,7 +402,7 @@ def _get_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
401402
return actions
402403

403404
@torch.no_grad()
404-
def predict_action_chunk(self, batch: dict[str, Tensor], tasks: Optional[Sequence[str]] = None) -> Tensor:
405+
def predict_action_chunk(self, batch: dict[str, Tensor], tasks: Sequence[str] | None = None) -> Tensor:
405406
"""Predict a chunk of actions given environment observations."""
406407
self.eval()
407408

@@ -499,10 +500,10 @@ def __init__(self, config: OctoConfig):
499500

500501
def forward(
501502
self,
502-
observations: Dict[str, torch.Tensor],
503-
tasks: Dict[str, torch.Tensor],
503+
observations: dict[str, torch.Tensor],
504+
tasks: dict[str, torch.Tensor],
504505
timestep_pad_mask: torch.Tensor,
505-
embodiment_action_dim: Optional[int] = None,
506+
embodiment_action_dim: int | None = None,
506507
) -> torch.Tensor:
507508
transformer_outputs = self.octo_transformer(observations, tasks, timestep_pad_mask)
508509
actions = self.head.predict_action(

src/lerobot/policies/octo/tokenizers.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
# limitations under the License.
1616

1717
import re
18-
from typing import Dict, List, Optional, Sequence, Tuple
18+
from typing import Dict, List, Optional, Tuple
19+
from collections.abc import Sequence
1920

2021
import numpy as np
2122
import torch
@@ -112,10 +113,10 @@ def __init__(
112113
self,
113114
use_film: bool = False,
114115
patch_size: int = 32,
115-
kernel_sizes: Tuple[int, ...] = (3, 3, 3, 3),
116-
strides: Tuple[int, ...] = (2, 2, 2, 2),
117-
features: Tuple[int, ...] = (32, 96, 192, 384),
118-
padding: Tuple[int, ...] = (1, 1, 1, 1),
116+
kernel_sizes: tuple[int, ...] = (3, 3, 3, 3),
117+
strides: tuple[int, ...] = (2, 2, 2, 2),
118+
features: tuple[int, ...] = (32, 96, 192, 384),
119+
padding: tuple[int, ...] = (1, 1, 1, 1),
119120
num_features: int = 512,
120121
img_norm_type: str = "default",
121122
):
@@ -167,7 +168,7 @@ def __init__(
167168
self.film = FilmConditioning() if use_film else None
168169

169170
def forward(
170-
self, observations: torch.Tensor, train: bool = True, cond_var: Optional[torch.Tensor] = None
171+
self, observations: torch.Tensor, train: bool = True, cond_var: torch.Tensor | None = None
171172
):
172173
"""
173174
Args:
@@ -212,10 +213,10 @@ class SmallStem16(SmallStem):
212213
def __init__(
213214
self,
214215
use_film: bool = False,
215-
kernel_sizes: Tuple[int, ...] = (3, 3, 3, 3),
216-
strides: Tuple[int, ...] = (2, 2, 2, 2),
217-
features: Tuple[int, ...] = (32, 96, 192, 384),
218-
padding: Tuple[int, ...] = (1, 1, 1, 1),
216+
kernel_sizes: tuple[int, ...] = (3, 3, 3, 3),
217+
strides: tuple[int, ...] = (2, 2, 2, 2),
218+
features: tuple[int, ...] = (32, 96, 192, 384),
219+
padding: tuple[int, ...] = (1, 1, 1, 1),
219220
num_features: int = 512,
220221
img_norm_type: str = "default",
221222
):
@@ -243,7 +244,7 @@ def regex_filter(regex_keys, xs):
243244

244245
def generate_proper_pad_mask(
245246
tokens: torch.Tensor,
246-
pad_mask_dict: Optional[Dict[str, torch.Tensor]],
247+
pad_mask_dict: dict[str, torch.Tensor] | None,
247248
keys: Sequence[str],
248249
) -> torch.Tensor:
249250
"""Generate proper padding mask for tokens."""
@@ -286,8 +287,8 @@ def __init__(
286287

287288
def forward(
288289
self,
289-
observations: Dict[str, torch.Tensor],
290-
tasks: Optional[Dict[str, torch.Tensor]] = None,
290+
observations: dict[str, torch.Tensor],
291+
tasks: dict[str, torch.Tensor] | None = None,
291292
):
292293
"""Forward pass through image tokenizer."""
293294

@@ -382,7 +383,7 @@ def __init__(self, finetune_encoder: bool = False, proper_pad_mask: bool = True)
382383
for param in self.t5_encoder.parameters():
383384
param.requires_grad = False
384385

385-
def forward(self, language_input: Dict[str, torch.Tensor], tasks=None) -> TokenGroup:
386+
def forward(self, language_input: dict[str, torch.Tensor], tasks=None) -> TokenGroup:
386387
outputs = self.t5_encoder(
387388
input_ids=language_input["input_ids"], attention_mask=language_input["attention_mask"]
388389
)
@@ -411,7 +412,7 @@ def forward(self, language_input: Dict[str, torch.Tensor], tasks=None) -> TokenG
411412
class TextProcessor:
412413
"""HF Tokenizer wrapper."""
413414

414-
def __init__(self, tokenizer_name: str = "t5-base", tokenizer_kwargs: Optional[Dict] = None):
415+
def __init__(self, tokenizer_name: str = "t5-base", tokenizer_kwargs: dict | None = None):
415416
if tokenizer_kwargs is None:
416417
tokenizer_kwargs = {
417418
"max_length": 16,
@@ -423,6 +424,6 @@ def __init__(self, tokenizer_name: str = "t5-base", tokenizer_kwargs: Optional[D
423424
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
424425
self.tokenizer_kwargs = tokenizer_kwargs
425426

426-
def encode(self, strings: List[str]) -> Dict[str, torch.Tensor]:
427+
def encode(self, strings: list[str]) -> dict[str, torch.Tensor]:
427428
"""Encode strings to token IDs and attention masks."""
428429
return self.tokenizer(strings, **self.tokenizer_kwargs)

0 commit comments

Comments
 (0)