Skip to content

Commit 3b4d846

Browse files
committed
refactor(pipeline): Improve state file naming conventions for clarity and uniqueness
- Enhanced state file naming to include the processor's sanitized name, ensuring uniqueness when multiple processors are saved in the same directory. - Updated tests to reflect changes in state file naming, verifying that filenames now include the processor name and step indices to prevent conflicts. - Added a new test to validate state file naming when using multiple processors, ensuring distinct filenames for each processor's state files.
1 parent 121ad28 commit 3b4d846

File tree

2 files changed

+58
-17
lines changed

2 files changed

+58
-17
lines changed

src/lerobot/processor/pipeline.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -416,12 +416,13 @@ def save_pretrained(self, destination_path: str, config_filename: str | None = N
416416
"""
417417
os.makedirs(destination_path, exist_ok=True)
418418

419-
# Determine config filename - sanitize the processor name for filesystem
420-
if config_filename is None:
421-
# Sanitize name - replace any character that's not alphanumeric or underscore
422-
import re
419+
# Sanitize processor name for use in filenames
420+
import re
421+
422+
sanitized_name = re.sub(r"[^a-zA-Z0-9_]", "_", self.name.lower())
423423

424-
sanitized_name = re.sub(r"[^a-zA-Z0-9_]", "_", self.name.lower())
424+
# Use sanitized name for config if not provided
425+
if config_filename is None:
425426
config_filename = f"{sanitized_name}.json"
426427

427428
config: dict[str, Any] = {
@@ -463,12 +464,12 @@ def save_pretrained(self, destination_path: str, config_filename: str | None = N
463464
for key, tensor in state.items():
464465
cloned_state[key] = tensor.clone()
465466

466-
# Always include step index to ensure unique filenames
467-
# This prevents conflicts when the same processor type is used multiple times
467+
# Include pipeline name and step index to ensure unique filenames
468+
# This prevents conflicts when multiple processors are saved in the same directory
468469
if registry_name:
469-
state_filename = f"step_{step_index}_{registry_name}.safetensors"
470+
state_filename = f"{sanitized_name}_step_{step_index}_{registry_name}.safetensors"
470471
else:
471-
state_filename = f"step_{step_index}.safetensors"
472+
state_filename = f"{sanitized_name}_step_{step_index}.safetensors"
472473

473474
save_file(cloned_state, os.path.join(destination_path, state_filename))
474475
step_entry["state_file"] = state_filename

tests/processor/test_pipeline.py

Lines changed: 48 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -630,7 +630,7 @@ def test_mixed_json_and_tensor_state():
630630

631631
# Check that both config and state files were created
632632
config_path = Path(tmp_dir) / "robotprocessor.json" # Default name is "RobotProcessor"
633-
state_path = Path(tmp_dir) / "step_0.safetensors"
633+
state_path = Path(tmp_dir) / "robotprocessor_step_0.safetensors"
634634
assert config_path.exists()
635635
assert state_path.exists()
636636

@@ -1735,7 +1735,7 @@ def test_error_multiple_configs_no_filename():
17351735

17361736

17371737
def test_state_file_naming_with_indices():
1738-
"""Test that state files include step indices to avoid conflicts."""
1738+
"""Test that state files include pipeline name and step indices to avoid conflicts."""
17391739
# Create multiple steps of same type with state
17401740
step1 = MockStepWithTensorState(name="norm1", window_size=5)
17411741
step2 = MockStepWithTensorState(name="norm2", window_size=10)
@@ -1755,14 +1755,18 @@ def test_state_file_naming_with_indices():
17551755
state_files = sorted(Path(tmp_dir).glob("*.safetensors"))
17561756
assert len(state_files) == 3
17571757

1758-
# Files should be named with indices
1759-
expected_names = ["step_0.safetensors", "step_1.safetensors", "step_2.safetensors"]
1758+
# Files should be named with pipeline name prefix and indices
1759+
expected_names = [
1760+
"robotprocessor_step_0.safetensors",
1761+
"robotprocessor_step_1.safetensors",
1762+
"robotprocessor_step_2.safetensors",
1763+
]
17601764
actual_names = [f.name for f in state_files]
17611765
assert actual_names == expected_names
17621766

17631767

17641768
def test_state_file_naming_with_registry():
1765-
"""Test state file naming for registered steps includes both index and name."""
1769+
"""Test state file naming for registered steps includes pipeline name, index and registry name."""
17661770

17671771
# Register a test step
17681772
@ProcessorStepRegistry.register("test_stateful_step")
@@ -1799,10 +1803,10 @@ def load_state_dict(self, state):
17991803
state_files = sorted(Path(tmp_dir).glob("*.safetensors"))
18001804
assert len(state_files) == 2
18011805

1802-
# Should include both index and registry name
1806+
# Should include pipeline name, index and registry name
18031807
expected_names = [
1804-
"step_0_test_stateful_step.safetensors",
1805-
"step_1_test_stateful_step.safetensors",
1808+
"robotprocessor_step_0_test_stateful_step.safetensors",
1809+
"robotprocessor_step_1_test_stateful_step.safetensors",
18061810
]
18071811
actual_names = [f.name for f in state_files]
18081812
assert actual_names == expected_names
@@ -1995,6 +1999,42 @@ def test_config_filename_special_characters():
19951999
assert json_files[0].name == expected_name
19962000

19972001

2002+
def test_state_file_naming_with_multiple_processors():
2003+
"""Test that state files are properly prefixed with pipeline names to avoid conflicts."""
2004+
# Create two processors with state
2005+
step1 = MockStepWithTensorState(name="norm", window_size=5)
2006+
preprocessor = RobotProcessor([step1], name="PreProcessor")
2007+
2008+
step2 = MockStepWithTensorState(name="norm", window_size=10)
2009+
postprocessor = RobotProcessor([step2], name="PostProcessor")
2010+
2011+
# Process some data to create state
2012+
for i in range(3):
2013+
transition = create_transition(reward=float(i))
2014+
preprocessor(transition)
2015+
postprocessor(transition)
2016+
2017+
with tempfile.TemporaryDirectory() as tmp_dir:
2018+
# Save both processors to the same directory
2019+
preprocessor.save_pretrained(tmp_dir)
2020+
postprocessor.save_pretrained(tmp_dir)
2021+
2022+
# Check that all files exist and are distinct
2023+
assert (Path(tmp_dir) / "preprocessor.json").exists()
2024+
assert (Path(tmp_dir) / "postprocessor.json").exists()
2025+
assert (Path(tmp_dir) / "preprocessor_step_0.safetensors").exists()
2026+
assert (Path(tmp_dir) / "postprocessor_step_0.safetensors").exists()
2027+
2028+
# Load both back and verify they work correctly
2029+
loaded_pre = RobotProcessor.from_pretrained(tmp_dir, config_filename="preprocessor.json")
2030+
loaded_post = RobotProcessor.from_pretrained(tmp_dir, config_filename="postprocessor.json")
2031+
2032+
assert loaded_pre.name == "PreProcessor"
2033+
assert loaded_post.name == "PostProcessor"
2034+
assert loaded_pre.steps[0].window_size == 5
2035+
assert loaded_post.steps[0].window_size == 10
2036+
2037+
19982038
def test_override_with_device_strings():
19992039
"""Test overriding device parameters with string values."""
20002040

0 commit comments

Comments
 (0)