Skip to content

feat: use content defined chunking #7589

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 13 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/build_documentation.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ jobs:
commit_sha: ${{ github.sha }}
package: datasets
notebook_folder: datasets_doc
additional_args: --extra-index-url https://pypi.anaconda.org/scientific-python-nightly-wheels/simple
secrets:
token: ${{ secrets.HUGGINGFACE_PUSH }}
hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }}
23 changes: 15 additions & 8 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ on:
branches:
- main
push:
branches:
- main
- ci-*
# branches:
# - main
# - ci-*

env:
CI_HEADERS: ${{ secrets.CI_HEADERS }}
Expand All @@ -25,6 +25,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install --extra-index-url https://pypi.anaconda.org/scientific-python-nightly-wheels/simple "pyarrow>=21.0.0.dev"
pip install .[quality]
- name: Check quality
run: |
Expand Down Expand Up @@ -56,13 +57,15 @@ jobs:
- name: Install uv
run: pip install --upgrade uv
- name: Install dependencies
run: uv pip install --system "datasets[tests] @ ."
run: |
uv pip install --system --extra-index-url https://pypi.anaconda.org/scientific-python-nightly-wheels/simple "pyarrow>=21.0.0.dev"
uv pip install --system "datasets[tests] @ ."
- name: Install dependencies (latest versions)
if: ${{ matrix.deps_versions == 'deps-latest' }}
run: uv pip install --system --upgrade pyarrow huggingface-hub "dill<0.3.9"
run: uv pip install --system --upgrade huggingface-hub "dill<0.3.9"
- name: Install dependencies (minimum versions)
if: ${{ matrix.deps_versions != 'deps-latest' }}
run: uv pip install --system pyarrow==15.0.0 huggingface-hub==0.24.7 transformers dill==0.3.1.1
run: uv pip install --system huggingface-hub==0.24.7 transformers dill==0.3.1.1
- name: Test with pytest
run: |
python -m pytest -rfExX -m ${{ matrix.test }} -n 2 --dist loadfile -sv ./tests/
Expand All @@ -89,7 +92,9 @@ jobs:
- name: Install uv
run: pip install --upgrade uv
- name: Install dependencies
run: uv pip install --system "datasets[tests] @ ."
run: |
uv pip install --system --extra-index-url https://pypi.anaconda.org/scientific-python-nightly-wheels/simple "pyarrow>=21.0.0.dev"
uv pip install --system "datasets[tests] @ ."
- name: Test with pytest
run: |
python -m pytest -rfExX -m ${{ matrix.test }} -n 2 --dist loadfile -sv ./tests/
Expand All @@ -116,7 +121,9 @@ jobs:
- name: Install uv
run: pip install --upgrade uv
- name: Install dependencies
run: uv pip install --system "datasets[tests_numpy2] @ ."
run: |
uv pip install --system --extra-index-url https://pypi.anaconda.org/scientific-python-nightly-wheels/simple "pyarrow>=21.0.0.dev"
uv pip install --system "datasets[tests_numpy2] @ ."
- name: Test with pytest
run: |
python -m pytest -rfExX -m ${{ matrix.test }} -n 2 --dist loadfile -sv ./tests/
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@
"numpy>=1.17",
# Backend and serialization.
# Minimum 15.0.0 to be able to cast dictionary types to their underlying types
"pyarrow>=15.0.0",
"pyarrow>=21.0.0.dev",
# For smart caching dataset processing
"dill>=0.3.0,<0.3.9", # tmp pin until dill has official support for determinism see https://github.com/uqfoundation/dill/issues/19
# For performance gains with apache arrow
Expand Down
35 changes: 26 additions & 9 deletions src/datasets/arrow_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,8 +341,6 @@ def __init__(
class ArrowWriter:
"""Shuffles and writes Examples to Arrow files."""

_WRITER_CLASS = pa.RecordBatchStreamWriter

def __init__(
self,
schema: Optional[pa.Schema] = None,
Expand Down Expand Up @@ -430,8 +428,9 @@ def close(self):
if self._closable_stream and not self.stream.closed:
self.stream.close() # This also closes self.pa_writer if it is opened

def _build_writer(self, inferred_schema: pa.Schema):
def _build_schema(self, inferred_schema: pa.Schema):
schema = self.schema
features = self._features
inferred_features = Features.from_arrow_schema(inferred_schema)
if self._features is not None:
if self.update_features: # keep original features it they match, or update them
Expand All @@ -441,19 +440,24 @@ def _build_writer(self, inferred_schema: pa.Schema):
if name in fields:
if inferred_field == fields[name]:
inferred_features[name] = self._features[name]
self._features = inferred_features
features = inferred_features
schema: pa.Schema = inferred_schema
else:
self._features = inferred_features
features = inferred_features
schema: pa.Schema = inferred_features.arrow_schema

if self.disable_nullable:
schema = pa.schema(pa.field(field.name, field.type, nullable=False) for field in schema)
if self.with_metadata:
schema = schema.with_metadata(self._build_metadata(DatasetInfo(features=self._features), self.fingerprint))
schema = schema.with_metadata(self._build_metadata(DatasetInfo(features=features), self.fingerprint))
else:
schema = schema.with_metadata({})
self._schema = schema
self.pa_writer = self._WRITER_CLASS(self.stream, schema)

return schema, features

def _build_writer(self, inferred_schema: pa.Schema):
self._schema, self._features = self._build_schema(inferred_schema)
self.pa_writer = pa.RecordBatchStreamWriter(self.stream, self._schema)

@property
def schema(self):
Expand Down Expand Up @@ -674,4 +678,17 @@ def finalize(self, close_stream=True):


class ParquetWriter(ArrowWriter):
_WRITER_CLASS = pq.ParquetWriter
def __init__(self, *args, use_content_defined_chunking=None, **kwargs):
super().__init__(*args, **kwargs)
self.use_content_defined_chunking = (
config.DEFAULT_CDC_OPTIONS if use_content_defined_chunking is None else use_content_defined_chunking
)

def _build_writer(self, inferred_schema: pa.Schema):
self._schema, self._features = self._build_schema(inferred_schema)
self.pa_writer = pq.ParquetWriter(
self.stream, self._schema, use_content_defined_chunking=self.use_content_defined_chunking
)
self.pa_writer.add_key_value_metadata(
{"content_defined_chunking": json.dumps(self.use_content_defined_chunking)}
)
4 changes: 3 additions & 1 deletion src/datasets/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,9 @@

# Batch size constants. For more info, see:
# https://github.com/apache/arrow/blob/master/docs/source/cpp/arrays.rst#size-limitations-and-recommendations)
DEFAULT_MAX_BATCH_SIZE = 1000
DEFAULT_MAX_BATCH_SIZE = 1024 * 1024

DEFAULT_CDC_OPTIONS = {"min_chunk_size": 256 * 1024, "max_chunk_size": 1024 * 1024, "norm_level": 0}

# Size of the preloaded record batch in `Dataset.__iter__`
ARROW_READER_BATCH_SIZE_IN_DATASET_ITER = 10
Expand Down
31 changes: 27 additions & 4 deletions src/datasets/io/parquet.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import os
from typing import BinaryIO, Optional, Union

Expand Down Expand Up @@ -77,25 +78,42 @@ def __init__(
path_or_buf: Union[PathLike, BinaryIO],
batch_size: Optional[int] = None,
storage_options: Optional[dict] = None,
use_content_defined_chunking: Optional[dict] = None,
**parquet_writer_kwargs,
):
self.dataset = dataset
self.path_or_buf = path_or_buf
self.batch_size = batch_size or get_writer_batch_size(dataset.features)
self.storage_options = storage_options or {}
self.parquet_writer_kwargs = parquet_writer_kwargs
self.use_content_defined_chunking = use_content_defined_chunking

def write(self) -> int:
batch_size = self.batch_size if self.batch_size else config.DEFAULT_MAX_BATCH_SIZE
use_content_defined_chunking = (
self.use_content_defined_chunking if self.use_content_defined_chunking else config.DEFAULT_CDC_OPTIONS
)

if isinstance(self.path_or_buf, (str, bytes, os.PathLike)):
with fsspec.open(self.path_or_buf, "wb", **(self.storage_options or {})) as buffer:
written = self._write(file_obj=buffer, batch_size=batch_size, **self.parquet_writer_kwargs)
written = self._write(
file_obj=buffer,
batch_size=batch_size,
use_content_defined_chunking=use_content_defined_chunking,
**self.parquet_writer_kwargs,
)
else:
written = self._write(file_obj=self.path_or_buf, batch_size=batch_size, **self.parquet_writer_kwargs)
written = self._write(
file_obj=self.path_or_buf,
batch_size=batch_size,
use_content_defined_chunking=use_content_defined_chunking,
**self.parquet_writer_kwargs,
)
return written

def _write(self, file_obj: BinaryIO, batch_size: int, **parquet_writer_kwargs) -> int:
def _write(
self, file_obj: BinaryIO, batch_size: int, use_content_defined_chunking: bool | dict, **parquet_writer_kwargs
) -> int:
"""Writes the pyarrow table as Parquet to a binary file handle.

Caller is responsible for opening and closing the handle.
Expand All @@ -104,7 +122,9 @@ def _write(self, file_obj: BinaryIO, batch_size: int, **parquet_writer_kwargs) -
_ = parquet_writer_kwargs.pop("path_or_buf", None)
schema = self.dataset.features.arrow_schema

writer = pq.ParquetWriter(file_obj, schema=schema, **parquet_writer_kwargs)
writer = pq.ParquetWriter(
file_obj, schema=schema, use_content_defined_chunking=use_content_defined_chunking, **parquet_writer_kwargs
)

for offset in hf_tqdm(
range(0, len(self.dataset), batch_size),
Expand All @@ -118,5 +138,8 @@ def _write(self, file_obj: BinaryIO, batch_size: int, **parquet_writer_kwargs) -
)
writer.write_table(batch)
written += batch.nbytes

# TODO(kszucs): we may want to persist multiple parameters
writer.add_key_value_metadata({"content_defined_chunking": json.dumps(use_content_defined_chunking)})
writer.close()
return written
45 changes: 45 additions & 0 deletions tests/io/test_parquet.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import json
import unittest.mock

import fsspec
import pyarrow.parquet as pq
import pytest
Expand Down Expand Up @@ -199,6 +202,48 @@ def test_parquet_write(dataset, tmp_path):
assert dataset.data.table == output_table


def test_parquet_write_uses_content_defined_chunking(dataset, tmp_path):
assert config.DEFAULT_CDC_OPTIONS == {
"min_chunk_size": 256 * 1024, # 256 KiB
"max_chunk_size": 1024 * 1024, # 1 MiB
"norm_level": 0,
}

with unittest.mock.patch("pyarrow.parquet.ParquetWriter") as MockWriter:
writer = ParquetDatasetWriter(dataset, tmp_path / "foo.parquet")
writer.write()
assert MockWriter.call_count == 1
_, kwargs = MockWriter.call_args
# Save or check the arguments as needed
assert "use_content_defined_chunking" in kwargs
assert kwargs["use_content_defined_chunking"] == config.DEFAULT_CDC_OPTIONS


custom_cdc_options = {
"min_chunk_size": 128 * 1024, # 128 KiB
"max_chunk_size": 512 * 1024, # 512 KiB
"norm_level": 1,
}


@pytest.mark.parametrize(
("cdc_options", "expected_options"), [(None, config.DEFAULT_CDC_OPTIONS), (custom_cdc_options, custom_cdc_options)]
)
def test_parquet_writer_persist_cdc_options_as_metadata(dataset, tmp_path, cdc_options, expected_options):
# write the dataset to parquet with the default CDC options
writer = ParquetDatasetWriter(dataset, tmp_path / "foo.parquet", use_content_defined_chunking=cdc_options)
assert writer.write() > 0

# read the parquet KV metadata
metadata = pq.read_metadata(tmp_path / "foo.parquet")
key_value_metadata = metadata.metadata

# check that the content defined chunking options are persisted
assert b"content_defined_chunking" in key_value_metadata
json_encoded_options = key_value_metadata[b"content_defined_chunking"].decode("utf-8")
assert json.loads(json_encoded_options) == expected_options


def test_dataset_to_parquet_keeps_features(shared_datadir, tmp_path):
image_path = str(shared_datadir / "test_image_rgb.jpg")
data = {"image": [image_path]}
Expand Down
35 changes: 35 additions & 0 deletions tests/test_arrow_writer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import copy
import json
import os
import tempfile
from unittest import TestCase
Expand All @@ -9,6 +10,7 @@
import pyarrow.parquet as pq
import pytest

from datasets import config
from datasets.arrow_writer import ArrowWriter, OptimizedTypedSequence, ParquetWriter, TypedSequence
from datasets.features import Array2D, ClassLabel, Features, Image, Value
from datasets.features.features import Array2DExtensionType, cast_to_python_objects
Expand Down Expand Up @@ -334,6 +336,39 @@ def test_parquet_writer_write():
assert pa_table.to_pydict() == {"col_1": ["foo", "bar"], "col_2": [1, 2]}


custom_cdc_options = {
"min_chunk_size": 128 * 1024, # 128 KiB
"max_chunk_size": 512 * 1024, # 512 KiB
"norm_level": 1,
}


@pytest.mark.parametrize(
("cdc_options", "expected_options"), [(None, config.DEFAULT_CDC_OPTIONS), (custom_cdc_options, custom_cdc_options)]
)
def test_parquet_write_uses_content_defined_chunking(cdc_options, expected_options):
output = pa.BufferOutputStream()
with patch("pyarrow.parquet.ParquetWriter", wraps=pq.ParquetWriter) as MockWriter:
with ParquetWriter(stream=output, use_content_defined_chunking=cdc_options) as writer:
writer.write({"col_1": "foo", "col_2": 1})
writer.write({"col_1": "bar", "col_2": 2})
writer.finalize()
assert MockWriter.call_count == 1
_, kwargs = MockWriter.call_args
assert "use_content_defined_chunking" in kwargs
assert kwargs["use_content_defined_chunking"] == expected_options

# read metadata from the output stream
with pa.input_stream(output.getvalue()) as stream:
metadata = pq.read_metadata(stream)
key_value_metadata = metadata.metadata

# check that the content defined chunking options are persisted
assert b"content_defined_chunking" in key_value_metadata
json_encoded_options = key_value_metadata[b"content_defined_chunking"].decode("utf-8")
assert json.loads(json_encoded_options) == expected_options


@require_pil
@pytest.mark.parametrize("embed_local_files", [False, True])
def test_writer_embed_local_files(tmp_path, embed_local_files):
Expand Down
Loading