Skip to content

Commit fd13439

Browse files
Added rigorous testing to validate the consistency of the meta data after creation of a new dataset
1 parent d399526 commit fd13439

File tree

1 file changed

+278
-0
lines changed

1 file changed

+278
-0
lines changed

tests/datasets/test_datasets.py

Lines changed: 278 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -789,3 +789,281 @@ def test_update_chunk_settings_video_dataset(tmp_path):
789789
dataset.meta.update_chunk_settings(video_files_size_in_mb=new_video_size)
790790
assert dataset.meta.get_chunk_settings()["video_files_size_in_mb"] == new_video_size
791791
assert dataset.meta.video_files_size_in_mb == new_video_size
792+
793+
794+
def test_episode_index_distribution(tmp_path, empty_lerobot_dataset_factory):
795+
"""Test that all frames have correct episode indices across multiple episodes."""
796+
features = {"state": {"dtype": "float32", "shape": (2,), "names": None}}
797+
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features, use_videos=False)
798+
799+
# Create 3 episodes with different lengths
800+
num_episodes = 3
801+
frames_per_episode = [10, 15, 8]
802+
803+
for episode_idx in range(num_episodes):
804+
for _ in range(frames_per_episode[episode_idx]):
805+
dataset.add_frame({"state": torch.randn(2), "task": f"task_{episode_idx}"})
806+
dataset.save_episode()
807+
808+
dataset.finalize()
809+
810+
# Load the dataset and check episode indices
811+
loaded_dataset = LeRobotDataset(dataset.repo_id, root=dataset.root)
812+
813+
# Check specific frames across episode boundaries
814+
cumulative = 0
815+
for ep_idx, ep_length in enumerate(frames_per_episode):
816+
# Check start, middle, and end of each episode
817+
start_frame = cumulative
818+
middle_frame = cumulative + ep_length // 2
819+
end_frame = cumulative + ep_length - 1
820+
821+
for frame_idx in [start_frame, middle_frame, end_frame]:
822+
frame_data = loaded_dataset[frame_idx]
823+
actual_ep_idx = frame_data["episode_index"].item()
824+
assert actual_ep_idx == ep_idx, (
825+
f"Frame {frame_idx} has episode_index {actual_ep_idx}, should be {ep_idx}"
826+
)
827+
828+
cumulative += ep_length
829+
830+
# Check episode index distribution
831+
all_episode_indices = [loaded_dataset[i]["episode_index"].item() for i in range(len(loaded_dataset))]
832+
from collections import Counter
833+
834+
distribution = Counter(all_episode_indices)
835+
expected_dist = {i: frames_per_episode[i] for i in range(num_episodes)}
836+
837+
assert dict(distribution) == expected_dist, (
838+
f"Episode distribution {dict(distribution)} != expected {expected_dist}"
839+
)
840+
841+
842+
def test_multi_episode_metadata_consistency(tmp_path, empty_lerobot_dataset_factory):
843+
"""Test episode metadata consistency across multiple episodes."""
844+
features = {
845+
"state": {"dtype": "float32", "shape": (3,), "names": ["x", "y", "z"]},
846+
"action": {"dtype": "float32", "shape": (2,), "names": ["v", "w"]},
847+
}
848+
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features, use_videos=False)
849+
850+
num_episodes = 4
851+
frames_per_episode = [20, 35, 10, 25]
852+
tasks = ["pick", "place", "pick", "place"]
853+
854+
for episode_idx in range(num_episodes):
855+
for _ in range(frames_per_episode[episode_idx]):
856+
dataset.add_frame({"state": torch.randn(3), "action": torch.randn(2), "task": tasks[episode_idx]})
857+
dataset.save_episode()
858+
859+
dataset.finalize()
860+
861+
# Load and validate episode metadata
862+
loaded_dataset = LeRobotDataset(dataset.repo_id, root=dataset.root)
863+
864+
assert loaded_dataset.meta.total_episodes == num_episodes
865+
assert loaded_dataset.meta.total_frames == sum(frames_per_episode)
866+
867+
cumulative_frames = 0
868+
for episode_idx in range(num_episodes):
869+
episode_metadata = loaded_dataset.meta.episodes[episode_idx]
870+
871+
# Check basic episode properties
872+
assert episode_metadata["episode_index"] == episode_idx
873+
assert episode_metadata["length"] == frames_per_episode[episode_idx]
874+
assert episode_metadata["tasks"] == [tasks[episode_idx]]
875+
876+
# Check dataset indices
877+
expected_from = cumulative_frames
878+
expected_to = cumulative_frames + frames_per_episode[episode_idx]
879+
880+
assert episode_metadata["dataset_from_index"] == expected_from
881+
assert episode_metadata["dataset_to_index"] == expected_to
882+
883+
cumulative_frames += frames_per_episode[episode_idx]
884+
885+
886+
def test_data_consistency_across_episodes(tmp_path, empty_lerobot_dataset_factory):
887+
"""Test that episodes have no gaps or overlaps in their data indices."""
888+
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
889+
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features, use_videos=False)
890+
891+
num_episodes = 5
892+
frames_per_episode = [12, 8, 20, 15, 5]
893+
894+
for episode_idx in range(num_episodes):
895+
for _ in range(frames_per_episode[episode_idx]):
896+
dataset.add_frame({"state": torch.randn(1), "task": "consistency_test"})
897+
dataset.save_episode()
898+
899+
dataset.finalize()
900+
901+
loaded_dataset = LeRobotDataset(dataset.repo_id, root=dataset.root)
902+
903+
# Check data consistency - no gaps or overlaps
904+
cumulative_check = 0
905+
for episode_idx in range(num_episodes):
906+
episode_metadata = loaded_dataset.meta.episodes[episode_idx]
907+
from_idx = episode_metadata["dataset_from_index"]
908+
to_idx = episode_metadata["dataset_to_index"]
909+
910+
# Check that episode starts exactly where previous ended
911+
assert from_idx == cumulative_check, (
912+
f"Episode {episode_idx} starts at {from_idx}, expected {cumulative_check}"
913+
)
914+
915+
# Check that episode length matches expected
916+
actual_length = to_idx - from_idx
917+
expected_length = frames_per_episode[episode_idx]
918+
assert actual_length == expected_length, (
919+
f"Episode {episode_idx} length {actual_length} != expected {expected_length}"
920+
)
921+
922+
cumulative_check = to_idx
923+
924+
# Final check: last episode should end at total frames
925+
expected_total_frames = sum(frames_per_episode)
926+
assert cumulative_check == expected_total_frames, (
927+
f"Final frame count {cumulative_check} != expected {expected_total_frames}"
928+
)
929+
930+
931+
def test_statistics_metadata_validation(tmp_path, empty_lerobot_dataset_factory):
932+
"""Test that statistics are properly computed and stored for all features."""
933+
features = {
934+
"state": {"dtype": "float32", "shape": (2,), "names": ["pos", "vel"]},
935+
"action": {"dtype": "float32", "shape": (1,), "names": ["force"]},
936+
}
937+
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features, use_videos=False)
938+
939+
# Create controlled data to verify statistics
940+
num_episodes = 2
941+
frames_per_episode = [10, 10]
942+
943+
# Use deterministic data for predictable statistics
944+
torch.manual_seed(42)
945+
for episode_idx in range(num_episodes):
946+
for frame_idx in range(frames_per_episode[episode_idx]):
947+
state_data = torch.tensor([frame_idx * 0.1, frame_idx * 0.2], dtype=torch.float32)
948+
action_data = torch.tensor([frame_idx * 0.05], dtype=torch.float32)
949+
dataset.add_frame({"state": state_data, "action": action_data, "task": "stats_test"})
950+
dataset.save_episode()
951+
952+
dataset.finalize()
953+
954+
loaded_dataset = LeRobotDataset(dataset.repo_id, root=dataset.root)
955+
956+
# Check that statistics exist for all features
957+
assert loaded_dataset.meta.stats is not None, "No statistics found"
958+
959+
for feature_name in features.keys():
960+
assert feature_name in loaded_dataset.meta.stats, f"No statistics for feature '{feature_name}'"
961+
962+
feature_stats = loaded_dataset.meta.stats[feature_name]
963+
expected_stats = ["min", "max", "mean", "std", "count"]
964+
965+
for stat_key in expected_stats:
966+
assert stat_key in feature_stats, f"Missing '{stat_key}' statistic for '{feature_name}'"
967+
968+
stat_value = feature_stats[stat_key]
969+
# Basic sanity checks
970+
if stat_key == "count":
971+
assert stat_value == sum(frames_per_episode), f"Wrong count for '{feature_name}'"
972+
elif stat_key in ["min", "max", "mean", "std"]:
973+
# Check that statistics are reasonable (not NaN, proper shapes)
974+
if hasattr(stat_value, "shape"):
975+
expected_shape = features[feature_name]["shape"]
976+
assert stat_value.shape == expected_shape or len(stat_value) == expected_shape[0], (
977+
f"Wrong shape for {stat_key} of '{feature_name}'"
978+
)
979+
# Check no NaN values
980+
if hasattr(stat_value, "__iter__"):
981+
assert not any(np.isnan(v) for v in stat_value), f"NaN in {stat_key} for '{feature_name}'"
982+
else:
983+
assert not np.isnan(stat_value), f"NaN in {stat_key} for '{feature_name}'"
984+
985+
986+
def test_episode_boundary_integrity(tmp_path, empty_lerobot_dataset_factory):
987+
"""Test frame indices and episode transitions at episode boundaries."""
988+
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
989+
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features, use_videos=False)
990+
991+
num_episodes = 3
992+
frames_per_episode = [7, 12, 5]
993+
994+
for episode_idx in range(num_episodes):
995+
for frame_idx in range(frames_per_episode[episode_idx]):
996+
dataset.add_frame({"state": torch.tensor([float(frame_idx)]), "task": f"episode_{episode_idx}"})
997+
dataset.save_episode()
998+
999+
dataset.finalize()
1000+
1001+
loaded_dataset = LeRobotDataset(dataset.repo_id, root=dataset.root)
1002+
1003+
# Test episode boundaries
1004+
cumulative = 0
1005+
for ep_idx, ep_length in enumerate(frames_per_episode):
1006+
if ep_idx > 0:
1007+
# Check last frame of previous episode
1008+
prev_frame = loaded_dataset[cumulative - 1]
1009+
assert prev_frame["episode_index"].item() == ep_idx - 1
1010+
1011+
# Check first frame of current episode
1012+
if cumulative < len(loaded_dataset):
1013+
curr_frame = loaded_dataset[cumulative]
1014+
assert curr_frame["episode_index"].item() == ep_idx
1015+
1016+
# Check frame_index within episode
1017+
for i in range(ep_length):
1018+
if cumulative + i < len(loaded_dataset):
1019+
frame = loaded_dataset[cumulative + i]
1020+
assert frame["frame_index"].item() == i, f"Frame {cumulative + i} has wrong frame_index"
1021+
assert frame["episode_index"].item() == ep_idx, (
1022+
f"Frame {cumulative + i} has wrong episode_index"
1023+
)
1024+
1025+
cumulative += ep_length
1026+
1027+
1028+
def test_task_indexing_and_validation(tmp_path, empty_lerobot_dataset_factory):
1029+
"""Test that tasks are properly indexed and retrievable."""
1030+
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
1031+
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features, use_videos=False)
1032+
1033+
# Use multiple tasks, including repeated ones
1034+
tasks = ["pick", "place", "pick", "navigate", "place"]
1035+
unique_tasks = list(set(tasks)) # ["pick", "place", "navigate"]
1036+
frames_per_episode = [5, 8, 3, 10, 6]
1037+
1038+
for episode_idx, task in enumerate(tasks):
1039+
for _ in range(frames_per_episode[episode_idx]):
1040+
dataset.add_frame({"state": torch.randn(1), "task": task})
1041+
dataset.save_episode()
1042+
1043+
dataset.finalize()
1044+
1045+
loaded_dataset = LeRobotDataset(dataset.repo_id, root=dataset.root)
1046+
1047+
# Check that all unique tasks are in the tasks metadata
1048+
stored_tasks = set(loaded_dataset.meta.tasks.index)
1049+
assert stored_tasks == set(unique_tasks), f"Stored tasks {stored_tasks} != expected {set(unique_tasks)}"
1050+
1051+
# Check that task indices are consistent
1052+
cumulative = 0
1053+
for episode_idx, expected_task in enumerate(tasks):
1054+
episode_metadata = loaded_dataset.meta.episodes[episode_idx]
1055+
assert episode_metadata["tasks"] == [expected_task]
1056+
1057+
# Check frames in this episode have correct task
1058+
for i in range(frames_per_episode[episode_idx]):
1059+
frame = loaded_dataset[cumulative + i]
1060+
assert frame["task"] == expected_task, f"Frame {cumulative + i} has wrong task"
1061+
1062+
# Check task_index consistency
1063+
expected_task_index = loaded_dataset.meta.get_task_index(expected_task)
1064+
assert frame["task_index"].item() == expected_task_index
1065+
1066+
cumulative += frames_per_episode[episode_idx]
1067+
1068+
# Check total number of tasks
1069+
assert loaded_dataset.meta.total_tasks == len(unique_tasks)

0 commit comments

Comments
 (0)