Skip to content

Commit 5ae2833

Browse files
committed
refactor(data): Minor refactorings to address PR comments.
1 parent 1dfcd08 commit 5ae2833

File tree

2 files changed

+20
-15
lines changed

2 files changed

+20
-15
lines changed

src/modalities/dataloader/create_packed_data.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,14 @@ def _process_line(self, line: str, process_id: int) -> bytes:
327327
def update_data_length_in_pre_allocated_header(dst_path: Path, index_list: list[tuple[int, int]]):
328328
# Update the length of the data section in the pre-allocated header of the destination file.
329329
# The data segment length is sum of the starting position and the length of the last document.
330-
length_of_byte_encoded_data_section = index_list[-1][0] + index_list[-1][1] if len(index_list) > 0 else 0
330+
if len(index_list) > 0:
331+
length_of_byte_encoded_data_section = index_list[-1][0] + index_list[-1][1]
332+
else:
333+
length_of_byte_encoded_data_section = 0
334+
logger.warning(
335+
f'No data was written to the file "{dst_path}". '
336+
"This can happen if the input file is empty or all samples were filtered out."
337+
)
331338
data_section_length_in_bytes = length_of_byte_encoded_data_section.to_bytes(
332339
EmbeddedStreamData.DATA_SECTION_LENGTH_IN_BYTES, byteorder="little"
333340
)

tests/dataloader/test_filter_packed_data.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,20 +9,20 @@
99
from modalities.dataloader.filter_packed_data import filter_dataset
1010

1111

12-
def test_creates_output_file(tmp_path: Path, packed_data_paths: Path):
12+
def test_creates_output_file(tmp_path: Path, packed_data_path: Path):
1313
output_path = Path(tmp_path, "output.pbin")
1414
filter_dataset(
15-
src_path=packed_data_paths, dst_path=output_path, filter_func=accept_even_indices, sample_key="input_ids"
15+
src_path=packed_data_path, dst_path=output_path, filter_func=accept_even_indices, sample_key="input_ids"
1616
)
1717
assert output_path.exists()
1818

1919

20-
def test_filtered_data_has_expected_length(tmp_path: Path, packed_data_paths: Path):
20+
def test_filtered_data_has_expected_length(tmp_path: Path, packed_data_path: Path):
2121
output_path = Path(tmp_path, "output.pbin")
2222
filter_dataset(
23-
src_path=packed_data_paths, dst_path=output_path, filter_func=accept_even_indices, sample_key="input_ids"
23+
src_path=packed_data_path, dst_path=output_path, filter_func=accept_even_indices, sample_key="input_ids"
2424
)
25-
original_data = PackedMemMapDatasetBase(packed_data_paths, sample_key="input_ids")
25+
original_data = PackedMemMapDatasetBase(packed_data_path, sample_key="input_ids")
2626
filtered_data = PackedMemMapDatasetBase(output_path, sample_key="input_ids")
2727
assert (
2828
len(filtered_data) == len(original_data) // 2 + len(original_data) % 2
@@ -39,22 +39,20 @@ def test_filtered_data_has_expected_content(tmp_path: Path, dummy_packed_data_pa
3939
assert filtered_data[1]["input_ids"].tolist() == list(range(64 // 4, (64 + 12) // 4))
4040

4141

42-
def test_always_true_filtered_data_has_identical_file_hash(tmp_path: Path, packed_data_paths: Path):
42+
def test_always_true_filtered_data_has_identical_file_hash(tmp_path: Path, packed_data_path: Path):
4343
output_path = Path(tmp_path, "output.pbin")
44-
filter_dataset(src_path=packed_data_paths, dst_path=output_path, filter_func=lambda x: True, sample_key="input_ids")
45-
with open(packed_data_paths, "rb") as f_in, open(output_path, "rb") as f_out:
44+
filter_dataset(src_path=packed_data_path, dst_path=output_path, filter_func=lambda x: True, sample_key="input_ids")
45+
with open(packed_data_path, "rb") as f_in, open(output_path, "rb") as f_out:
4646
original_hash = hashlib.sha256(f_in.read()).hexdigest()
4747
filtered_hash = hashlib.sha256(f_out.read()).hexdigest()
4848
assert (
4949
original_hash == filtered_hash
5050
), "Filtered data should have the same hash as the original data when no filtering is applied."
5151

5252

53-
def test_always_false_filtered_data_produces_valid_file(tmp_path: Path, packed_data_paths: Path):
53+
def test_always_false_filtered_data_produces_valid_file(tmp_path: Path, packed_data_path: Path):
5454
output_path = Path(tmp_path, "output.pbin")
55-
filter_dataset(
56-
src_path=packed_data_paths, dst_path=output_path, filter_func=lambda x: False, sample_key="input_ids"
57-
)
55+
filter_dataset(src_path=packed_data_path, dst_path=output_path, filter_func=lambda x: False, sample_key="input_ids")
5856
filtered_data = PackedMemMapDatasetBase(output_path, sample_key="input_ids")
5957
assert len(filtered_data) == 0, "Filtered data should be empty when all samples are filtered out."
6058
assert output_path.stat().st_size > 0, "Output file should not be empty even if no samples are included."
@@ -74,6 +72,6 @@ def accept_even_indices(idx_content: tuple[int, dict[str, NDArray[np.int_]]]) ->
7472

7573

7674
@pytest.fixture(params=[0, 1])
77-
def packed_data_paths(dummy_packed_data_path: Path, request: pytest.FixtureRequest) -> Path:
78-
path_options = [dummy_packed_data_path, Path("tests", "data", "datasets", "lorem_ipsum_long.pbin")]
75+
def packed_data_path(dummy_packed_data_path: Path, request: pytest.FixtureRequest) -> Path:
76+
path_options = [dummy_packed_data_path, Path("tests/data/datasets/lorem_ipsum_long.pbin")]
7977
return path_options[request.param]

0 commit comments

Comments
 (0)