Skip to content

Commit 1dfcd08

Browse files
committed
refactor(data): Some refactorings and additional tests for packed data filtering.
1 parent ef71346 commit 1dfcd08

File tree

3 files changed

+46
-13
lines changed

3 files changed

+46
-13
lines changed

src/modalities/dataloader/create_packed_data.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,7 @@ def _write_batch(
258258
# write index
259259
f.write(pickle.dumps(index_list))
260260

261-
_update_data_length_in_pre_allocated_header(dst_path, index_list)
261+
update_data_length_in_pre_allocated_header(dst_path, index_list)
262262

263263
return writer
264264

@@ -324,10 +324,10 @@ def _process_line(self, line: str, process_id: int) -> bytes:
324324
return token_byte_string
325325

326326

327-
def _update_data_length_in_pre_allocated_header(dst_path: Path, index_list: list[tuple[int, int]]):
327+
def update_data_length_in_pre_allocated_header(dst_path: Path, index_list: list[tuple[int, int]]):
328328
# Update the length of the data section in the pre-allocated header of the destination file.
329329
# The data segment length is sum of the starting position and the length of the last document.
330-
length_of_byte_encoded_data_section = index_list[-1][0] + index_list[-1][1]
330+
length_of_byte_encoded_data_section = index_list[-1][0] + index_list[-1][1] if len(index_list) > 0 else 0
331331
data_section_length_in_bytes = length_of_byte_encoded_data_section.to_bytes(
332332
EmbeddedStreamData.DATA_SECTION_LENGTH_IN_BYTES, byteorder="little"
333333
)

src/modalities/dataloader/filter_packed_data.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,13 @@
66
from numpy.typing import NDArray
77
from tqdm import tqdm
88

9-
from modalities.dataloader.create_packed_data import EmbeddedStreamData, _update_data_length_in_pre_allocated_header
9+
from modalities.dataloader.create_packed_data import EmbeddedStreamData, update_data_length_in_pre_allocated_header
1010
from modalities.dataloader.dataset import PackedMemMapDatasetBase
1111

1212

1313
def filter_dataset(
14-
dst_path: Path,
1514
src_path: Path,
15+
dst_path: Path,
1616
filter_func: Callable[[tuple[int, dict[str, NDArray[np.int_]]]], bool],
1717
sample_key: str = "input_ids",
1818
) -> None:
@@ -41,6 +41,7 @@ def filter_dataset(
4141
# When we load the file, we add the header size to the offset
4242
curr_offset = 0
4343

44+
# Provide sample and its index (via enumerate) to the filter function.
4445
for _, entry in filter(filter_func, enumerate(tqdm(source_data, desc="Filtering samples"))):
4546
tokens: NDArray[np.int_] = entry[sample_key].astype(tok_type)
4647
tokens = tokens.astype(tokens.dtype.newbyteorder("<"))
@@ -49,7 +50,7 @@ def filter_dataset(
4950
segment_length = len(tokens_as_bytes)
5051
index_list.append((curr_offset, segment_length))
5152
curr_offset += segment_length
52-
# write index
53+
# Write index at end of the file.
5354
f_out.write(pickle.dumps(index_list))
5455

55-
_update_data_length_in_pre_allocated_header(dst_path, index_list)
56+
update_data_length_in_pre_allocated_header(dst_path, index_list)
Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,65 @@
1+
import hashlib
12
from pathlib import Path
23

34
import numpy as np
5+
import pytest
46
from numpy.typing import NDArray
57

68
from modalities.dataloader.dataset import PackedMemMapDatasetBase
79
from modalities.dataloader.filter_packed_data import filter_dataset
810

911

10-
def test_creates_output_file(tmp_path: Path, dummy_packed_data_path: Path):
12+
def test_creates_output_file(tmp_path: Path, packed_data_paths: Path):
1113
output_path = Path(tmp_path, "output.pbin")
1214
filter_dataset(
13-
dst_path=output_path, src_path=dummy_packed_data_path, filter_func=accept_even_indices, sample_key="input_ids"
15+
src_path=packed_data_paths, dst_path=output_path, filter_func=accept_even_indices, sample_key="input_ids"
1416
)
1517
assert output_path.exists()
1618

1719

18-
def test_filtered_data_has_expected_length(tmp_path: Path, dummy_packed_data_path: Path):
20+
def test_filtered_data_has_expected_length(tmp_path: Path, packed_data_paths: Path):
1921
output_path = Path(tmp_path, "output.pbin")
2022
filter_dataset(
21-
dst_path=output_path, src_path=dummy_packed_data_path, filter_func=accept_even_indices, sample_key="input_ids"
23+
src_path=packed_data_paths, dst_path=output_path, filter_func=accept_even_indices, sample_key="input_ids"
2224
)
25+
original_data = PackedMemMapDatasetBase(packed_data_paths, sample_key="input_ids")
2326
filtered_data = PackedMemMapDatasetBase(output_path, sample_key="input_ids")
24-
assert len(filtered_data) == 2
27+
assert (
28+
len(filtered_data) == len(original_data) // 2 + len(original_data) % 2
29+
), "Filtered data length should be half of the original data length (rounded up)."
2530

2631

2732
def test_filtered_data_has_expected_content(tmp_path: Path, dummy_packed_data_path: Path):
2833
output_path = Path(tmp_path, "output.pbin")
2934
filter_dataset(
30-
dst_path=output_path, src_path=dummy_packed_data_path, filter_func=accept_even_indices, sample_key="input_ids"
35+
src_path=dummy_packed_data_path, dst_path=output_path, filter_func=accept_even_indices, sample_key="input_ids"
3136
)
3237
filtered_data = PackedMemMapDatasetBase(output_path, sample_key="input_ids")
3338
assert filtered_data[0]["input_ids"].tolist() == list(range(24 // 4))
3439
assert filtered_data[1]["input_ids"].tolist() == list(range(64 // 4, (64 + 12) // 4))
3540

3641

42+
def test_always_true_filtered_data_has_identical_file_hash(tmp_path: Path, packed_data_paths: Path):
43+
output_path = Path(tmp_path, "output.pbin")
44+
filter_dataset(src_path=packed_data_paths, dst_path=output_path, filter_func=lambda x: True, sample_key="input_ids")
45+
with open(packed_data_paths, "rb") as f_in, open(output_path, "rb") as f_out:
46+
original_hash = hashlib.sha256(f_in.read()).hexdigest()
47+
filtered_hash = hashlib.sha256(f_out.read()).hexdigest()
48+
assert (
49+
original_hash == filtered_hash
50+
), "Filtered data should have the same hash as the original data when no filtering is applied."
51+
52+
53+
def test_always_false_filtered_data_produces_valid_file(tmp_path: Path, packed_data_paths: Path):
54+
output_path = Path(tmp_path, "output.pbin")
55+
filter_dataset(
56+
src_path=packed_data_paths, dst_path=output_path, filter_func=lambda x: False, sample_key="input_ids"
57+
)
58+
filtered_data = PackedMemMapDatasetBase(output_path, sample_key="input_ids")
59+
assert len(filtered_data) == 0, "Filtered data should be empty when all samples are filtered out."
60+
assert output_path.stat().st_size > 0, "Output file should not be empty even if no samples are included."
61+
62+
3763
def accept_even_indices(idx_content: tuple[int, dict[str, NDArray[np.int_]]]) -> bool:
3864
"""
3965
Filter function that accepts only even indices.
@@ -45,3 +71,9 @@ def accept_even_indices(idx_content: tuple[int, dict[str, NDArray[np.int_]]]) ->
4571
"""
4672
idx, _ = idx_content
4773
return idx % 2 == 0
74+
75+
76+
@pytest.fixture(params=[0, 1])
77+
def packed_data_paths(dummy_packed_data_path: Path, request: pytest.FixtureRequest) -> Path:
78+
path_options = [dummy_packed_data_path, Path("tests", "data", "datasets", "lorem_ipsum_long.pbin")]
79+
return path_options[request.param]

0 commit comments

Comments
 (0)