1
+ import hashlib
1
2
from pathlib import Path
2
3
3
4
import numpy as np
5
+ import pytest
4
6
from numpy .typing import NDArray
5
7
6
8
from modalities .dataloader .dataset import PackedMemMapDatasetBase
7
9
from modalities .dataloader .filter_packed_data import filter_dataset
8
10
9
11
10
- def test_creates_output_file (tmp_path : Path , dummy_packed_data_path : Path ):
12
+ def test_creates_output_file (tmp_path : Path , packed_data_paths : Path ):
11
13
output_path = Path (tmp_path , "output.pbin" )
12
14
filter_dataset (
13
- dst_path = output_path , src_path = dummy_packed_data_path , filter_func = accept_even_indices , sample_key = "input_ids"
15
+ src_path = packed_data_paths , dst_path = output_path , filter_func = accept_even_indices , sample_key = "input_ids"
14
16
)
15
17
assert output_path .exists ()
16
18
17
19
18
- def test_filtered_data_has_expected_length (tmp_path : Path , dummy_packed_data_path : Path ):
20
+ def test_filtered_data_has_expected_length (tmp_path : Path , packed_data_paths : Path ):
19
21
output_path = Path (tmp_path , "output.pbin" )
20
22
filter_dataset (
21
- dst_path = output_path , src_path = dummy_packed_data_path , filter_func = accept_even_indices , sample_key = "input_ids"
23
+ src_path = packed_data_paths , dst_path = output_path , filter_func = accept_even_indices , sample_key = "input_ids"
22
24
)
25
+ original_data = PackedMemMapDatasetBase (packed_data_paths , sample_key = "input_ids" )
23
26
filtered_data = PackedMemMapDatasetBase (output_path , sample_key = "input_ids" )
24
- assert len (filtered_data ) == 2
27
+ assert (
28
+ len (filtered_data ) == len (original_data ) // 2 + len (original_data ) % 2
29
+ ), "Filtered data length should be half of the original data length (rounded up)."
25
30
26
31
27
32
def test_filtered_data_has_expected_content (tmp_path : Path , dummy_packed_data_path : Path ):
28
33
output_path = Path (tmp_path , "output.pbin" )
29
34
filter_dataset (
30
- dst_path = output_path , src_path = dummy_packed_data_path , filter_func = accept_even_indices , sample_key = "input_ids"
35
+ src_path = dummy_packed_data_path , dst_path = output_path , filter_func = accept_even_indices , sample_key = "input_ids"
31
36
)
32
37
filtered_data = PackedMemMapDatasetBase (output_path , sample_key = "input_ids" )
33
38
assert filtered_data [0 ]["input_ids" ].tolist () == list (range (24 // 4 ))
34
39
assert filtered_data [1 ]["input_ids" ].tolist () == list (range (64 // 4 , (64 + 12 ) // 4 ))
35
40
36
41
42
+ def test_always_true_filtered_data_has_identical_file_hash (tmp_path : Path , packed_data_paths : Path ):
43
+ output_path = Path (tmp_path , "output.pbin" )
44
+ filter_dataset (src_path = packed_data_paths , dst_path = output_path , filter_func = lambda x : True , sample_key = "input_ids" )
45
+ with open (packed_data_paths , "rb" ) as f_in , open (output_path , "rb" ) as f_out :
46
+ original_hash = hashlib .sha256 (f_in .read ()).hexdigest ()
47
+ filtered_hash = hashlib .sha256 (f_out .read ()).hexdigest ()
48
+ assert (
49
+ original_hash == filtered_hash
50
+ ), "Filtered data should have the same hash as the original data when no filtering is applied."
51
+
52
+
53
+ def test_always_false_filtered_data_produces_valid_file (tmp_path : Path , packed_data_paths : Path ):
54
+ output_path = Path (tmp_path , "output.pbin" )
55
+ filter_dataset (
56
+ src_path = packed_data_paths , dst_path = output_path , filter_func = lambda x : False , sample_key = "input_ids"
57
+ )
58
+ filtered_data = PackedMemMapDatasetBase (output_path , sample_key = "input_ids" )
59
+ assert len (filtered_data ) == 0 , "Filtered data should be empty when all samples are filtered out."
60
+ assert output_path .stat ().st_size > 0 , "Output file should not be empty even if no samples are included."
61
+
62
+
37
63
def accept_even_indices (idx_content : tuple [int , dict [str , NDArray [np .int_ ]]]) -> bool :
38
64
"""
39
65
Filter function that accepts only even indices.
@@ -45,3 +71,9 @@ def accept_even_indices(idx_content: tuple[int, dict[str, NDArray[np.int_]]]) ->
45
71
"""
46
72
idx , _ = idx_content
47
73
return idx % 2 == 0
74
+
75
+
76
+ @pytest .fixture (params = [0 , 1 ])
77
+ def packed_data_paths (dummy_packed_data_path : Path , request : pytest .FixtureRequest ) -> Path :
78
+ path_options = [dummy_packed_data_path , Path ("tests" , "data" , "datasets" , "lorem_ipsum_long.pbin" )]
79
+ return path_options [request .param ]
0 commit comments