diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index ec419d557f..3565527337 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -66,6 +66,7 @@ write_tasks, ) from lerobot.datasets.video_utils import ( + VideoEncodingManager, VideoFrame, concatenate_video_files, decode_video_frames, @@ -1136,8 +1137,9 @@ def save_episode(self, episode_data: dict | None = None) -> None: use_batched_encoding = self.batch_encoding_size > 1 if has_video_keys and not use_batched_encoding: - for video_key in self.meta.video_keys: - ep_metadata.update(self._save_episode_video(video_key, episode_index)) + with VideoEncodingManager(self): + for video_key in self.meta.video_keys: + ep_metadata.update(self._save_episode_video(video_key, episode_index)) # `meta.save_episode` need to be executed after encoding the videos self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats, ep_metadata) @@ -1148,7 +1150,8 @@ def save_episode(self, episode_data: dict | None = None) -> None: if self.episodes_since_last_encoding == self.batch_encoding_size: start_ep = self.num_episodes - self.batch_encoding_size end_ep = self.num_episodes - self._batch_save_episode_video(start_ep, end_ep) + with VideoEncodingManager(self): + self._batch_save_episode_video(start_ep, end_ep) self.episodes_since_last_encoding = 0 if not episode_data: diff --git a/tests/datasets/test_datasets.py b/tests/datasets/test_datasets.py index e174c57896..a71ec5e6a7 100644 --- a/tests/datasets/test_datasets.py +++ b/tests/datasets/test_datasets.py @@ -68,6 +68,19 @@ def image_dataset(tmp_path, empty_lerobot_dataset_factory): return empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) +@pytest.fixture +def video_dataset(tmp_path, empty_lerobot_dataset_factory): + features = { + "image": { + "dtype": "video", + "shape": DUMMY_HWC, + "names": ["height", "width", "channels"], + "info": None, + } + } + return empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) + + def test_same_attributes_defined(tmp_path, lerobot_dataset_factory): """ Instantiate a LeRobotDataset both ways with '__init__()' and 'create()' and verify that instantiated @@ -344,6 +357,17 @@ def test_add_frame_image_pil(image_dataset): assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW) +def test_add_frame_video(video_dataset): + dataset = video_dataset + image = np.random.randint(0, 256, DUMMY_HWC, dtype=np.uint8) + dataset.add_frame({"image": Image.fromarray(image)}, task="Dummy task") + dataset.save_episode() + + assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW) + assert len([str(f) for f in (dataset.root / "videos").glob("**/*.mp4")]) == 1 + assert not (dataset.root / "images").exists() + + def test_image_array_to_pil_image_wrong_range_float_0_255(): image = np.random.rand(*DUMMY_HWC) * 255 with pytest.raises(ValueError):