Skip to content

Commit a024d33

Browse files
fix(bug): Fix policy renaming ValueError during training (#2278)
* fixes * style * Update src/lerobot/policies/factory.py Co-authored-by: Copilot <[email protected]> Signed-off-by: Jade Choghari <[email protected]> * style * add review fixes --------- Signed-off-by: Jade Choghari <[email protected]> Co-authored-by: Copilot <[email protected]>
1 parent 63cd211 commit a024d33

File tree

4 files changed

+27
-0
lines changed

4 files changed

+27
-0
lines changed

src/lerobot/configs/train.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@ class TrainPipelineConfig(HubMixin):
6565
eval: EvalConfig = field(default_factory=EvalConfig)
6666
wandb: WandBConfig = field(default_factory=WandBConfig)
6767
checkpoint_path: Path | None = field(init=False, default=None)
68+
# Rename map for the observation to override the image and state keys
69+
rename_map: dict[str, str] = field(default_factory=dict)
6870

6971
def validate(self) -> None:
7072
# HACK: We parse again the cli args here to get the pretrained paths if there was some.

src/lerobot/policies/factory.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,7 @@ def make_policy(
303303
cfg: PreTrainedConfig,
304304
ds_meta: LeRobotDatasetMetadata | None = None,
305305
env_cfg: EnvConfig | None = None,
306+
rename_map: dict[str, str] | None = None,
306307
) -> PreTrainedPolicy:
307308
"""
308309
Instantiate a policy model.
@@ -319,6 +320,8 @@ def make_policy(
319320
statistics for normalization layers.
320321
env_cfg: Environment configuration used to infer feature shapes and types.
321322
One of `ds_meta` or `env_cfg` must be provided.
323+
rename_map: Optional mapping of dataset or environment feature keys to match
324+
expected policy feature names (e.g., `"left"` → `"camera1"`).
322325
323326
Returns:
324327
An instantiated and device-placed policy model.
@@ -380,4 +383,21 @@ def make_policy(
380383

381384
# policy = torch.compile(policy, mode="reduce-overhead")
382385

386+
if not rename_map:
387+
expected_features = set(cfg.input_features.keys()) | set(cfg.output_features.keys())
388+
provided_features = set(features.keys())
389+
if expected_features and provided_features != expected_features:
390+
missing = expected_features - provided_features
391+
extra = provided_features - expected_features
392+
# TODO (jadechoghari): provide a dynamic rename map suggestion to the user.
393+
raise ValueError(
394+
f"Feature mismatch between dataset/environment and policy config.\n"
395+
f"- Missing features: {sorted(missing) if missing else 'None'}\n"
396+
f"- Extra features: {sorted(extra) if extra else 'None'}\n\n"
397+
f"Please ensure your dataset and policy use consistent feature names.\n"
398+
f"If your dataset uses different observation keys (e.g., cameras named differently), "
399+
f"use the `--rename_map` argument, for example:\n"
400+
f' --rename_map=\'{{"observation.images.left": "observation.images.camera1", '
401+
f'"observation.images.top": "observation.images.camera2"}}\''
402+
)
383403
return policy

src/lerobot/scripts/lerobot_eval.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -501,6 +501,7 @@ def eval_main(cfg: EvalPipelineConfig):
501501
policy = make_policy(
502502
cfg=cfg.policy,
503503
env_cfg=cfg.env,
504+
rename_map=cfg.rename_map,
504505
)
505506

506507
policy.eval()

src/lerobot/scripts/lerobot_train.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
203203
policy = make_policy(
204204
cfg=cfg.policy,
205205
ds_meta=dataset.meta,
206+
rename_map=cfg.rename_map,
206207
)
207208

208209
# Wait for all processes to finish policy creation before continuing
@@ -224,6 +225,9 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
224225
"norm_map": policy.config.normalization_mapping,
225226
},
226227
}
228+
processor_kwargs["preprocessor_overrides"]["rename_observations_processor"] = {
229+
"rename_map": cfg.rename_map
230+
}
227231
postprocessor_kwargs["postprocessor_overrides"] = {
228232
"unnormalizer_processor": {
229233
"stats": dataset.meta.stats,

0 commit comments

Comments
 (0)