@@ -789,3 +789,281 @@ def test_update_chunk_settings_video_dataset(tmp_path):
789
789
dataset .meta .update_chunk_settings (video_files_size_in_mb = new_video_size )
790
790
assert dataset .meta .get_chunk_settings ()["video_files_size_in_mb" ] == new_video_size
791
791
assert dataset .meta .video_files_size_in_mb == new_video_size
792
+
793
+
794
+ def test_episode_index_distribution (tmp_path , empty_lerobot_dataset_factory ):
795
+ """Test that all frames have correct episode indices across multiple episodes."""
796
+ features = {"state" : {"dtype" : "float32" , "shape" : (2 ,), "names" : None }}
797
+ dataset = empty_lerobot_dataset_factory (root = tmp_path / "test" , features = features , use_videos = False )
798
+
799
+ # Create 3 episodes with different lengths
800
+ num_episodes = 3
801
+ frames_per_episode = [10 , 15 , 8 ]
802
+
803
+ for episode_idx in range (num_episodes ):
804
+ for _ in range (frames_per_episode [episode_idx ]):
805
+ dataset .add_frame ({"state" : torch .randn (2 ), "task" : f"task_{ episode_idx } " })
806
+ dataset .save_episode ()
807
+
808
+ dataset .finalize ()
809
+
810
+ # Load the dataset and check episode indices
811
+ loaded_dataset = LeRobotDataset (dataset .repo_id , root = dataset .root )
812
+
813
+ # Check specific frames across episode boundaries
814
+ cumulative = 0
815
+ for ep_idx , ep_length in enumerate (frames_per_episode ):
816
+ # Check start, middle, and end of each episode
817
+ start_frame = cumulative
818
+ middle_frame = cumulative + ep_length // 2
819
+ end_frame = cumulative + ep_length - 1
820
+
821
+ for frame_idx in [start_frame , middle_frame , end_frame ]:
822
+ frame_data = loaded_dataset [frame_idx ]
823
+ actual_ep_idx = frame_data ["episode_index" ].item ()
824
+ assert actual_ep_idx == ep_idx , (
825
+ f"Frame { frame_idx } has episode_index { actual_ep_idx } , should be { ep_idx } "
826
+ )
827
+
828
+ cumulative += ep_length
829
+
830
+ # Check episode index distribution
831
+ all_episode_indices = [loaded_dataset [i ]["episode_index" ].item () for i in range (len (loaded_dataset ))]
832
+ from collections import Counter
833
+
834
+ distribution = Counter (all_episode_indices )
835
+ expected_dist = {i : frames_per_episode [i ] for i in range (num_episodes )}
836
+
837
+ assert dict (distribution ) == expected_dist , (
838
+ f"Episode distribution { dict (distribution )} != expected { expected_dist } "
839
+ )
840
+
841
+
842
+ def test_multi_episode_metadata_consistency (tmp_path , empty_lerobot_dataset_factory ):
843
+ """Test episode metadata consistency across multiple episodes."""
844
+ features = {
845
+ "state" : {"dtype" : "float32" , "shape" : (3 ,), "names" : ["x" , "y" , "z" ]},
846
+ "action" : {"dtype" : "float32" , "shape" : (2 ,), "names" : ["v" , "w" ]},
847
+ }
848
+ dataset = empty_lerobot_dataset_factory (root = tmp_path / "test" , features = features , use_videos = False )
849
+
850
+ num_episodes = 4
851
+ frames_per_episode = [20 , 35 , 10 , 25 ]
852
+ tasks = ["pick" , "place" , "pick" , "place" ]
853
+
854
+ for episode_idx in range (num_episodes ):
855
+ for _ in range (frames_per_episode [episode_idx ]):
856
+ dataset .add_frame ({"state" : torch .randn (3 ), "action" : torch .randn (2 ), "task" : tasks [episode_idx ]})
857
+ dataset .save_episode ()
858
+
859
+ dataset .finalize ()
860
+
861
+ # Load and validate episode metadata
862
+ loaded_dataset = LeRobotDataset (dataset .repo_id , root = dataset .root )
863
+
864
+ assert loaded_dataset .meta .total_episodes == num_episodes
865
+ assert loaded_dataset .meta .total_frames == sum (frames_per_episode )
866
+
867
+ cumulative_frames = 0
868
+ for episode_idx in range (num_episodes ):
869
+ episode_metadata = loaded_dataset .meta .episodes [episode_idx ]
870
+
871
+ # Check basic episode properties
872
+ assert episode_metadata ["episode_index" ] == episode_idx
873
+ assert episode_metadata ["length" ] == frames_per_episode [episode_idx ]
874
+ assert episode_metadata ["tasks" ] == [tasks [episode_idx ]]
875
+
876
+ # Check dataset indices
877
+ expected_from = cumulative_frames
878
+ expected_to = cumulative_frames + frames_per_episode [episode_idx ]
879
+
880
+ assert episode_metadata ["dataset_from_index" ] == expected_from
881
+ assert episode_metadata ["dataset_to_index" ] == expected_to
882
+
883
+ cumulative_frames += frames_per_episode [episode_idx ]
884
+
885
+
886
+ def test_data_consistency_across_episodes (tmp_path , empty_lerobot_dataset_factory ):
887
+ """Test that episodes have no gaps or overlaps in their data indices."""
888
+ features = {"state" : {"dtype" : "float32" , "shape" : (1 ,), "names" : None }}
889
+ dataset = empty_lerobot_dataset_factory (root = tmp_path / "test" , features = features , use_videos = False )
890
+
891
+ num_episodes = 5
892
+ frames_per_episode = [12 , 8 , 20 , 15 , 5 ]
893
+
894
+ for episode_idx in range (num_episodes ):
895
+ for _ in range (frames_per_episode [episode_idx ]):
896
+ dataset .add_frame ({"state" : torch .randn (1 ), "task" : "consistency_test" })
897
+ dataset .save_episode ()
898
+
899
+ dataset .finalize ()
900
+
901
+ loaded_dataset = LeRobotDataset (dataset .repo_id , root = dataset .root )
902
+
903
+ # Check data consistency - no gaps or overlaps
904
+ cumulative_check = 0
905
+ for episode_idx in range (num_episodes ):
906
+ episode_metadata = loaded_dataset .meta .episodes [episode_idx ]
907
+ from_idx = episode_metadata ["dataset_from_index" ]
908
+ to_idx = episode_metadata ["dataset_to_index" ]
909
+
910
+ # Check that episode starts exactly where previous ended
911
+ assert from_idx == cumulative_check , (
912
+ f"Episode { episode_idx } starts at { from_idx } , expected { cumulative_check } "
913
+ )
914
+
915
+ # Check that episode length matches expected
916
+ actual_length = to_idx - from_idx
917
+ expected_length = frames_per_episode [episode_idx ]
918
+ assert actual_length == expected_length , (
919
+ f"Episode { episode_idx } length { actual_length } != expected { expected_length } "
920
+ )
921
+
922
+ cumulative_check = to_idx
923
+
924
+ # Final check: last episode should end at total frames
925
+ expected_total_frames = sum (frames_per_episode )
926
+ assert cumulative_check == expected_total_frames , (
927
+ f"Final frame count { cumulative_check } != expected { expected_total_frames } "
928
+ )
929
+
930
+
931
+ def test_statistics_metadata_validation (tmp_path , empty_lerobot_dataset_factory ):
932
+ """Test that statistics are properly computed and stored for all features."""
933
+ features = {
934
+ "state" : {"dtype" : "float32" , "shape" : (2 ,), "names" : ["pos" , "vel" ]},
935
+ "action" : {"dtype" : "float32" , "shape" : (1 ,), "names" : ["force" ]},
936
+ }
937
+ dataset = empty_lerobot_dataset_factory (root = tmp_path / "test" , features = features , use_videos = False )
938
+
939
+ # Create controlled data to verify statistics
940
+ num_episodes = 2
941
+ frames_per_episode = [10 , 10 ]
942
+
943
+ # Use deterministic data for predictable statistics
944
+ torch .manual_seed (42 )
945
+ for episode_idx in range (num_episodes ):
946
+ for frame_idx in range (frames_per_episode [episode_idx ]):
947
+ state_data = torch .tensor ([frame_idx * 0.1 , frame_idx * 0.2 ], dtype = torch .float32 )
948
+ action_data = torch .tensor ([frame_idx * 0.05 ], dtype = torch .float32 )
949
+ dataset .add_frame ({"state" : state_data , "action" : action_data , "task" : "stats_test" })
950
+ dataset .save_episode ()
951
+
952
+ dataset .finalize ()
953
+
954
+ loaded_dataset = LeRobotDataset (dataset .repo_id , root = dataset .root )
955
+
956
+ # Check that statistics exist for all features
957
+ assert loaded_dataset .meta .stats is not None , "No statistics found"
958
+
959
+ for feature_name in features .keys ():
960
+ assert feature_name in loaded_dataset .meta .stats , f"No statistics for feature '{ feature_name } '"
961
+
962
+ feature_stats = loaded_dataset .meta .stats [feature_name ]
963
+ expected_stats = ["min" , "max" , "mean" , "std" , "count" ]
964
+
965
+ for stat_key in expected_stats :
966
+ assert stat_key in feature_stats , f"Missing '{ stat_key } ' statistic for '{ feature_name } '"
967
+
968
+ stat_value = feature_stats [stat_key ]
969
+ # Basic sanity checks
970
+ if stat_key == "count" :
971
+ assert stat_value == sum (frames_per_episode ), f"Wrong count for '{ feature_name } '"
972
+ elif stat_key in ["min" , "max" , "mean" , "std" ]:
973
+ # Check that statistics are reasonable (not NaN, proper shapes)
974
+ if hasattr (stat_value , "shape" ):
975
+ expected_shape = features [feature_name ]["shape" ]
976
+ assert stat_value .shape == expected_shape or len (stat_value ) == expected_shape [0 ], (
977
+ f"Wrong shape for { stat_key } of '{ feature_name } '"
978
+ )
979
+ # Check no NaN values
980
+ if hasattr (stat_value , "__iter__" ):
981
+ assert not any (np .isnan (v ) for v in stat_value ), f"NaN in { stat_key } for '{ feature_name } '"
982
+ else :
983
+ assert not np .isnan (stat_value ), f"NaN in { stat_key } for '{ feature_name } '"
984
+
985
+
986
+ def test_episode_boundary_integrity (tmp_path , empty_lerobot_dataset_factory ):
987
+ """Test frame indices and episode transitions at episode boundaries."""
988
+ features = {"state" : {"dtype" : "float32" , "shape" : (1 ,), "names" : None }}
989
+ dataset = empty_lerobot_dataset_factory (root = tmp_path / "test" , features = features , use_videos = False )
990
+
991
+ num_episodes = 3
992
+ frames_per_episode = [7 , 12 , 5 ]
993
+
994
+ for episode_idx in range (num_episodes ):
995
+ for frame_idx in range (frames_per_episode [episode_idx ]):
996
+ dataset .add_frame ({"state" : torch .tensor ([float (frame_idx )]), "task" : f"episode_{ episode_idx } " })
997
+ dataset .save_episode ()
998
+
999
+ dataset .finalize ()
1000
+
1001
+ loaded_dataset = LeRobotDataset (dataset .repo_id , root = dataset .root )
1002
+
1003
+ # Test episode boundaries
1004
+ cumulative = 0
1005
+ for ep_idx , ep_length in enumerate (frames_per_episode ):
1006
+ if ep_idx > 0 :
1007
+ # Check last frame of previous episode
1008
+ prev_frame = loaded_dataset [cumulative - 1 ]
1009
+ assert prev_frame ["episode_index" ].item () == ep_idx - 1
1010
+
1011
+ # Check first frame of current episode
1012
+ if cumulative < len (loaded_dataset ):
1013
+ curr_frame = loaded_dataset [cumulative ]
1014
+ assert curr_frame ["episode_index" ].item () == ep_idx
1015
+
1016
+ # Check frame_index within episode
1017
+ for i in range (ep_length ):
1018
+ if cumulative + i < len (loaded_dataset ):
1019
+ frame = loaded_dataset [cumulative + i ]
1020
+ assert frame ["frame_index" ].item () == i , f"Frame { cumulative + i } has wrong frame_index"
1021
+ assert frame ["episode_index" ].item () == ep_idx , (
1022
+ f"Frame { cumulative + i } has wrong episode_index"
1023
+ )
1024
+
1025
+ cumulative += ep_length
1026
+
1027
+
1028
+ def test_task_indexing_and_validation (tmp_path , empty_lerobot_dataset_factory ):
1029
+ """Test that tasks are properly indexed and retrievable."""
1030
+ features = {"state" : {"dtype" : "float32" , "shape" : (1 ,), "names" : None }}
1031
+ dataset = empty_lerobot_dataset_factory (root = tmp_path / "test" , features = features , use_videos = False )
1032
+
1033
+ # Use multiple tasks, including repeated ones
1034
+ tasks = ["pick" , "place" , "pick" , "navigate" , "place" ]
1035
+ unique_tasks = list (set (tasks )) # ["pick", "place", "navigate"]
1036
+ frames_per_episode = [5 , 8 , 3 , 10 , 6 ]
1037
+
1038
+ for episode_idx , task in enumerate (tasks ):
1039
+ for _ in range (frames_per_episode [episode_idx ]):
1040
+ dataset .add_frame ({"state" : torch .randn (1 ), "task" : task })
1041
+ dataset .save_episode ()
1042
+
1043
+ dataset .finalize ()
1044
+
1045
+ loaded_dataset = LeRobotDataset (dataset .repo_id , root = dataset .root )
1046
+
1047
+ # Check that all unique tasks are in the tasks metadata
1048
+ stored_tasks = set (loaded_dataset .meta .tasks .index )
1049
+ assert stored_tasks == set (unique_tasks ), f"Stored tasks { stored_tasks } != expected { set (unique_tasks )} "
1050
+
1051
+ # Check that task indices are consistent
1052
+ cumulative = 0
1053
+ for episode_idx , expected_task in enumerate (tasks ):
1054
+ episode_metadata = loaded_dataset .meta .episodes [episode_idx ]
1055
+ assert episode_metadata ["tasks" ] == [expected_task ]
1056
+
1057
+ # Check frames in this episode have correct task
1058
+ for i in range (frames_per_episode [episode_idx ]):
1059
+ frame = loaded_dataset [cumulative + i ]
1060
+ assert frame ["task" ] == expected_task , f"Frame { cumulative + i } has wrong task"
1061
+
1062
+ # Check task_index consistency
1063
+ expected_task_index = loaded_dataset .meta .get_task_index (expected_task )
1064
+ assert frame ["task_index" ].item () == expected_task_index
1065
+
1066
+ cumulative += frames_per_episode [episode_idx ]
1067
+
1068
+ # Check total number of tasks
1069
+ assert loaded_dataset .meta .total_tasks == len (unique_tasks )
0 commit comments