diff --git a/src/lerobot/datasets/aggregate.py b/src/lerobot/datasets/aggregate.py index e7ea59ed03..870c9571e8 100644 --- a/src/lerobot/datasets/aggregate.py +++ b/src/lerobot/datasets/aggregate.py @@ -31,8 +31,8 @@ DEFAULT_EPISODES_PATH, DEFAULT_VIDEO_FILE_SIZE_IN_MB, DEFAULT_VIDEO_PATH, + get_file_size_in_mb, get_parquet_file_size_in_mb, - get_video_size_in_mb, to_parquet_with_hf_images, update_chunk_file_indices, write_info, @@ -217,6 +217,7 @@ def aggregate_datasets( robot_type=robot_type, features=features, root=aggr_root, + use_videos=len(video_keys) > 0, chunks_size=chunk_size, data_files_size_in_mb=data_files_size_in_mb, video_files_size_in_mb=video_files_size_in_mb, @@ -307,8 +308,9 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu current_offset += src_duration continue - src_size = get_video_size_in_mb(src_path) - dst_size = get_video_size_in_mb(dst_path) + # Check file sizes before appending + src_size = get_file_size_in_mb(src_path) + dst_size = get_file_size_in_mb(dst_path) if dst_size + src_size >= video_files_size_in_mb: # Rotate to a new file, this source becomes start of new destination diff --git a/src/lerobot/datasets/dataset_tools.py b/src/lerobot/datasets/dataset_tools.py index fdeb24a729..8ebc4a59de 100644 --- a/src/lerobot/datasets/dataset_tools.py +++ b/src/lerobot/datasets/dataset_tools.py @@ -42,6 +42,7 @@ DEFAULT_DATA_PATH, DEFAULT_EPISODES_PATH, get_parquet_file_size_in_mb, + load_episodes, to_parquet_with_hf_images, update_chunk_file_indices, write_info, @@ -436,6 +437,9 @@ def _copy_and_reindex_data( Returns: dict mapping episode index to its data file metadata (chunk_index, file_index, etc.) """ + if src_dataset.meta.episodes is None: + src_dataset.meta.episodes = load_episodes(src_dataset.meta.root) + file_to_episodes: dict[Path, set[int]] = {} for old_idx in episode_mapping: file_path = src_dataset.meta.get_data_file_path(old_idx) @@ -645,6 +649,8 @@ def _copy_and_reindex_videos( Returns: dict mapping episode index to its video metadata (chunk_index, file_index, timestamps) """ + if src_dataset.meta.episodes is None: + src_dataset.meta.episodes = load_episodes(src_dataset.meta.root) episodes_video_metadata: dict[int, dict] = {new_idx: {} for new_idx in episode_mapping.values()} @@ -770,6 +776,9 @@ def _copy_and_reindex_episodes_metadata( """ from lerobot.datasets.utils import flatten_dict + if src_dataset.meta.episodes is None: + src_dataset.meta.episodes = load_episodes(src_dataset.meta.root) + all_stats = [] total_frames = 0 @@ -831,6 +840,8 @@ def _copy_and_reindex_episodes_metadata( total_frames += src_episode["length"] + dst_meta._close_writer() + dst_meta.info.update( { "total_episodes": len(episode_mapping), diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index 229d376413..ae142c1e8f 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -14,7 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import contextlib -import gc import logging import shutil import tempfile @@ -26,6 +25,8 @@ import packaging.version import pandas as pd import PIL.Image +import pyarrow as pa +import pyarrow.parquet as pq import torch import torch.utils from huggingface_hub import HfApi, snapshot_download @@ -46,13 +47,9 @@ embed_images, flatten_dict, get_delta_indices, - get_hf_dataset_cache_dir, - get_hf_dataset_size_in_mb, + get_file_size_in_mb, get_hf_features_from_features, - get_parquet_file_size_in_mb, - get_parquet_num_frames, get_safe_version, - get_video_size_in_mb, hf_transform_to_torch, is_valid_version, load_episodes, @@ -60,7 +57,6 @@ load_nested_dataset, load_stats, load_tasks, - to_parquet_with_hf_images, update_chunk_file_indices, validate_episode_buffer, validate_frame, @@ -90,10 +86,15 @@ def __init__( root: str | Path | None = None, revision: str | None = None, force_cache_sync: bool = False, + metadata_buffer_size: int = 10, ): self.repo_id = repo_id self.revision = revision if revision else CODEBASE_VERSION self.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id + self.writer = None + self.latest_episode = None + self.metadata_buffer: list[dict] = [] + self.metadata_buffer_size = metadata_buffer_size try: if force_cache_sync: @@ -107,6 +108,54 @@ def __init__( self.pull_from_repo(allow_patterns="meta/") self.load_metadata() + def _flush_metadata_buffer(self) -> None: + """Write all buffered episode metadata to parquet file.""" + if not hasattr(self, "metadata_buffer") or len(self.metadata_buffer) == 0: + return + + combined_dict = {} + for episode_dict in self.metadata_buffer: + for key, value in episode_dict.items(): + if key not in combined_dict: + combined_dict[key] = [] + # Extract value and serialize numpy arrays + # because PyArrow's from_pydict function doesn't support numpy arrays + val = value[0] if isinstance(value, list) else value + combined_dict[key].append(val.tolist() if isinstance(val, np.ndarray) else val) + + first_ep = self.metadata_buffer[0] + chunk_idx = first_ep["meta/episodes/chunk_index"][0] + file_idx = first_ep["meta/episodes/file_index"][0] + + table = pa.Table.from_pydict(combined_dict) + + if not self.writer: + path = Path(self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx)) + path.parent.mkdir(parents=True, exist_ok=True) + self.writer = pq.ParquetWriter( + path, schema=table.schema, compression="snappy", use_dictionary=True + ) + + self.writer.write_table(table) + + self.latest_episode = self.metadata_buffer[-1] + self.metadata_buffer.clear() + + def _close_writer(self) -> None: + """Close and cleanup the parquet writer if it exists.""" + self._flush_metadata_buffer() + + writer = getattr(self, "writer", None) + if writer is not None: + writer.close() + self.writer = None + + def __del__(self): + """ + Trust the user to call .finalize() but as an added safety check call the parquet writer to stop when calling the destructor + """ + self._close_writer() + def load_metadata(self): self.info = load_info(self.root) check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION) @@ -138,6 +187,12 @@ def _version(self) -> packaging.version.Version: return packaging.version.parse(self.info["codebase_version"]) def get_data_file_path(self, ep_index: int) -> Path: + if self.episodes is None: + self.episodes = load_episodes(self.root) + if ep_index >= len(self.episodes): + raise IndexError( + f"Episode index {ep_index} out of range. Episodes: {len(self.episodes) if self.episodes else 0}" + ) ep = self.episodes[ep_index] chunk_idx = ep["data/chunk_index"] file_idx = ep["data/file_index"] @@ -145,6 +200,12 @@ def get_data_file_path(self, ep_index: int) -> Path: return Path(fpath) def get_video_file_path(self, ep_index: int, vid_key: str) -> Path: + if self.episodes is None: + self.episodes = load_episodes(self.root) + if ep_index >= len(self.episodes): + raise IndexError( + f"Episode index {ep_index} out of range. Episodes: {len(self.episodes) if self.episodes else 0}" + ) ep = self.episodes[ep_index] chunk_idx = ep[f"videos/{vid_key}/chunk_index"] file_idx = ep[f"videos/{vid_key}/file_index"] @@ -260,72 +321,75 @@ def save_episode_tasks(self, tasks: list[str]): write_tasks(self.tasks, self.root) def _save_episode_metadata(self, episode_dict: dict) -> None: - """Save episode metadata to a parquet file and update the Hugging Face dataset of episodes metadata. + """Buffer episode metadata and write to parquet in batches for efficiency. - This function processes episodes metadata from a dictionary, converts it into a Hugging Face dataset, - and saves it as a parquet file. It handles both the creation of new parquet files and the - updating of existing ones based on size constraints. After saving the metadata, it reloads - the Hugging Face dataset to ensure it is up-to-date. + This function accumulates episode metadata in a buffer and flushes it when the buffer + reaches the configured size. This reduces I/O overhead by writing multiple episodes + at once instead of one row at a time. Notes: We both need to update parquet files and HF dataset: - `pandas` loads parquet file in RAM - `datasets` relies on a memory mapping from pyarrow (no RAM). It either converts parquet files to a pyarrow cache on disk, or loads directly from pyarrow cache. """ - # Convert buffer into HF Dataset + # Convert to list format for each value episode_dict = {key: [value] for key, value in episode_dict.items()} - ep_dataset = datasets.Dataset.from_dict(episode_dict) - ep_size_in_mb = get_hf_dataset_size_in_mb(ep_dataset) - df = pd.DataFrame(ep_dataset) num_frames = episode_dict["length"][0] - if self.episodes is None: + if self.latest_episode is None: # Initialize indices and frame count for a new dataset made of the first episode data chunk_idx, file_idx = 0, 0 - df["meta/episodes/chunk_index"] = [chunk_idx] - df["meta/episodes/file_index"] = [file_idx] - df["dataset_from_index"] = [0] - df["dataset_to_index"] = [num_frames] - else: - # Retrieve information from the latest parquet file - latest_ep = self.episodes[-1] - chunk_idx = latest_ep["meta/episodes/chunk_index"] - file_idx = latest_ep["meta/episodes/file_index"] + if self.episodes is not None and len(self.episodes) > 0: + # It means we are resuming recording, so we need to load the latest episode + # Update the indices to avoid overwriting the latest episode + chunk_idx = self.episodes[-1]["meta/episodes/chunk_index"] + file_idx = self.episodes[-1]["meta/episodes/file_index"] + latest_num_frames = self.episodes[-1]["dataset_to_index"] + episode_dict["dataset_from_index"] = [latest_num_frames] + episode_dict["dataset_to_index"] = [latest_num_frames + num_frames] + + # When resuming, move to the next file + chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.chunks_size) + else: + episode_dict["dataset_from_index"] = [0] + episode_dict["dataset_to_index"] = [num_frames] - latest_path = self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx) - latest_size_in_mb = get_parquet_file_size_in_mb(latest_path) + episode_dict["meta/episodes/chunk_index"] = [chunk_idx] + episode_dict["meta/episodes/file_index"] = [file_idx] + else: + chunk_idx = self.latest_episode["meta/episodes/chunk_index"][0] + file_idx = self.latest_episode["meta/episodes/file_index"][0] - if latest_size_in_mb + ep_size_in_mb >= self.data_files_size_in_mb: - # Size limit is reached, prepare new parquet file - chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.chunks_size) + latest_path = ( + self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx) + if self.writer is None + else self.writer.where + ) - # Update the existing pandas dataframe with new row - df["meta/episodes/chunk_index"] = [chunk_idx] - df["meta/episodes/file_index"] = [file_idx] - df["dataset_from_index"] = [latest_ep["dataset_to_index"]] - df["dataset_to_index"] = [latest_ep["dataset_to_index"] + num_frames] + if Path(latest_path).exists(): + latest_size_in_mb = get_file_size_in_mb(Path(latest_path)) + latest_num_frames = self.latest_episode["episode_index"][0] - if latest_size_in_mb + ep_size_in_mb < self.data_files_size_in_mb: - # Size limit wasnt reached, concatenate latest dataframe with new one - latest_df = pd.read_parquet(latest_path) - df = pd.concat([latest_df, df], ignore_index=True) + av_size_per_frame = latest_size_in_mb / latest_num_frames if latest_num_frames > 0 else 0.0 - # Memort optimization - del latest_df - gc.collect() + if latest_size_in_mb + av_size_per_frame * num_frames >= self.data_files_size_in_mb: + # Size limit is reached, flush buffer and prepare new parquet file + self._flush_metadata_buffer() + chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.chunks_size) + self._close_writer() - # Write the resulting dataframe from RAM to disk - path = self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx) - path.parent.mkdir(parents=True, exist_ok=True) - df.to_parquet(path, index=False) + # Update the existing pandas dataframe with new row + episode_dict["meta/episodes/chunk_index"] = [chunk_idx] + episode_dict["meta/episodes/file_index"] = [file_idx] + episode_dict["dataset_from_index"] = [self.latest_episode["dataset_to_index"][0]] + episode_dict["dataset_to_index"] = [self.latest_episode["dataset_to_index"][0] + num_frames] - if self.episodes is not None: - # Remove the episodes cache directory, necessary to avoid cache bloat - cached_dir = get_hf_dataset_cache_dir(self.episodes) - if cached_dir is not None: - shutil.rmtree(cached_dir) + # Add to buffer + self.metadata_buffer.append(episode_dict) + self.latest_episode = episode_dict - self.episodes = load_episodes(self.root) + if len(self.metadata_buffer) >= self.metadata_buffer_size: + self._flush_metadata_buffer() def save_episode( self, @@ -438,6 +502,7 @@ def create( robot_type: str | None = None, root: str | Path | None = None, use_videos: bool = True, + metadata_buffer_size: int = 10, chunks_size: int | None = None, data_files_size_in_mb: int | None = None, video_files_size_in_mb: int | None = None, @@ -469,6 +534,10 @@ def create( raise ValueError() write_json(obj.info, obj.root / INFO_PATH) obj.revision = None + obj.writer = None + obj.latest_episode = None + obj.metadata_buffer = [] + obj.metadata_buffer_size = metadata_buffer_size return obj @@ -615,6 +684,8 @@ def __init__( # Unused attributes self.image_writer = None self.episode_buffer = None + self.writer = None + self.latest_episode = None self.root.mkdir(exist_ok=True, parents=True) @@ -623,6 +694,11 @@ def __init__( self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync ) + # Track dataset state for efficient incremental writing + self._lazy_loading = False + self._recorded_frames = self.meta.total_frames + self._writer_closed_for_reading = False + # Load actual data try: if force_cache_sync: @@ -641,6 +717,19 @@ def __init__( check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s) self.delta_indices = get_delta_indices(self.delta_timestamps, self.fps) + def _close_writer(self) -> None: + """Close and cleanup the parquet writer if it exists.""" + writer = getattr(self, "writer", None) + if writer is not None: + writer.close() + self.writer = None + + def __del__(self): + """ + Trust the user to call .finalize() but as an added safety check call the parquet writer to stop when calling the destructor + """ + self._close_writer() + def push_to_hub( self, branch: str | None = None, @@ -781,8 +870,15 @@ def fps(self) -> int: @property def num_frames(self) -> int: - """Number of frames in selected episodes.""" - return len(self.hf_dataset) if self.hf_dataset is not None else self.meta.total_frames + """Number of frames in selected episodes. + + Note: When episodes a subset of the full dataset is requested, we must return the + actual loaded data length (len(self.hf_dataset)) rather than metadata total_frames. + self.meta.total_frames is the total number of frames in the full dataset. + """ + if self.episodes is not None and self.hf_dataset is not None: + return len(self.hf_dataset) + return self.meta.total_frames @property def num_episodes(self) -> int: @@ -860,10 +956,22 @@ def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) - return item + def _ensure_hf_dataset_loaded(self): + """Lazy load the HF dataset only when needed for reading.""" + if self._lazy_loading or self.hf_dataset is None: + # Close the writer before loading to ensure parquet file is properly finalized + if self.writer is not None: + self._close_writer() + self._writer_closed_for_reading = True + self.hf_dataset = self.load_hf_dataset() + self._lazy_loading = False + def __len__(self): return self.num_frames def __getitem__(self, idx) -> dict: + # Ensure dataset is loaded when we actually need to read from it + self._ensure_hf_dataset_loaded() item = self.hf_dataset[idx] ep_idx = item["episode_index"].item() @@ -902,6 +1010,14 @@ def __repr__(self): "})',\n" ) + def finalize(self): + """ + Close the parquet writers. This function needs to be called after data collection/conversion, else footer metadata won't be written to the parquet files. + The dataset won't be valid and can't be loaded as ds = LeRobotDataset(repo_id=repo, root=HF_LEROBOT_HOME.joinpath(repo)) + """ + self._close_writer() + self.meta._close_writer() + def create_episode_buffer(self, episode_index: int | None = None) -> dict: current_ep_idx = self.meta.total_episodes if episode_index is None else episode_index ep_buffer = {} @@ -1109,74 +1225,101 @@ def _save_episode_data(self, episode_buffer: dict) -> dict: ep_dict = {key: episode_buffer[key] for key in self.hf_features} ep_dataset = datasets.Dataset.from_dict(ep_dict, features=self.hf_features, split="train") ep_dataset = embed_images(ep_dataset) - ep_size_in_mb = get_hf_dataset_size_in_mb(ep_dataset) ep_num_frames = len(ep_dataset) - df = pd.DataFrame(ep_dataset) - if self.meta.episodes is None: + if self.latest_episode is None: # Initialize indices and frame count for a new dataset made of the first episode data chunk_idx, file_idx = 0, 0 - latest_num_frames = 0 + global_frame_index = 0 + # However, if the episodes already exists + # It means we are resuming recording, so we need to load the latest episode + # Update the indices to avoid overwriting the latest episode + if self.meta.episodes is not None and len(self.meta.episodes) > 0: + latest_ep = self.meta.episodes[-1] + global_frame_index = latest_ep["dataset_to_index"] + chunk_idx = latest_ep["data/chunk_index"] + file_idx = latest_ep["data/file_index"] + + # When resuming, move to the next file + chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.meta.chunks_size) else: # Retrieve information from the latest parquet file - latest_ep = self.meta.episodes[-1] + latest_ep = self.latest_episode chunk_idx = latest_ep["data/chunk_index"] file_idx = latest_ep["data/file_index"] + global_frame_index = latest_ep["index"][-1] + 1 latest_path = self.root / self.meta.data_path.format(chunk_index=chunk_idx, file_index=file_idx) - latest_size_in_mb = get_parquet_file_size_in_mb(latest_path) - latest_num_frames = get_parquet_num_frames(latest_path) + latest_size_in_mb = get_file_size_in_mb(latest_path) + + frames_in_current_file = global_frame_index - latest_ep["dataset_from_index"] + av_size_per_frame = ( + latest_size_in_mb / frames_in_current_file if frames_in_current_file > 0 else 0 + ) # Determine if a new parquet file is needed - if latest_size_in_mb + ep_size_in_mb >= self.meta.data_files_size_in_mb: - # Size limit is reached, prepare new parquet file + if ( + latest_size_in_mb + av_size_per_frame * ep_num_frames >= self.meta.data_files_size_in_mb + or self._writer_closed_for_reading + ): + # Size limit is reached or writer was closed for reading, prepare new parquet file chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.meta.chunks_size) - latest_num_frames = 0 - else: - # Update the existing parquet file with new rows - latest_df = pd.read_parquet(latest_path) - df = pd.concat([latest_df, df], ignore_index=True) + self._close_writer() + self._writer_closed_for_reading = False - # Memort optimization - del latest_df - gc.collect() + ep_dict["data/chunk_index"] = chunk_idx + ep_dict["data/file_index"] = file_idx # Write the resulting dataframe from RAM to disk path = self.root / self.meta.data_path.format(chunk_index=chunk_idx, file_index=file_idx) path.parent.mkdir(parents=True, exist_ok=True) - if len(self.meta.image_keys) > 0: - to_parquet_with_hf_images(df, path) - else: - df.to_parquet(path) - if self.hf_dataset is not None: - # Remove hf dataset cache directory, necessary to avoid cache bloat - cached_dir = get_hf_dataset_cache_dir(self.hf_dataset) - if cached_dir is not None: - shutil.rmtree(cached_dir) - - self.hf_dataset = self.load_hf_dataset() + table = ep_dataset.with_format("arrow")[:] + if not self.writer: + self.writer = pq.ParquetWriter( + path, schema=table.schema, compression="snappy", use_dictionary=True + ) + self.writer.write_table(table) metadata = { "data/chunk_index": chunk_idx, "data/file_index": file_idx, - "dataset_from_index": latest_num_frames, - "dataset_to_index": latest_num_frames + ep_num_frames, + "dataset_from_index": global_frame_index, + "dataset_to_index": global_frame_index + ep_num_frames, } + + # Store metadata with episode data for next episode + self.latest_episode = {**ep_dict, **metadata} + + # Mark that the HF dataset needs reloading (lazy loading approach) + # This avoids expensive reloading during sequential recording + self._lazy_loading = True + # Update recorded frames count for efficient length tracking + self._recorded_frames += ep_num_frames + return metadata def _save_episode_video(self, video_key: str, episode_index: int) -> dict: # Encode episode frames into a temporary video ep_path = self._encode_temporary_episode_video(video_key, episode_index) - ep_size_in_mb = get_video_size_in_mb(ep_path) + ep_size_in_mb = get_file_size_in_mb(ep_path) ep_duration_in_s = get_video_duration_in_s(ep_path) - if self.meta.episodes is None or ( - f"videos/{video_key}/chunk_index" not in self.meta.episodes.column_names - or f"videos/{video_key}/file_index" not in self.meta.episodes.column_names + if ( + episode_index == 0 + or self.meta.latest_episode is None + or f"videos/{video_key}/chunk_index" not in self.meta.latest_episode ): # Initialize indices for a new dataset made of the first episode data chunk_idx, file_idx = 0, 0 + if self.meta.episodes is not None and len(self.meta.episodes) > 0: + # It means we are resuming recording, so we need to load the latest episode + # Update the indices to avoid overwriting the latest episode + old_chunk_idx = self.meta.episodes[-1][f"videos/{video_key}/chunk_index"] + old_file_idx = self.meta.episodes[-1][f"videos/{video_key}/file_index"] + chunk_idx, file_idx = update_chunk_file_indices( + old_chunk_idx, old_file_idx, self.meta.chunks_size + ) latest_duration_in_s = 0.0 new_path = self.root / self.meta.video_path.format( video_key=video_key, chunk_index=chunk_idx, file_index=file_idx @@ -1184,16 +1327,16 @@ def _save_episode_video(self, video_key: str, episode_index: int) -> dict: new_path.parent.mkdir(parents=True, exist_ok=True) shutil.move(str(ep_path), str(new_path)) else: - # Retrieve information from the latest updated video file (possibly several episodes ago) - latest_ep = self.meta.episodes[episode_index - 1] - chunk_idx = latest_ep[f"videos/{video_key}/chunk_index"] - file_idx = latest_ep[f"videos/{video_key}/file_index"] + # Retrieve information from the latest updated video file using latest_episode + latest_ep = self.meta.latest_episode + chunk_idx = latest_ep[f"videos/{video_key}/chunk_index"][0] + file_idx = latest_ep[f"videos/{video_key}/file_index"][0] latest_path = self.root / self.meta.video_path.format( video_key=video_key, chunk_index=chunk_idx, file_index=file_idx ) - latest_size_in_mb = get_video_size_in_mb(latest_path) - latest_duration_in_s = get_video_duration_in_s(latest_path) + latest_size_in_mb = get_file_size_in_mb(latest_path) + latest_duration_in_s = latest_ep[f"videos/{video_key}/to_timestamp"][0] if latest_size_in_mb + ep_size_in_mb >= self.meta.video_files_size_in_mb: # Move temporary episode video to a new video file in the dataset @@ -1327,6 +1470,12 @@ def create( obj.delta_timestamps = None obj.delta_indices = None obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec() + obj.writer = None + obj.latest_episode = None + # Initialize tracking for incremental recording + obj._lazy_loading = False + obj._recorded_frames = 0 + obj._writer_closed_for_reading = False return obj diff --git a/src/lerobot/datasets/utils.py b/src/lerobot/datasets/utils.py index 422a7010a6..37d8432b2b 100644 --- a/src/lerobot/datasets/utils.py +++ b/src/lerobot/datasets/utils.py @@ -94,12 +94,6 @@ def get_hf_dataset_size_in_mb(hf_ds: Dataset) -> int: return hf_ds.data.nbytes // (1024**2) -def get_hf_dataset_cache_dir(hf_ds: Dataset) -> Path | None: - if hf_ds.cache_files is None or len(hf_ds.cache_files) == 0: - return None - return Path(hf_ds.cache_files[0]["filename"]).parents[2] - - def update_chunk_file_indices(chunk_idx: int, file_idx: int, chunks_size: int) -> tuple[int, int]: if file_idx == chunks_size - 1: file_idx = 0 @@ -133,10 +127,14 @@ def get_parquet_num_frames(parquet_path: str | Path) -> int: return metadata.num_rows -def get_video_size_in_mb(mp4_path: Path) -> float: - file_size_bytes = mp4_path.stat().st_size - file_size_mb = file_size_bytes / (1024**2) - return file_size_mb +def get_file_size_in_mb(file_path: Path) -> float: + """Get file size on disk in megabytes. + + Args: + file_path (Path): Path to the file. + """ + file_size_bytes = file_path.stat().st_size + return file_size_bytes / (1024**2) def flatten_dict(d: dict, parent_key: str = "", sep: str = "/") -> dict: diff --git a/src/lerobot/datasets/video_utils.py b/src/lerobot/datasets/video_utils.py index 620ba863ac..740cdb6020 100644 --- a/src/lerobot/datasets/video_utils.py +++ b/src/lerobot/datasets/video_utils.py @@ -642,6 +642,9 @@ def __exit__(self, exc_type, exc_val, exc_tb): ) self.dataset._batch_save_episode_video(start_ep, end_ep) + # Finalize the dataset to properly close all writers + self.dataset.finalize() + # Clean up episode images if recording was interrupted if exc_type is not None: interrupted_episode_index = self.dataset.num_episodes diff --git a/src/lerobot/rl/buffer.py b/src/lerobot/rl/buffer.py index 917e4e2cc9..81aa29c480 100644 --- a/src/lerobot/rl/buffer.py +++ b/src/lerobot/rl/buffer.py @@ -607,6 +607,7 @@ def to_lerobot_dataset( lerobot_dataset.save_episode() lerobot_dataset.stop_image_writer() + lerobot_dataset.finalize() return lerobot_dataset diff --git a/tests/datasets/test_dataset_tools.py b/tests/datasets/test_dataset_tools.py index fe117b35b8..a9c04d6f24 100644 --- a/tests/datasets/test_dataset_tools.py +++ b/tests/datasets/test_dataset_tools.py @@ -55,6 +55,7 @@ def sample_dataset(tmp_path, empty_lerobot_dataset_factory): dataset.add_frame(frame) dataset.save_episode() + dataset.finalize() return dataset @@ -263,6 +264,7 @@ def test_merge_two_datasets(sample_dataset, tmp_path, empty_lerobot_dataset_fact } dataset2.add_frame(frame) dataset2.save_episode() + dataset2.finalize() with ( patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, @@ -685,6 +687,7 @@ def test_merge_three_datasets(sample_dataset, tmp_path, empty_lerobot_dataset_fa } dataset.add_frame(frame) dataset.save_episode() + dataset.finalize() datasets.append(dataset) @@ -728,6 +731,7 @@ def test_merge_preserves_stats(sample_dataset, tmp_path, empty_lerobot_dataset_f } dataset2.add_frame(frame) dataset2.save_episode() + dataset2.finalize() with ( patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, diff --git a/tests/datasets/test_datasets.py b/tests/datasets/test_datasets.py index 2bc3bea43b..e174c57896 100644 --- a/tests/datasets/test_datasets.py +++ b/tests/datasets/test_datasets.py @@ -806,6 +806,8 @@ def test_episode_index_distribution(tmp_path, empty_lerobot_dataset_factory): dataset.add_frame({"state": torch.randn(2), "task": f"task_{episode_idx}"}) dataset.save_episode() + dataset.finalize() + # Load the dataset and check episode indices loaded_dataset = LeRobotDataset(dataset.repo_id, root=dataset.root) @@ -855,6 +857,8 @@ def test_multi_episode_metadata_consistency(tmp_path, empty_lerobot_dataset_fact dataset.add_frame({"state": torch.randn(3), ACTION: torch.randn(2), "task": tasks[episode_idx]}) dataset.save_episode() + dataset.finalize() + # Load and validate episode metadata loaded_dataset = LeRobotDataset(dataset.repo_id, root=dataset.root) @@ -893,6 +897,8 @@ def test_data_consistency_across_episodes(tmp_path, empty_lerobot_dataset_factor dataset.add_frame({"state": torch.randn(1), "task": "consistency_test"}) dataset.save_episode() + dataset.finalize() + loaded_dataset = LeRobotDataset(dataset.repo_id, root=dataset.root) # Check data consistency - no gaps or overlaps @@ -944,6 +950,8 @@ def test_statistics_metadata_validation(tmp_path, empty_lerobot_dataset_factory) dataset.add_frame({"state": state_data, ACTION: action_data, "task": "stats_test"}) dataset.save_episode() + dataset.finalize() + loaded_dataset = LeRobotDataset(dataset.repo_id, root=dataset.root) # Check that statistics exist for all features @@ -989,6 +997,8 @@ def test_episode_boundary_integrity(tmp_path, empty_lerobot_dataset_factory): dataset.add_frame({"state": torch.tensor([float(frame_idx)]), "task": f"episode_{episode_idx}"}) dataset.save_episode() + dataset.finalize() + loaded_dataset = LeRobotDataset(dataset.repo_id, root=dataset.root) # Test episode boundaries @@ -1031,6 +1041,8 @@ def test_task_indexing_and_validation(tmp_path, empty_lerobot_dataset_factory): dataset.add_frame({"state": torch.randn(1), "task": task}) dataset.save_episode() + dataset.finalize() + loaded_dataset = LeRobotDataset(dataset.repo_id, root=dataset.root) # Check that all unique tasks are in the tasks metadata @@ -1056,3 +1068,134 @@ def test_task_indexing_and_validation(tmp_path, empty_lerobot_dataset_factory): # Check total number of tasks assert loaded_dataset.meta.total_tasks == len(unique_tasks) + + +def test_dataset_resume_recording(tmp_path, empty_lerobot_dataset_factory): + """Test that resuming dataset recording preserves previously recorded episodes. + + This test validates the critical resume functionality by: + 1. Recording initial episodes and finalizing + 2. Reopening the dataset + 3. Recording additional episodes + 4. Verifying all data (old + new) is intact + + This specifically tests the bug fix where parquet files were being overwritten + instead of appended to during resume. + """ + features = { + "observation.state": {"dtype": "float32", "shape": (2,), "names": ["x", "y"]}, + "action": {"dtype": "float32", "shape": (2,), "names": ["x", "y"]}, + } + + dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features, use_videos=False) + + initial_episodes = 2 + frames_per_episode = 3 + + for ep_idx in range(initial_episodes): + for frame_idx in range(frames_per_episode): + dataset.add_frame( + { + "observation.state": torch.tensor([float(ep_idx), float(frame_idx)]), + "action": torch.tensor([0.5, 0.5]), + "task": f"task_{ep_idx}", + } + ) + dataset.save_episode() + + assert dataset.meta.total_episodes == initial_episodes + assert dataset.meta.total_frames == initial_episodes * frames_per_episode + + dataset.finalize() + initial_root = dataset.root + initial_repo_id = dataset.repo_id + del dataset + + dataset_verify = LeRobotDataset(initial_repo_id, root=initial_root, revision="v3.0") + assert dataset_verify.meta.total_episodes == initial_episodes + assert dataset_verify.meta.total_frames == initial_episodes * frames_per_episode + assert len(dataset_verify.hf_dataset) == initial_episodes * frames_per_episode + + for idx in range(len(dataset_verify.hf_dataset)): + item = dataset_verify[idx] + expected_ep = idx // frames_per_episode + expected_frame = idx % frames_per_episode + assert item["episode_index"].item() == expected_ep + assert item["frame_index"].item() == expected_frame + assert item["index"].item() == idx + assert item["observation.state"][0].item() == float(expected_ep) + assert item["observation.state"][1].item() == float(expected_frame) + + del dataset_verify + + # Phase 3: Resume recording - add more episodes + dataset_resumed = LeRobotDataset(initial_repo_id, root=initial_root, revision="v3.0") + + assert dataset_resumed.meta.total_episodes == initial_episodes + assert dataset_resumed.meta.total_frames == initial_episodes * frames_per_episode + assert dataset_resumed.latest_episode is None # Not recording yet + assert dataset_resumed.writer is None + assert dataset_resumed.meta.writer is None + + additional_episodes = 2 + for ep_idx in range(initial_episodes, initial_episodes + additional_episodes): + for frame_idx in range(frames_per_episode): + dataset_resumed.add_frame( + { + "observation.state": torch.tensor([float(ep_idx), float(frame_idx)]), + "action": torch.tensor([0.5, 0.5]), + "task": f"task_{ep_idx}", + } + ) + dataset_resumed.save_episode() + + total_episodes = initial_episodes + additional_episodes + total_frames = total_episodes * frames_per_episode + assert dataset_resumed.meta.total_episodes == total_episodes + assert dataset_resumed.meta.total_frames == total_frames + + dataset_resumed.finalize() + del dataset_resumed + + dataset_final = LeRobotDataset(initial_repo_id, root=initial_root, revision="v3.0") + + assert dataset_final.meta.total_episodes == total_episodes + assert dataset_final.meta.total_frames == total_frames + assert len(dataset_final.hf_dataset) == total_frames + + for idx in range(total_frames): + item = dataset_final[idx] + expected_ep = idx // frames_per_episode + expected_frame = idx % frames_per_episode + + assert item["episode_index"].item() == expected_ep, ( + f"Frame {idx}: wrong episode_index. Expected {expected_ep}, got {item['episode_index'].item()}" + ) + assert item["frame_index"].item() == expected_frame, ( + f"Frame {idx}: wrong frame_index. Expected {expected_frame}, got {item['frame_index'].item()}" + ) + assert item["index"].item() == idx, ( + f"Frame {idx}: wrong index. Expected {idx}, got {item['index'].item()}" + ) + + # Verify data integrity + assert item["observation.state"][0].item() == float(expected_ep), ( + f"Frame {idx}: wrong observation.state[0]. Expected {float(expected_ep)}, " + f"got {item['observation.state'][0].item()}" + ) + assert item["observation.state"][1].item() == float(expected_frame), ( + f"Frame {idx}: wrong observation.state[1]. Expected {float(expected_frame)}, " + f"got {item['observation.state'][1].item()}" + ) + + assert len(dataset_final.meta.episodes) == total_episodes + for ep_idx in range(total_episodes): + ep_metadata = dataset_final.meta.episodes[ep_idx] + assert ep_metadata["episode_index"] == ep_idx + assert ep_metadata["length"] == frames_per_episode + assert ep_metadata["tasks"] == [f"task_{ep_idx}"] + + expected_from = ep_idx * frames_per_episode + expected_to = (ep_idx + 1) * frames_per_episode + assert ep_metadata["dataset_from_index"] == expected_from + assert ep_metadata["dataset_to_index"] == expected_to