Skip to content

Commit 121ad28

Browse files
committed
feat(pipeline): Enhance configuration filename handling and state file naming
- Introduced support for custom configuration filenames in the `save_pretrained` method, allowing users to specify a filename instead of the default. - Improved state file naming to include step indices, preventing conflicts when multiple processors of the same type are saved. - Added automatic detection for configuration files when loading from a directory, with error handling for multiple files. - Updated tests to validate new features, including custom filenames and automatic config detection.
1 parent cc7cf72 commit 121ad28

File tree

3 files changed

+535
-16
lines changed

3 files changed

+535
-16
lines changed

src/lerobot/processor/pipeline.py

Lines changed: 80 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
import torch
2828
from huggingface_hub import ModelHubMixin, hf_hub_download
29+
from huggingface_hub.errors import HfHubHTTPError
2930
from safetensors.torch import load_file, save_file
3031

3132
from lerobot.utils.utils import get_safe_torch_device
@@ -293,8 +294,6 @@ class RobotProcessor(ModelHubMixin):
293294
after_step_hooks: list[Callable[[int, EnvTransition], None]] = field(default_factory=list, repr=False)
294295
reset_hooks: list[Callable[[], None]] = field(default_factory=list, repr=False)
295296

296-
_CFG_NAME = "processor.json"
297-
298297
def __call__(self, data: EnvTransition | dict[str, Any]):
299298
"""Process data through all steps.
300299
@@ -386,7 +385,9 @@ def step_through(self, data: EnvTransition | dict[str, Any]) -> Iterable[EnvTran
386385

387386
def _save_pretrained(self, destination_path: str, **kwargs):
388387
"""Internal save method for ModelHubMixin compatibility."""
389-
self.save_pretrained(destination_path)
388+
# Extract config_filename from kwargs if provided
389+
config_filename = kwargs.pop("config_filename", None)
390+
self.save_pretrained(destination_path, config_filename=config_filename)
390391

391392
def _generate_model_card(self, destination_path: str) -> None:
392393
"""Generate README.md from the RobotProcessor model card template."""
@@ -405,10 +406,24 @@ def _generate_model_card(self, destination_path: str) -> None:
405406
with open(readme_path, "w") as f:
406407
f.write(model_card_content)
407408

408-
def save_pretrained(self, destination_path: str, **kwargs):
409-
"""Serialize the processor definition and parameters to *destination_path*."""
409+
def save_pretrained(self, destination_path: str, config_filename: str | None = None, **kwargs):
410+
"""Serialize the processor definition and parameters to *destination_path*.
411+
412+
Args:
413+
destination_path: Directory where the processor will be saved.
414+
config_filename: Optional custom config filename. If not provided, defaults to
415+
"{self.name}.json" where self.name is sanitized for filesystem compatibility.
416+
"""
410417
os.makedirs(destination_path, exist_ok=True)
411418

419+
# Determine config filename - sanitize the processor name for filesystem
420+
if config_filename is None:
421+
# Sanitize name - replace any character that's not alphanumeric or underscore
422+
import re
423+
424+
sanitized_name = re.sub(r"[^a-zA-Z0-9_]", "_", self.name.lower())
425+
config_filename = f"{sanitized_name}.json"
426+
412427
config: dict[str, Any] = {
413428
"name": self.name,
414429
"seed": self.seed,
@@ -448,9 +463,10 @@ def save_pretrained(self, destination_path: str, **kwargs):
448463
for key, tensor in state.items():
449464
cloned_state[key] = tensor.clone()
450465

451-
# Use registry name for more meaningful filenames when available
466+
# Always include step index to ensure unique filenames
467+
# This prevents conflicts when the same processor type is used multiple times
452468
if registry_name:
453-
state_filename = f"{registry_name}.safetensors"
469+
state_filename = f"step_{step_index}_{registry_name}.safetensors"
454470
else:
455471
state_filename = f"step_{step_index}.safetensors"
456472

@@ -459,7 +475,7 @@ def save_pretrained(self, destination_path: str, **kwargs):
459475

460476
config["steps"].append(step_entry)
461477

462-
with open(os.path.join(destination_path, self._CFG_NAME), "w") as file_pointer:
478+
with open(os.path.join(destination_path, config_filename), "w") as file_pointer:
463479
json.dump(config, file_pointer, indent=2)
464480

465481
# Generate README.md from template
@@ -484,12 +500,17 @@ def to(self, device: str | torch.device):
484500
return self
485501

486502
@classmethod
487-
def from_pretrained(cls, source: str, *, overrides: dict[str, Any] | None = None) -> RobotProcessor:
503+
def from_pretrained(
504+
cls, source: str, *, config_filename: str | None = None, overrides: dict[str, Any] | None = None
505+
) -> RobotProcessor:
488506
"""Load a serialized processor from source (local path or Hugging Face Hub identifier).
489507
490508
Args:
491509
source: Local path to a saved processor directory or Hugging Face Hub identifier
492510
(e.g., "username/processor-name").
511+
config_filename: Optional specific config filename to load. If not provided, will:
512+
- For local paths: look for any .json file in the directory (error if multiple found)
513+
- For HF Hub: try common names ("processor.json", "preprocessor.json", "postprocessor.json")
493514
overrides: Optional dictionary mapping step names to configuration overrides.
494515
Keys must match exact step class names (for unregistered steps) or registry names
495516
(for registered steps). Values are dictionaries containing parameter overrides
@@ -510,6 +531,13 @@ def from_pretrained(cls, source: str, *, overrides: dict[str, Any] | None = None
510531
processor = RobotProcessor.from_pretrained("path/to/processor")
511532
```
512533
534+
Loading specific config file:
535+
```python
536+
processor = RobotProcessor.from_pretrained(
537+
"username/multi-processor-repo", config_filename="preprocessor.json"
538+
)
539+
```
540+
513541
Loading with overrides for non-serializable objects:
514542
```python
515543
import gym
@@ -534,12 +562,52 @@ def from_pretrained(cls, source: str, *, overrides: dict[str, Any] | None = None
534562
if Path(source).is_dir():
535563
# Local path - use it directly
536564
base_path = Path(source)
537-
with open(base_path / cls._CFG_NAME) as file_pointer:
565+
566+
if config_filename is None:
567+
# Look for any .json file in the directory
568+
json_files = list(base_path.glob("*.json"))
569+
if len(json_files) == 0:
570+
raise FileNotFoundError(f"No .json configuration files found in {source}")
571+
elif len(json_files) > 1:
572+
raise ValueError(
573+
f"Multiple .json files found in {source}: {[f.name for f in json_files]}. "
574+
f"Please specify which one to load using the config_filename parameter."
575+
)
576+
config_filename = json_files[0].name
577+
578+
with open(base_path / config_filename) as file_pointer:
538579
config: dict[str, Any] = json.load(file_pointer)
539580
else:
540581
# Hugging Face Hub - download all required files
541-
# First download the config file
542-
config_path = hf_hub_download(source, cls._CFG_NAME, repo_type="model")
582+
if config_filename is None:
583+
# Try common config names
584+
common_names = [
585+
"processor.json",
586+
"preprocessor.json",
587+
"postprocessor.json",
588+
"robotprocessor.json",
589+
]
590+
config_path = None
591+
for name in common_names:
592+
try:
593+
config_path = hf_hub_download(source, name, repo_type="model")
594+
config_filename = name
595+
break
596+
except (FileNotFoundError, OSError, HfHubHTTPError):
597+
# FileNotFoundError: local file issues
598+
# OSError: network/system errors
599+
# HfHubHTTPError: file not found on Hub (404) or other HTTP errors
600+
continue
601+
602+
if config_path is None:
603+
raise FileNotFoundError(
604+
f"No processor configuration file found in {source}. "
605+
f"Tried: {common_names}. Please specify the config_filename parameter."
606+
)
607+
else:
608+
# Download specific config file
609+
config_path = hf_hub_download(source, config_filename, repo_type="model")
610+
543611
with open(config_path) as file_pointer:
544612
config: dict[str, Any] = json.load(file_pointer)
545613

0 commit comments

Comments
 (0)