9
9
from modalities .dataloader .filter_packed_data import filter_dataset
10
10
11
11
12
- def test_creates_output_file (tmp_path : Path , packed_data_paths : Path ):
12
+ def test_creates_output_file (tmp_path : Path , packed_data_path : Path ):
13
13
output_path = Path (tmp_path , "output.pbin" )
14
14
filter_dataset (
15
- src_path = packed_data_paths , dst_path = output_path , filter_func = accept_even_indices , sample_key = "input_ids"
15
+ src_path = packed_data_path , dst_path = output_path , filter_func = accept_even_indices , sample_key = "input_ids"
16
16
)
17
17
assert output_path .exists ()
18
18
19
19
20
- def test_filtered_data_has_expected_length (tmp_path : Path , packed_data_paths : Path ):
20
+ def test_filtered_data_has_expected_length (tmp_path : Path , packed_data_path : Path ):
21
21
output_path = Path (tmp_path , "output.pbin" )
22
22
filter_dataset (
23
- src_path = packed_data_paths , dst_path = output_path , filter_func = accept_even_indices , sample_key = "input_ids"
23
+ src_path = packed_data_path , dst_path = output_path , filter_func = accept_even_indices , sample_key = "input_ids"
24
24
)
25
- original_data = PackedMemMapDatasetBase (packed_data_paths , sample_key = "input_ids" )
25
+ original_data = PackedMemMapDatasetBase (packed_data_path , sample_key = "input_ids" )
26
26
filtered_data = PackedMemMapDatasetBase (output_path , sample_key = "input_ids" )
27
27
assert (
28
28
len (filtered_data ) == len (original_data ) // 2 + len (original_data ) % 2
@@ -39,22 +39,20 @@ def test_filtered_data_has_expected_content(tmp_path: Path, dummy_packed_data_pa
39
39
assert filtered_data [1 ]["input_ids" ].tolist () == list (range (64 // 4 , (64 + 12 ) // 4 ))
40
40
41
41
42
- def test_always_true_filtered_data_has_identical_file_hash (tmp_path : Path , packed_data_paths : Path ):
42
+ def test_always_true_filtered_data_has_identical_file_hash (tmp_path : Path , packed_data_path : Path ):
43
43
output_path = Path (tmp_path , "output.pbin" )
44
- filter_dataset (src_path = packed_data_paths , dst_path = output_path , filter_func = lambda x : True , sample_key = "input_ids" )
45
- with open (packed_data_paths , "rb" ) as f_in , open (output_path , "rb" ) as f_out :
44
+ filter_dataset (src_path = packed_data_path , dst_path = output_path , filter_func = lambda x : True , sample_key = "input_ids" )
45
+ with open (packed_data_path , "rb" ) as f_in , open (output_path , "rb" ) as f_out :
46
46
original_hash = hashlib .sha256 (f_in .read ()).hexdigest ()
47
47
filtered_hash = hashlib .sha256 (f_out .read ()).hexdigest ()
48
48
assert (
49
49
original_hash == filtered_hash
50
50
), "Filtered data should have the same hash as the original data when no filtering is applied."
51
51
52
52
53
- def test_always_false_filtered_data_produces_valid_file (tmp_path : Path , packed_data_paths : Path ):
53
+ def test_always_false_filtered_data_produces_valid_file (tmp_path : Path , packed_data_path : Path ):
54
54
output_path = Path (tmp_path , "output.pbin" )
55
- filter_dataset (
56
- src_path = packed_data_paths , dst_path = output_path , filter_func = lambda x : False , sample_key = "input_ids"
57
- )
55
+ filter_dataset (src_path = packed_data_path , dst_path = output_path , filter_func = lambda x : False , sample_key = "input_ids" )
58
56
filtered_data = PackedMemMapDatasetBase (output_path , sample_key = "input_ids" )
59
57
assert len (filtered_data ) == 0 , "Filtered data should be empty when all samples are filtered out."
60
58
assert output_path .stat ().st_size > 0 , "Output file should not be empty even if no samples are included."
@@ -74,6 +72,6 @@ def accept_even_indices(idx_content: tuple[int, dict[str, NDArray[np.int_]]]) ->
74
72
75
73
76
74
@pytest .fixture (params = [0 , 1 ])
77
- def packed_data_paths (dummy_packed_data_path : Path , request : pytest .FixtureRequest ) -> Path :
78
- path_options = [dummy_packed_data_path , Path ("tests" , " data" , " datasets" , " lorem_ipsum_long.pbin" )]
75
+ def packed_data_path (dummy_packed_data_path : Path , request : pytest .FixtureRequest ) -> Path :
76
+ path_options = [dummy_packed_data_path , Path ("tests/ data/ datasets/ lorem_ipsum_long.pbin" )]
79
77
return path_options [request .param ]
0 commit comments