Skip to content

Commit bdf87d0

Browse files
author
The TensorFlow Datasets Authors
committed
Merge pull request #11070 from lgeiger:fix-no-shuffle-beam-writer
PiperOrigin-RevId: 777974077
2 parents 1fb2e3d + 5de1d20 commit bdf87d0

File tree

4 files changed

+42
-30
lines changed

4 files changed

+42
-30
lines changed

tensorflow_datasets/core/naming.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -666,7 +666,7 @@ def sharded_filepaths_pattern(
666666
`/path/dataset_name-split.fileformat@num_shards` or
667667
`/path/dataset_name-split@num_shards.fileformat` depending on the format.
668668
If `num_shards` is not given, then it returns
669-
`/path/dataset_name-split.fileformat*`.
669+
`/path/dataset_name-split.fileformat-[0-9][0-9][0-9][0-9][0-9]-of-[0-9][0-9][0-9][0-9][0-9]`.
670670
671671
Args:
672672
num_shards: optional specification of the number of shards.
@@ -681,7 +681,7 @@ def sharded_filepaths_pattern(
681681
elif use_at_notation:
682682
replacement = '@*'
683683
else:
684-
replacement = '*'
684+
replacement = '-[0-9][0-9][0-9][0-9][0-9]-of-[0-9][0-9][0-9][0-9][0-9]'
685685
return _replace_shard_pattern(os.fspath(a_filepath), replacement)
686686

687687
def glob_pattern(self, num_shards: int | None = None) -> str:

tensorflow_datasets/core/naming_test.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -459,7 +459,7 @@ def test_sharded_file_template_shard_index():
459459
)
460460
assert (
461461
os.fspath(template.sharded_filepaths_pattern())
462-
== '/my/path/data/mnist-train.tfrecord*'
462+
== '/my/path/data/mnist-train.tfrecord-[0-9][0-9][0-9][0-9][0-9]-of-[0-9][0-9][0-9][0-9][0-9]'
463463
)
464464
assert (
465465
os.fspath(template.sharded_filepaths_pattern(num_shards=100))
@@ -474,7 +474,10 @@ def test_glob_pattern():
474474
filetype_suffix='tfrecord',
475475
data_dir=epath.Path('/data'),
476476
)
477-
assert '/data/ds-train.tfrecord*' == template.glob_pattern()
477+
assert (
478+
'/data/ds-train.tfrecord-[0-9][0-9][0-9][0-9][0-9]-of-[0-9][0-9][0-9][0-9][0-9]'
479+
== template.glob_pattern()
480+
)
478481
assert '/data/ds-train.tfrecord-*-of-00042' == template.glob_pattern(
479482
num_shards=42
480483
)

tensorflow_datasets/core/writer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -816,8 +816,9 @@ def finalize(self) -> tuple[list[int], int]:
816816
logging.info("Finalizing writer for %s", self._filename_template.split)
817817
# We don't know the number of shards, the length of each shard, nor the
818818
# total size, so we compute them here.
819-
prefix = epath.Path(self._filename_template.filepath_prefix())
820-
shards = self._filename_template.data_dir.glob(f"{prefix.name}*")
819+
shards = self._filename_template.data_dir.glob(
820+
self._filename_template.glob_pattern()
821+
)
821822

822823
def _get_length_and_size(shard: epath.Path) -> tuple[epath.Path, int, int]:
823824
length = self._file_adapter.num_examples(shard)

tensorflow_datasets/core/writer_test.py

Lines changed: 32 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -592,39 +592,47 @@ def test_write_beam(self, file_format: file_adapters.FileFormat):
592592

593593
with tempfile.TemporaryDirectory() as tmp_dir:
594594
tmp_dir = epath.Path(tmp_dir)
595-
filename_template = naming.ShardedFileTemplate(
596-
dataset_name='foo',
597-
split='train',
598-
filetype_suffix=file_format.file_suffix,
599-
data_dir=tmp_dir,
600-
)
601-
writer = writer_lib.NoShuffleBeamWriter(
602-
serializer=testing.DummySerializer('dummy specs'),
603-
filename_template=filename_template,
604-
file_format=file_format,
605-
)
595+
596+
def get_writer(split):
597+
filename_template = naming.ShardedFileTemplate(
598+
dataset_name='foo',
599+
split=split,
600+
filetype_suffix=file_format.file_suffix,
601+
data_dir=tmp_dir,
602+
)
603+
return writer_lib.NoShuffleBeamWriter(
604+
serializer=testing.DummySerializer('dummy specs'),
605+
filename_template=filename_template,
606+
file_format=file_format,
607+
)
608+
606609
to_write = [(i, str(i).encode('utf-8')) for i in range(10)]
607610
# Here we need to disable type check as `beam.Create` is not capable of
608611
# inferring the type of the PCollection elements.
609612
options = beam.options.pipeline_options.PipelineOptions(
610613
pipeline_type_check=False
611614
)
612-
with beam.Pipeline(options=options, runner=_get_runner()) as pipeline:
613-
614-
@beam.ptransform_fn
615-
def _build_pcollection(pipeline):
616-
pcollection = pipeline | 'Start' >> beam.Create(to_write)
617-
return writer.write_from_pcollection(pcollection)
618-
619-
_ = pipeline | 'test' >> _build_pcollection() # pylint: disable=no-value-for-parameter
620-
shard_lengths, total_size = writer.finalize()
621-
self.assertNotEmpty(shard_lengths)
622-
self.assertEqual(sum(shard_lengths), 10)
623-
self.assertGreater(total_size, 10)
615+
writers = [get_writer(split) for split in ('train-b', 'train')]
616+
617+
for writer in writers:
618+
with beam.Pipeline(options=options, runner=_get_runner()) as pipeline:
619+
620+
@beam.ptransform_fn
621+
def _build_pcollection(pipeline, writer):
622+
pcollection = pipeline | 'Start' >> beam.Create(to_write)
623+
return writer.write_from_pcollection(pcollection)
624+
625+
_ = pipeline | 'test' >> _build_pcollection(writer)
626+
624627
files = list(tmp_dir.iterdir())
625-
self.assertGreaterEqual(len(files), 1)
628+
self.assertGreaterEqual(len(files), 2)
626629
for f in files:
627630
self.assertIn(file_format.file_suffix, f.name)
631+
for writer in writers:
632+
shard_lengths, total_size = writer.finalize()
633+
self.assertNotEmpty(shard_lengths)
634+
self.assertEqual(sum(shard_lengths), 10)
635+
self.assertGreater(total_size, 10)
628636

629637

630638
class CustomExampleWriter(writer_lib.ExampleWriter):

0 commit comments

Comments
 (0)