From d97542517bd5362d6b55ac2aff91f2a51ecbf52e Mon Sep 17 00:00:00 2001 From: Pierre Marcenac Date: Mon, 17 Mar 2025 11:25:54 -0700 Subject: [PATCH] Read the length of the datasource from the FileInstructions to limit I/O. PiperOrigin-RevId: 737687954 --- .../core/data_sources/array_record.py | 2 +- tensorflow_datasets/core/data_sources/base.py | 19 ++++++++----------- .../core/data_sources/base_test.py | 6 ++++++ .../core/data_sources/parquet.py | 2 +- 4 files changed, 16 insertions(+), 13 deletions(-) diff --git a/tensorflow_datasets/core/data_sources/array_record.py b/tensorflow_datasets/core/data_sources/array_record.py index ee2028bb5cb..60ffb15141c 100644 --- a/tensorflow_datasets/core/data_sources/array_record.py +++ b/tensorflow_datasets/core/data_sources/array_record.py @@ -56,7 +56,7 @@ class ArrayRecordDataSource(base.BaseDataSource): length: int = dataclasses.field(init=False) def __post_init__(self): - file_instructions = base.file_instructions(self.dataset_info, self.split) + file_instructions = self.split_info.file_instructions self.data_source = array_record_data_source.ArrayRecordDataSource( file_instructions ) diff --git a/tensorflow_datasets/core/data_sources/base.py b/tensorflow_datasets/core/data_sources/base.py index 09ece9410be..35fd1514da8 100644 --- a/tensorflow_datasets/core/data_sources/base.py +++ b/tensorflow_datasets/core/data_sources/base.py @@ -45,16 +45,6 @@ def __getitems__(self, keys: Iterable[int]) -> T: """Returns the value for the given `keys`.""" -def file_instructions( - dataset_info: dataset_info_lib.DatasetInfo, - split: splits_lib.Split | None = None, -) -> list[shard_utils.FileInstruction]: - """Retrieves the file instructions from the DatasetInfo.""" - split_infos = dataset_info.splits.values() - split_dict = splits_lib.SplitDict(split_infos=split_infos) - return split_dict[split].file_instructions - - @dataclasses.dataclass class BaseDataSource(MappingView, Sequence): """Base DataSource to override all dunder methods with the deserialization. @@ -94,6 +84,13 @@ def _deserialize(self, record: Any) -> Any: return features.deserialize_example_np(record, decoders=self.decoders) # pylint: disable=attribute-error raise ValueError('No features set, cannot decode example!') + @property + def split_info(self) -> splits_lib.SplitInfo | splits_lib.SubSplitInfo: + """Returns the SplitInfo for the split.""" + split_infos = self.dataset_info.splits.values() + splits_dict = splits_lib.SplitDict(split_infos=split_infos) + return splits_dict[self.split] # will raise an error if split is not found + def __getitem__(self, key: SupportsIndex) -> Any: record = self.data_source[key.__index__()] return self._deserialize(record) @@ -133,7 +130,7 @@ def __repr__(self) -> str: ) def __len__(self) -> int: - return self.data_source.__len__() + return sum(fi.take for fi in self.split_info.file_instructions) def __iter__(self): for i in range(self.__len__()): diff --git a/tensorflow_datasets/core/data_sources/base_test.py b/tensorflow_datasets/core/data_sources/base_test.py index e4d40b36760..59de6f46cf5 100644 --- a/tensorflow_datasets/core/data_sources/base_test.py +++ b/tensorflow_datasets/core/data_sources/base_test.py @@ -94,6 +94,12 @@ def test_read_write( for i, element in enumerate(data_source): assert element == {'id': i} + # Also works on sliced splits. + data_source = builder.as_data_source(split='train[0:2]') + assert len(data_source) == 2 + data_source = builder.as_data_source(split='train[:50%]') + assert len(data_source) == 2 + _FILE_INSTRUCTIONS = [ shard_utils.FileInstruction( diff --git a/tensorflow_datasets/core/data_sources/parquet.py b/tensorflow_datasets/core/data_sources/parquet.py index 7fe8b19b85e..41022a9546d 100644 --- a/tensorflow_datasets/core/data_sources/parquet.py +++ b/tensorflow_datasets/core/data_sources/parquet.py @@ -57,7 +57,7 @@ class ParquetDataSource(base.BaseDataSource): """ParquetDataSource to read from a ParquetDataset.""" def __post_init__(self): - file_instructions = base.file_instructions(self.dataset_info, self.split) + file_instructions = self.split_info.file_instructions filenames = [ file_instruction.filename for file_instruction in file_instructions ]