Skip to content

Commit 1e53e0c

Browse files
authored
fix parallel push_to_hub in dataset_dict (#7613)
* fix parallel push_to_hub in dataset_dict * style * fix * fix * fix * fix * last one
1 parent 89bd1f9 commit 1e53e0c

File tree

6 files changed

+36
-20
lines changed

6 files changed

+36
-20
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ on:
1010
- ci-*
1111

1212
env:
13-
HF_ALLOW_CODE_EVAL: 1
13+
CI_HEADERS: ${{ secrets.CI_HEADERS }}
1414

1515
jobs:
1616

setup.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,8 @@
152152
]
153153

154154
TESTS_REQUIRE = [
155+
# fix pip install issues for windows
156+
"numba>=0.56.4", # to get recent versions of llvmlite for windows ci
155157
# test dependencies
156158
"absl-py",
157159
"decorator",

src/datasets/arrow_dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5517,7 +5517,7 @@ def _push_parquet_shards_to_hub(
55175517
total=num_shards,
55185518
desc=desc,
55195519
)
5520-
with contextlib.nullcontext() if num_proc is None and num_proc > 1 else Pool(num_proc) as pool:
5520+
with contextlib.nullcontext() if num_proc is None or num_proc <= 1 else Pool(num_proc) as pool:
55215521
update_stream = (
55225522
Dataset._push_parquet_shards_to_hub_single(**kwargs_iterable[0])
55235523
if pool is None

tests/fixtures/hub.py

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import requests
99
from huggingface_hub.hf_api import HfApi, RepositoryNotFoundError
1010
from huggingface_hub.utils import hf_raise_for_status
11+
from huggingface_hub.utils._headers import _http_user_agent
1112

1213

1314
CI_HUB_USER = "__DUMMY_TRANSFORMERS_USER__"
@@ -19,17 +20,18 @@
1920
CI_HFH_HUGGINGFACE_CO_URL_TEMPLATE = CI_HUB_ENDPOINT + "/{repo_id}/resolve/{revision}/{filename}"
2021

2122

22-
@pytest.fixture
23-
def ci_hfh_hf_hub_url(monkeypatch):
24-
monkeypatch.setattr(
25-
"huggingface_hub.file_download.HUGGINGFACE_CO_URL_TEMPLATE", CI_HFH_HUGGINGFACE_CO_URL_TEMPLATE
26-
)
27-
28-
2923
@pytest.fixture
3024
def ci_hub_config(monkeypatch):
3125
monkeypatch.setattr("datasets.config.HF_ENDPOINT", CI_HUB_ENDPOINT)
3226
monkeypatch.setattr("datasets.config.HUB_DATASETS_URL", CI_HUB_DATASETS_URL)
27+
monkeypatch.setattr(
28+
"huggingface_hub.file_download.HUGGINGFACE_CO_URL_TEMPLATE", CI_HFH_HUGGINGFACE_CO_URL_TEMPLATE
29+
)
30+
old_environ = dict(os.environ)
31+
os.environ["HF_ENDPOINT"] = CI_HUB_ENDPOINT
32+
yield
33+
os.environ.clear()
34+
os.environ.update(old_environ)
3335

3436

3537
@pytest.fixture
@@ -38,6 +40,22 @@ def set_ci_hub_access_token(ci_hub_config, monkeypatch):
3840
monkeypatch.setattr("huggingface_hub.constants.HF_HUB_DISABLE_IMPLICIT_TOKEN", False)
3941
old_environ = dict(os.environ)
4042
os.environ["HF_TOKEN"] = CI_HUB_USER_TOKEN
43+
os.environ["HF_HUB_DISABLE_IMPLICIT_TOKEN"] = "0"
44+
yield
45+
os.environ.clear()
46+
os.environ.update(old_environ)
47+
48+
49+
def _http_ci_user_agent(*args, **kwargs):
50+
ua = _http_user_agent(*args, **kwargs)
51+
return ua + os.environ.get("CI_HEADERS", "")
52+
53+
54+
@pytest.fixture(autouse=True)
55+
def set_hf_ci_headers(monkeypatch):
56+
old_environ = dict(os.environ)
57+
os.environ["TRANSFORMERS_IS_CI"] = "1"
58+
monkeypatch.setattr("huggingface_hub.utils._headers._http_user_agent", _http_ci_user_agent)
4159
yield
4260
os.environ.clear()
4361
os.environ.update(old_environ)
@@ -105,7 +123,7 @@ def _hf_gated_dataset_repo_txt_data(hf_api: HfApi, hf_token, text_file_content):
105123

106124

107125
@pytest.fixture()
108-
def hf_gated_dataset_repo_txt_data(_hf_gated_dataset_repo_txt_data, ci_hub_config, ci_hfh_hf_hub_url):
126+
def hf_gated_dataset_repo_txt_data(_hf_gated_dataset_repo_txt_data, ci_hub_config):
109127
return _hf_gated_dataset_repo_txt_data
110128

111129

@@ -129,7 +147,7 @@ def hf_private_dataset_repo_txt_data_(hf_api: HfApi, hf_token, text_file_content
129147

130148

131149
@pytest.fixture()
132-
def hf_private_dataset_repo_txt_data(hf_private_dataset_repo_txt_data_, ci_hub_config, ci_hfh_hf_hub_url):
150+
def hf_private_dataset_repo_txt_data(hf_private_dataset_repo_txt_data_, ci_hub_config):
133151
return hf_private_dataset_repo_txt_data_
134152

135153

@@ -153,9 +171,7 @@ def hf_private_dataset_repo_zipped_txt_data_(hf_api: HfApi, hf_token, zip_csv_wi
153171

154172

155173
@pytest.fixture()
156-
def hf_private_dataset_repo_zipped_txt_data(
157-
hf_private_dataset_repo_zipped_txt_data_, ci_hub_config, ci_hfh_hf_hub_url
158-
):
174+
def hf_private_dataset_repo_zipped_txt_data(hf_private_dataset_repo_zipped_txt_data_, ci_hub_config):
159175
return hf_private_dataset_repo_zipped_txt_data_
160176

161177

@@ -179,7 +195,5 @@ def hf_private_dataset_repo_zipped_img_data_(hf_api: HfApi, hf_token, zip_image_
179195

180196

181197
@pytest.fixture()
182-
def hf_private_dataset_repo_zipped_img_data(
183-
hf_private_dataset_repo_zipped_img_data_, ci_hub_config, ci_hfh_hf_hub_url
184-
):
198+
def hf_private_dataset_repo_zipped_img_data(hf_private_dataset_repo_zipped_img_data_, ci_hub_config):
185199
return hf_private_dataset_repo_zipped_img_data_

tests/test_hub.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def test_dataset_url(repo_id, filename, revision):
4545
assert url == f"https://huggingface.co/datasets/{repo_id}/resolve/{revision or 'main'}/{quote(filename)}"
4646

4747

48-
def test_delete_from_hub(temporary_repo, hf_api, hf_token, csv_path, ci_hub_config, ci_hfh_hf_hub_url) -> None:
48+
def test_delete_from_hub(temporary_repo, hf_api, hf_token, csv_path, ci_hub_config) -> None:
4949
with temporary_repo() as repo_id:
5050
hf_api.create_repo(repo_id, token=hf_token, repo_type="dataset")
5151
hf_api.upload_file(

tests/test_upstream_hub.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545

4646

4747
@for_all_test_methods(xfail_if_500_502_http_error)
48-
@pytest.mark.usefixtures("ci_hub_config", "ci_hfh_hf_hub_url")
48+
@pytest.mark.usefixtures("ci_hub_config")
4949
class TestPushToHub:
5050
_api = HfApi(endpoint=CI_HUB_ENDPOINT)
5151
_token = CI_HUB_USER_TOKEN
@@ -969,7 +969,7 @@ def text_file_with_metadata(request, tmp_path, text_file):
969969

970970

971971
@for_all_test_methods(xfail_if_500_502_http_error)
972-
@pytest.mark.usefixtures("ci_hub_config", "ci_hfh_hf_hub_url")
972+
@pytest.mark.usefixtures("ci_hub_config")
973973
class TestLoadFromHub:
974974
_api = HfApi(endpoint=CI_HUB_ENDPOINT)
975975
_token = CI_HUB_USER_TOKEN

0 commit comments

Comments
 (0)