Skip to content

Commit 89bd1f9

Browse files
authored
Tests typing and fixes for push_to_hub (#7608)
* tests typing and fixes for push_to_hub * fix
1 parent 38d4d0e commit 89bd1f9

File tree

3 files changed

+170
-21
lines changed

3 files changed

+170
-21
lines changed

src/datasets/dataset_dict.py

Lines changed: 70 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from .features import Features
3232
from .features.features import FeatureType
3333
from .info import DatasetInfo, DatasetInfosDict
34+
from .iterable_dataset import IterableDataset
3435
from .naming import _split_re
3536
from .splits import NamedSplit, Split, SplitDict, SplitInfo
3637
from .table import Table
@@ -49,7 +50,7 @@ def __call__(self, *fn_args, **fn_kwargs):
4950
return self.func(*fn_args, *self.args, **fn_kwargs)
5051

5152

52-
class DatasetDict(dict):
53+
class DatasetDict(dict[Union[str, NamedSplit], "Dataset"]):
5354
"""A dictionary (dict of str: datasets.Dataset) with dataset transforms methods (map, filter, etc.)"""
5455

5556
def _check_values_type(self):
@@ -1616,6 +1617,7 @@ def push_to_hub(
16161617
max_shard_size: Optional[Union[int, str]] = None,
16171618
num_shards: Optional[dict[str, int]] = None,
16181619
embed_external_files: bool = True,
1620+
num_proc: Optional[int] = None,
16191621
) -> CommitInfo:
16201622
"""Pushes the [`DatasetDict`] to the hub as a Parquet dataset.
16211623
The [`DatasetDict`] is pushed using HTTP requests and does not need to have neither git or git-lfs installed.
@@ -1676,6 +1678,12 @@ def push_to_hub(
16761678
In particular, this will do the following before the push for the fields of type:
16771679
16781680
- [`Audio`] and [`Image`] removes local path information and embed file content in the Parquet files.
1681+
num_proc (`int`, *optional*, defaults to `None`):
1682+
Number of processes when preparing and uploading the dataset.
1683+
This is helpful if the dataset is made of many samples or media files to embed.
1684+
Multiprocessing is disabled by default.
1685+
1686+
<Added version="4.0.0"/>
16791687
16801688
Return:
16811689
huggingface_hub.CommitInfo
@@ -1756,6 +1764,7 @@ def push_to_hub(
17561764
max_shard_size=max_shard_size,
17571765
num_shards=num_shards.get(split),
17581766
embed_external_files=embed_external_files,
1767+
num_proc=num_proc,
17591768
)
17601769
additions += split_additions
17611770
total_uploaded_size += uploaded_size
@@ -1910,12 +1919,61 @@ def push_to_hub(
19101919
return commit_info
19111920

19121921

1913-
class IterableDatasetDict(dict):
1922+
class IterableDatasetDict(dict[Union[str, NamedSplit], IterableDataset]):
1923+
def _check_values_type(self):
1924+
for dataset in self.values():
1925+
if not isinstance(dataset, IterableDataset):
1926+
raise TypeError(f"Values in `DatasetDict` should be of type `Dataset` but got type '{type(dataset)}'")
1927+
1928+
def _check_values_features(self):
1929+
items = [(key, dataset._resolve_features()) for key, dataset in self.items()]
1930+
for item_a, item_b in zip(items[:-1], items[1:]):
1931+
if item_a[1].features != item_b[1].features:
1932+
raise ValueError(
1933+
f"All datasets in `DatasetDict` should have the same features but features for '{item_a[0]}' and '{item_b[0]}' don't match: {item_a[1].features} != {item_b[1].features}"
1934+
)
1935+
19141936
def __repr__(self):
19151937
repr = "\n".join([f"{k}: {v}" for k, v in self.items()])
19161938
repr = re.sub(r"^", " " * 4, repr, count=0, flags=re.M)
19171939
return f"IterableDatasetDict({{\n{repr}\n}})"
19181940

1941+
@property
1942+
def num_columns(self) -> dict[str, Optional[int]]:
1943+
"""Number of columns in each split of the dataset.
1944+
This can contain None valies if some splits have unknown features (e.g. after a map() operation).
1945+
1946+
Example:
1947+
1948+
```py
1949+
>>> from datasets import load_dataset
1950+
>>> ds = load_dataset("cornell-movie-review-data/rotten_tomatoes")
1951+
>>> ds.num_columns
1952+
{'test': 2, 'train': 2, 'validation': 2}
1953+
```
1954+
"""
1955+
self._check_values_type()
1956+
return {k: dataset.num_columns for k, dataset in self.items()}
1957+
1958+
@property
1959+
def column_names(self) -> dict[str, Optional[list[str]]]:
1960+
"""Names of the columns in each split of the dataset.
1961+
This can contain None valies if some splits have unknown features (e.g. after a map() operation).
1962+
1963+
Example:
1964+
1965+
```py
1966+
>>> from datasets import load_dataset
1967+
>>> ds = load_dataset("cornell-movie-review-data/rotten_tomatoes")
1968+
>>> ds.column_names
1969+
{'test': ['text', 'label'],
1970+
'train': ['text', 'label'],
1971+
'validation': ['text', 'label']}
1972+
```
1973+
"""
1974+
self._check_values_type()
1975+
return {k: dataset.column_names for k, dataset in self.items()}
1976+
19191977
def with_format(
19201978
self,
19211979
type: Optional[str] = None,
@@ -2385,6 +2443,7 @@ def push_to_hub(
23852443
# max_shard_size: Optional[Union[int, str]] = None, # TODO(QL): add arg
23862444
num_shards: Optional[dict[str, int]] = None,
23872445
embed_external_files: bool = True,
2446+
num_proc: Optional[int] = None,
23882447
) -> CommitInfo:
23892448
"""Pushes the [`DatasetDict`] to the hub as a Parquet dataset.
23902449
The [`DatasetDict`] is pushed using HTTP requests and does not need to have neither git or git-lfs installed.
@@ -2436,6 +2495,12 @@ def push_to_hub(
24362495
In particular, this will do the following before the push for the fields of type:
24372496
24382497
- [`Audio`] and [`Image`] removes local path information and embed file content in the Parquet files.
2498+
num_proc (`int`, *optional*, defaults to `None`):
2499+
Number of processes when preparing and uploading the dataset.
2500+
This is helpful if the dataset is made of many samples or media files to embed.
2501+
Multiprocessing is disabled by default.
2502+
2503+
<Added version="4.0.0"/>
24392504
24402505
Return:
24412506
huggingface_hub.CommitInfo
@@ -2505,7 +2570,7 @@ def push_to_hub(
25052570
for split in self.keys():
25062571
logger.info(f"Pushing split {split} to the Hub.")
25072572
# The split=key needs to be removed before merging
2508-
split_additions, uploaded_size, dataset_nbytes = self[split]._push_parquet_shards_to_hub(
2573+
split_additions, uploaded_size, dataset_nbytes, num_examples = self[split]._push_parquet_shards_to_hub(
25092574
repo_id,
25102575
data_dir=data_dir,
25112576
split=split,
@@ -2515,11 +2580,12 @@ def push_to_hub(
25152580
# max_shard_size=max_shard_size, # TODO(QL): add arg
25162581
num_shards=num_shards.get(split),
25172582
embed_external_files=embed_external_files,
2583+
num_proc=num_proc,
25182584
)
25192585
additions += split_additions
25202586
total_uploaded_size += uploaded_size
25212587
total_dataset_nbytes += dataset_nbytes
2522-
info_to_dump.splits[split] = SplitInfo(str(split), num_bytes=dataset_nbytes, num_examples=len(self[split]))
2588+
info_to_dump.splits[split] = SplitInfo(str(split), num_bytes=dataset_nbytes, num_examples=num_examples)
25232589
info_to_dump.download_checksums = None
25242590
info_to_dump.download_size = total_uploaded_size
25252591
info_to_dump.dataset_size = total_dataset_nbytes

src/datasets/iterable_dataset.py

Lines changed: 37 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2015,6 +2015,38 @@ def __init__(
20152015
self._prepare_ex_iterable_for_iteration() # set state_dict
20162016
_maybe_add_torch_iterable_dataset_parent_class(self.__class__) # subclass of torch IterableDataset
20172017

2018+
@property
2019+
def num_columns(self) -> Optional[int]:
2020+
"""Number of columns in the dataset.
2021+
This can be None if the dataset has unknown features (e.g. after a map() operation).
2022+
2023+
Example:
2024+
2025+
```py
2026+
>>> from datasets import load_dataset
2027+
>>> ds = load_dataset("cornell-movie-review-data/rotten_tomatoes", split="validation")
2028+
>>> ds.num_columns
2029+
2
2030+
```
2031+
"""
2032+
return None if self.features is None else len(self.features)
2033+
2034+
@property
2035+
def column_names(self) -> Optional[list[str]]:
2036+
"""Names of the columns in the dataset.
2037+
This can be None if the dataset has unknown features (e.g. after a map() operation).
2038+
2039+
Example:
2040+
2041+
```py
2042+
>>> from datasets import load_dataset
2043+
>>> ds = load_dataset("cornell-movie-review-data/rotten_tomatoes", split="validation", streaming=True)
2044+
>>> ds.column_names
2045+
['text', 'label']
2046+
```
2047+
"""
2048+
return None if self.features is None else list(self.features)
2049+
20182050
def state_dict(self) -> dict:
20192051
"""Get the current state_dict of the dataset.
20202052
It corresponds to the state at the latest example it yielded.
@@ -3007,21 +3039,6 @@ def shard(
30073039
token_per_repo_id=self._token_per_repo_id,
30083040
)
30093041

3010-
@property
3011-
def column_names(self) -> Optional[list[str]]:
3012-
"""Names of the columns in the dataset.
3013-
3014-
Example:
3015-
3016-
```py
3017-
>>> from datasets import load_dataset
3018-
>>> ds = load_dataset("cornell-movie-review-data/rotten_tomatoes", split="validation", streaming=True)
3019-
>>> ds.column_names
3020-
['text', 'label']
3021-
```
3022-
"""
3023-
return list(self._info.features.keys()) if self._info.features is not None else None
3024-
30253042
def add_column(self, name: str, column: Union[list, np.array]) -> "IterableDataset":
30263043
"""Add column to Dataset.
30273044
@@ -3791,7 +3808,7 @@ def _push_parquet_shards_to_hub(
37913808
num_shards: Optional[int],
37923809
embed_external_files: bool,
37933810
num_proc: Optional[int],
3794-
) -> tuple[list[CommitOperationAdd], int, int]:
3811+
) -> tuple[list[CommitOperationAdd], int, int, int]:
37953812
"""Pushes the dataset shards as Parquet files to the hub.
37963813
37973814
Returns:
@@ -3841,7 +3858,7 @@ def _push_parquet_shards_to_hub(
38413858
total=num_shards,
38423859
desc=desc,
38433860
)
3844-
with contextlib.nullcontext() if num_proc is None and num_proc > 1 else Pool(num_proc) as pool:
3861+
with contextlib.nullcontext() if num_proc is None or num_proc <= 1 else Pool(num_proc) as pool:
38453862
update_stream = (
38463863
IterableDataset._push_parquet_shards_to_hub_single(**kwargs_iterable[0])
38473864
if pool is None
@@ -3858,6 +3875,9 @@ def _push_parquet_shards_to_hub(
38583875
additions += content[0]
38593876
dataset_nbytes += content[1]
38603877
num_examples += content[2]
3878+
if pool is not None:
3879+
pool.close()
3880+
pool.join()
38613881

38623882
uploaded_size = sum(addition.upload_info.size for addition in additions)
38633883
return additions, uploaded_size, dataset_nbytes, num_examples

tests/test_upstream_hub.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
DownloadManager,
2323
Features,
2424
Image,
25+
IterableDatasetDict,
2526
Value,
2627
load_dataset,
2728
load_dataset_builder,
@@ -873,6 +874,68 @@ def test_push_dataset_dict_to_hub_with_config_no_metadata_configs(self, temporar
873874
"*/another_config/random-00000-of-00001.parquet",
874875
)
875876

877+
def test_push_dataset_dict_to_hub_num_proc(self, temporary_repo, set_ci_hub_access_token):
878+
ds = Dataset.from_dict({"x": [1, 2, 3], "y": [4, 5, 6]})
879+
880+
local_ds = DatasetDict({"train": ds})
881+
882+
with temporary_repo() as ds_name:
883+
local_ds.push_to_hub(ds_name, num_proc=2)
884+
hub_ds = load_dataset(ds_name, download_mode="force_redownload")
885+
886+
assert local_ds.column_names == hub_ds.column_names
887+
assert list(local_ds["train"].features.keys()) == list(hub_ds["train"].features.keys())
888+
assert local_ds["train"].features == hub_ds["train"].features
889+
890+
# Ensure that there is a single file on the repository that has the correct name
891+
files = sorted(self._api.list_repo_files(ds_name, repo_type="dataset"))
892+
assert files == [
893+
".gitattributes",
894+
"README.md",
895+
"data/train-00000-of-00002.parquet",
896+
"data/train-00001-of-00002.parquet",
897+
]
898+
899+
def test_push_dataset_dict_to_hub_iterable(self, temporary_repo, set_ci_hub_access_token):
900+
ds = Dataset.from_dict({"x": [1, 2, 3], "y": [4, 5, 6]}).to_iterable_dataset()
901+
902+
local_ds = IterableDatasetDict({"train": ds})
903+
904+
with temporary_repo() as ds_name:
905+
local_ds.push_to_hub(ds_name)
906+
hub_ds = load_dataset(ds_name, download_mode="force_redownload")
907+
908+
assert local_ds.column_names == hub_ds.column_names
909+
assert list(local_ds["train"].features.keys()) == list(hub_ds["train"].features.keys())
910+
assert local_ds["train"].features == hub_ds["train"].features
911+
912+
# Ensure that there is a single file on the repository that has the correct name
913+
files = sorted(self._api.list_repo_files(ds_name, repo_type="dataset"))
914+
assert files == [".gitattributes", "README.md", "data/train-00000-of-00001.parquet"]
915+
916+
def test_push_dataset_dict_to_hub_iterable_num_proc(self, temporary_repo, set_ci_hub_access_token):
917+
ds = Dataset.from_dict({"x": [1, 2, 3], "y": [4, 5, 6]}).to_iterable_dataset(num_shards=3)
918+
919+
local_ds = IterableDatasetDict({"train": ds})
920+
921+
with temporary_repo() as ds_name:
922+
local_ds.push_to_hub(ds_name, num_proc=2)
923+
hub_ds = load_dataset(ds_name, download_mode="force_redownload")
924+
925+
assert local_ds.column_names == hub_ds.column_names
926+
assert list(local_ds["train"].features.keys()) == list(hub_ds["train"].features.keys())
927+
assert local_ds["train"].features == hub_ds["train"].features
928+
929+
# Ensure that there is a single file on the repository that has the correct name
930+
files = sorted(self._api.list_repo_files(ds_name, repo_type="dataset"))
931+
assert files == [
932+
".gitattributes",
933+
"README.md",
934+
"data/train-00000-of-00003.parquet",
935+
"data/train-00001-of-00003.parquet",
936+
"data/train-00002-of-00003.parquet",
937+
]
938+
876939

877940
class DummyFolderBasedBuilder(FolderBasedBuilder):
878941
BASE_FEATURE = dict

0 commit comments

Comments
 (0)