Skip to content

Commit 38d4d0e

Browse files
authored
Add num_proc= to .push_to_hub() (Dataset and IterableDataset) (#7606)
* parallel push_to_hub * minor * num_proc in IterableDataset.push_to_hub
1 parent 8e61377 commit 38d4d0e

File tree

2 files changed

+250
-89
lines changed

2 files changed

+250
-89
lines changed

src/datasets/arrow_dataset.py

Lines changed: 118 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -5397,17 +5397,76 @@ def to_iterable_dataset(self, num_shards: Optional[int] = 1) -> "IterableDataset
53975397
ds = ds.with_format(self._format_type)
53985398
return ds
53995399

5400+
def _push_parquet_shards_to_hub_single(
5401+
self,
5402+
job_id: int,
5403+
num_jobs: int,
5404+
repo_id: str,
5405+
data_dir: str,
5406+
split: str,
5407+
token: Optional[str],
5408+
revision: Optional[str],
5409+
create_pr: Optional[bool],
5410+
num_shards: int,
5411+
embed_external_files: bool,
5412+
):
5413+
div = num_shards // num_jobs
5414+
mod = num_shards % num_jobs
5415+
start = div * job_id + min(job_id, mod)
5416+
end = start + div + (1 if job_id < mod else 0)
5417+
5418+
index_shards = (
5419+
(start + i, self.shard(num_shards=end - start, index=i, contiguous=True)) for i in range(end - start)
5420+
)
5421+
5422+
api = HfApi(endpoint=config.HF_ENDPOINT, token=token)
5423+
5424+
uploaded_size = 0
5425+
additions: list[CommitOperationAdd] = []
5426+
for index, shard in index_shards:
5427+
if embed_external_files:
5428+
from .io.parquet import get_writer_batch_size
5429+
5430+
format = shard.format
5431+
shard = shard.with_format("arrow")
5432+
shard = shard.map(
5433+
embed_table_storage,
5434+
batched=True,
5435+
batch_size=get_writer_batch_size(shard.features),
5436+
keep_in_memory=True,
5437+
)
5438+
shard = shard.with_format(**format)
5439+
shard_path_in_repo = f"{data_dir}/{split}-{index:05d}-of-{num_shards:05d}.parquet"
5440+
buffer = BytesIO()
5441+
shard.to_parquet(buffer)
5442+
parquet_content = buffer.getvalue()
5443+
uploaded_size += len(parquet_content)
5444+
del buffer
5445+
shard_addition = CommitOperationAdd(path_in_repo=shard_path_in_repo, path_or_fileobj=parquet_content)
5446+
api.preupload_lfs_files(
5447+
repo_id=repo_id,
5448+
additions=[shard_addition],
5449+
repo_type="dataset",
5450+
revision=revision,
5451+
create_pr=create_pr,
5452+
)
5453+
additions.append(shard_addition)
5454+
yield job_id, False, 1
5455+
5456+
yield job_id, True, additions
5457+
54005458
def _push_parquet_shards_to_hub(
54015459
self,
54025460
repo_id: str,
5403-
data_dir: str = "data",
5404-
split: Optional[str] = None,
5405-
token: Optional[str] = None,
5406-
revision: Optional[str] = None,
5407-
create_pr: Optional[bool] = False,
5408-
max_shard_size: Optional[Union[int, str]] = None,
5409-
num_shards: Optional[int] = None,
5410-
embed_external_files: bool = True,
5461+
data_dir: str,
5462+
split: str,
5463+
token: Optional[str],
5464+
revision: Optional[str],
5465+
create_pr: Optional[bool],
5466+
max_shard_size: Optional[Union[int, str]],
5467+
num_shards: Optional[int],
5468+
embed_external_files: bool,
5469+
num_proc: Optional[int],
54115470
) -> tuple[list[CommitOperationAdd], int, int]:
54125471
"""Pushes the dataset shards as Parquet files to the hub.
54135472
@@ -5416,66 +5475,65 @@ def _push_parquet_shards_to_hub(
54165475
uploaded_size (`int`): number of uploaded bytes to the repository
54175476
dataset_nbytes (`int`): approximate size in bytes of the uploaded dataset after uncompression
54185477
"""
5478+
dataset_nbytes = self._estimate_nbytes()
5479+
54195480
# Find decodable columns, because if there are any, we need to:
54205481
# embed the bytes from the files in the shards
54215482
decodable_columns = (
54225483
[k for k, v in self._info.features.items() if require_decoding(v, ignore_decode_attribute=True)]
54235484
if embed_external_files
54245485
else []
54255486
)
5426-
5427-
dataset_nbytes = self._estimate_nbytes()
5487+
embed_external_files = embed_external_files and bool(decodable_columns)
54285488

54295489
if num_shards is None:
54305490
max_shard_size = convert_file_size_to_int(max_shard_size or config.MAX_SHARD_SIZE)
54315491
num_shards = int(dataset_nbytes / max_shard_size) + 1
5432-
num_shards = max(num_shards, 1)
5433-
5434-
shards = (self.shard(num_shards=num_shards, index=i, contiguous=True) for i in range(num_shards))
5435-
5436-
if decodable_columns:
5437-
from .io.parquet import get_writer_batch_size
5438-
5439-
def shards_with_embedded_external_files(shards: Iterator[Dataset]) -> Iterator[Dataset]:
5440-
for shard in shards:
5441-
format = shard.format
5442-
shard = shard.with_format("arrow")
5443-
shard = shard.map(
5444-
embed_table_storage,
5445-
batched=True,
5446-
batch_size=get_writer_batch_size(shard.features),
5447-
keep_in_memory=True,
5448-
)
5449-
shard = shard.with_format(**format)
5450-
yield shard
5451-
5452-
shards = shards_with_embedded_external_files(shards)
5453-
5454-
api = HfApi(endpoint=config.HF_ENDPOINT, token=token)
5492+
num_shards = max(num_shards, num_proc or 1)
54555493

5456-
uploaded_size = 0
54575494
additions: list[CommitOperationAdd] = []
5458-
for index, shard in hf_tqdm(
5459-
enumerate(shards),
5460-
desc="Uploading the dataset shards",
5495+
5496+
num_jobs = num_proc or 1
5497+
kwargs_iterable = [
5498+
{
5499+
"self": self.shard(num_shards=num_jobs, index=job_id, contiguous=True),
5500+
"job_id": job_id,
5501+
"num_jobs": num_jobs,
5502+
"repo_id": repo_id,
5503+
"data_dir": data_dir,
5504+
"split": split,
5505+
"token": token,
5506+
"revision": revision,
5507+
"create_pr": create_pr,
5508+
"num_shards": num_shards,
5509+
"embed_external_files": embed_external_files,
5510+
}
5511+
for job_id in range(num_jobs)
5512+
]
5513+
desc = "Uploading the dataset shards"
5514+
desc += f" (num_proc={num_proc})" if num_proc is not None and num_proc > 1 else ""
5515+
pbar = hf_tqdm(
5516+
unit=" shards",
54615517
total=num_shards,
5462-
):
5463-
shard_path_in_repo = f"{data_dir}/{split}-{index:05d}-of-{num_shards:05d}.parquet"
5464-
buffer = BytesIO()
5465-
shard.to_parquet(buffer)
5466-
parquet_content = buffer.getvalue()
5467-
uploaded_size += len(parquet_content)
5468-
del buffer
5469-
shard_addition = CommitOperationAdd(path_in_repo=shard_path_in_repo, path_or_fileobj=parquet_content)
5470-
api.preupload_lfs_files(
5471-
repo_id=repo_id,
5472-
additions=[shard_addition],
5473-
repo_type="dataset",
5474-
revision=revision,
5475-
create_pr=create_pr,
5518+
desc=desc,
5519+
)
5520+
with contextlib.nullcontext() if num_proc is None and num_proc > 1 else Pool(num_proc) as pool:
5521+
update_stream = (
5522+
Dataset._push_parquet_shards_to_hub_single(**kwargs_iterable[0])
5523+
if pool is None
5524+
else iflatmap_unordered(
5525+
pool,
5526+
Dataset._push_parquet_shards_to_hub_single,
5527+
kwargs_iterable=kwargs_iterable,
5528+
)
54765529
)
5477-
additions.append(shard_addition)
5530+
for job_id, done, content in update_stream:
5531+
if not done:
5532+
pbar.update(content)
5533+
else:
5534+
additions += content
54785535

5536+
uploaded_size = sum(addition.upload_info.size for addition in additions)
54795537
return additions, uploaded_size, dataset_nbytes
54805538

54815539
def push_to_hub(
@@ -5494,6 +5552,7 @@ def push_to_hub(
54945552
max_shard_size: Optional[Union[int, str]] = None,
54955553
num_shards: Optional[int] = None,
54965554
embed_external_files: bool = True,
5555+
num_proc: Optional[int] = None,
54975556
) -> CommitInfo:
54985557
"""Pushes the dataset to the hub as a Parquet dataset.
54995558
The dataset is pushed using HTTP requests and does not need to have neither git or git-lfs installed.
@@ -5553,6 +5612,12 @@ def push_to_hub(
55535612
In particular, this will do the following before the push for the fields of type:
55545613
55555614
- [`Audio`] and [`Image`]: remove local path information and embed file content in the Parquet files.
5615+
num_proc (`int`, *optional*, defaults to `None`):
5616+
Number of processes when preparing and uploading the dataset.
5617+
This is helpful if the dataset is made of many samples or media files to embed.
5618+
Multiprocessing is disabled by default.
5619+
5620+
<Added version="4.0.0"/>
55565621
55575622
Return:
55585623
huggingface_hub.CommitInfo
@@ -5636,6 +5701,7 @@ def push_to_hub(
56365701
num_shards=num_shards,
56375702
create_pr=create_pr,
56385703
embed_external_files=embed_external_files,
5704+
num_proc=num_proc,
56395705
)
56405706

56415707
# Check if the repo already has a README.md and/or a dataset_infos.json to update them with the new split info (size and pattern)

0 commit comments

Comments
 (0)