Skip to content

Commit 48b6520

Browse files
authored
unit test for download_hf_assets script (#1556)
1 parent 0c51d92 commit 48b6520

File tree

1 file changed

+258
-0
lines changed

1 file changed

+258
-0
lines changed
Lines changed: 258 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,258 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import tempfile
8+
import unittest
9+
from unittest.mock import Mock, patch
10+
11+
from scripts.download_hf_assets import download_hf_assets
12+
13+
14+
class TestDownloadHfAssets(unittest.TestCase):
15+
"""Tests for the download_hf_assets script
16+
17+
We mock `huggingface_hub.list_repo_files` and `huggingface_hub.hf_hub_download` to simulate the meta-llama/Llama-3.1-8B repo
18+
"""
19+
20+
# Complete file list from the actual meta-llama/Llama-3.1-8B repository
21+
COMPLETE_REPO_FILES = [
22+
"config.json",
23+
"generation_config.json",
24+
"model.safetensors.index.json",
25+
"model-00001-of-00004.safetensors",
26+
"model-00002-of-00004.safetensors",
27+
"model-00003-of-00004.safetensors",
28+
"model-00004-of-00004.safetensors",
29+
"original/consolidated.00.pth",
30+
"original/params.json",
31+
"original/tokenizer.model",
32+
"special_tokens_map.json",
33+
"tokenizer.json",
34+
"tokenizer_config.json",
35+
"LICENSE",
36+
"README.md",
37+
"USE_POLICY.md",
38+
]
39+
40+
# Expected files for each asset type
41+
EXPECTED_FILES = {
42+
"tokenizer": [
43+
"tokenizer.json",
44+
"tokenizer_config.json",
45+
"special_tokens_map.json",
46+
"original/tokenizer.model",
47+
],
48+
"config": ["config.json", "generation_config.json"],
49+
"safetensors": [
50+
"model-00001-of-00004.safetensors",
51+
"model-00002-of-00004.safetensors",
52+
"model-00003-of-00004.safetensors",
53+
"model-00004-of-00004.safetensors",
54+
"model.safetensors.index.json",
55+
],
56+
"index": ["model.safetensors.index.json"],
57+
}
58+
59+
def setUp(self):
60+
self.temp_dir = tempfile.mkdtemp()
61+
self.repo_id = "meta-llama/Llama-3.1-8B"
62+
63+
def tearDown(self):
64+
import shutil
65+
66+
shutil.rmtree(self.temp_dir, ignore_errors=True)
67+
68+
def _setup_mocks(self, mock_download, mock_list_files, repo_files=None):
69+
"""Helper to setup mock configurations"""
70+
mock_list_files.return_value = repo_files or self.COMPLETE_REPO_FILES
71+
mock_download.return_value = None
72+
73+
def _get_downloaded_files(self, mock_download):
74+
"""Helper to extract downloaded filenames from mock calls"""
75+
return [call[1]["filename"] for call in mock_download.call_args_list]
76+
77+
def _assert_files_downloaded(self, mock_download, expected_files):
78+
"""Helper to assert expected files were downloaded"""
79+
self.assertEqual(mock_download.call_count, len(expected_files))
80+
downloaded_files = self._get_downloaded_files(mock_download)
81+
for expected_file in expected_files:
82+
self.assertIn(expected_file, downloaded_files)
83+
84+
def _call_download_hf_assets(self, **kwargs):
85+
"""Helper to call download_hf_assets with common defaults"""
86+
defaults = {
87+
"repo_id": self.repo_id,
88+
"local_dir": self.temp_dir,
89+
}
90+
defaults.update(kwargs)
91+
return download_hf_assets(**defaults)
92+
93+
@patch("huggingface_hub.list_repo_files")
94+
@patch("huggingface_hub.hf_hub_download")
95+
def test_download_single_asset_types(self, mock_download, mock_list_files):
96+
"""Test downloading individual asset types"""
97+
self._setup_mocks(mock_download, mock_list_files)
98+
99+
# Test each asset type individually
100+
for asset_type, expected_files in self.EXPECTED_FILES.items():
101+
with self.subTest(asset_type=asset_type):
102+
mock_download.reset_mock()
103+
self._call_download_hf_assets(asset_types=[asset_type])
104+
self._assert_files_downloaded(mock_download, expected_files)
105+
106+
@patch("huggingface_hub.list_repo_files")
107+
@patch("huggingface_hub.hf_hub_download")
108+
def test_download_multiple_asset_types(self, mock_download, mock_list_files):
109+
"""Test downloading multiple asset types together"""
110+
self._setup_mocks(mock_download, mock_list_files)
111+
112+
# Get all expected files (removing duplicates)
113+
all_expected_files = set()
114+
for files in self.EXPECTED_FILES.values():
115+
all_expected_files.update(files)
116+
117+
self._call_download_hf_assets(asset_types=list(self.EXPECTED_FILES.keys()))
118+
self._assert_files_downloaded(mock_download, all_expected_files)
119+
120+
@patch("huggingface_hub.list_repo_files")
121+
@patch("huggingface_hub.hf_hub_download")
122+
def test_download_all_files(self, mock_download, mock_list_files):
123+
"""Test downloading all files with --all option"""
124+
self._setup_mocks(mock_download, mock_list_files)
125+
126+
self._call_download_hf_assets(asset_types=[], download_all=True)
127+
self._assert_files_downloaded(mock_download, self.COMPLETE_REPO_FILES)
128+
129+
@patch("huggingface_hub.list_repo_files")
130+
@patch("huggingface_hub.hf_hub_download")
131+
def test_additional_patterns(self, mock_download, mock_list_files):
132+
"""Test downloading with additional file patterns"""
133+
test_files = ["tokenizer.json", "custom_file.txt", "README.md"]
134+
self._setup_mocks(mock_download, mock_list_files, repo_files=test_files)
135+
136+
self._call_download_hf_assets(
137+
asset_types=["tokenizer"], additional_patterns=["*.txt"]
138+
)
139+
140+
# Only tokenizer.json and custom_file.txt should be downloaded
141+
expected_files = ["tokenizer.json", "custom_file.txt"]
142+
self._assert_files_downloaded(mock_download, expected_files)
143+
144+
@patch("huggingface_hub.hf_hub_download")
145+
def test_list_files(self, mock_download):
146+
"""Tests that list files returns correct list of files by using real huggingface_hub.list_files"""
147+
"""This test uses larger deepseek-ai/DeepSeek-V3 repo for more thorough test"""
148+
149+
# Setup mock download
150+
mock_download.return_value = None
151+
152+
# Test downloading safetensors asset type
153+
self._call_download_hf_assets(
154+
repo_id="deepseek-ai/DeepSeek-V3",
155+
asset_types=["safetensors"],
156+
)
157+
158+
# Verify all 163 safetensors files plus index file are downloaded
159+
expected_safetensors_files = [
160+
f"model-{i:05d}-of-000163.safetensors" for i in range(1, 164)
161+
]
162+
expected_files = expected_safetensors_files + [
163+
"model.safetensors.index.json",
164+
]
165+
166+
self._assert_files_downloaded(mock_download, expected_files)
167+
168+
@patch("huggingface_hub.list_repo_files")
169+
@patch("huggingface_hub.hf_hub_download")
170+
def test_nested_directory_handling(self, mock_download, mock_list_files):
171+
"""Tests that files in nested directory files are detected and downloaded correctly"""
172+
test_files = [
173+
"tokenizer.json",
174+
"original/tokenizer.model",
175+
"original/consolidated.00.pth", # Should NOT be downloaded (no .pth pattern)
176+
"config.json",
177+
]
178+
self._setup_mocks(mock_download, mock_list_files, repo_files=test_files)
179+
180+
self._call_download_hf_assets(asset_types=["tokenizer", "config"])
181+
182+
# Verify nested tokenizer file is downloaded but .pth file is not
183+
expected_files = ["tokenizer.json", "original/tokenizer.model", "config.json"]
184+
self._assert_files_downloaded(mock_download, expected_files)
185+
186+
# Verify .pth file was NOT downloaded
187+
downloaded_files = self._get_downloaded_files(mock_download)
188+
self.assertNotIn("original/consolidated.00.pth", downloaded_files)
189+
190+
@patch("huggingface_hub.list_repo_files")
191+
def test_missing_files_warning(self, mock_list_files):
192+
"""Test warning when requested files are not found"""
193+
mock_list_files.return_value = ["config.json", "README.md"]
194+
195+
with patch("builtins.print") as mock_print:
196+
self._call_download_hf_assets(asset_types=["tokenizer"])
197+
198+
# Check that warning was printed
199+
warning_calls = [
200+
call
201+
for call in mock_print.call_args_list
202+
if "Warning: No matching files found for asset_type 'tokenizer'"
203+
in str(call)
204+
]
205+
self.assertTrue(len(warning_calls) > 0)
206+
207+
@patch("huggingface_hub.list_repo_files")
208+
@patch("huggingface_hub.hf_hub_download")
209+
def test_download_failure_handling(self, mock_download, mock_list_files):
210+
"""Test handling of download failures"""
211+
from requests.exceptions import HTTPError
212+
213+
self._setup_mocks(
214+
mock_download,
215+
mock_list_files,
216+
repo_files=["tokenizer.json", "missing_file.json"],
217+
)
218+
219+
# Mock 404 error for missing file
220+
def download_side_effect(*args, **kwargs):
221+
if kwargs["filename"] == "missing_file.json":
222+
response = Mock()
223+
response.status_code = 404
224+
raise HTTPError(response=response)
225+
return None
226+
227+
mock_download.side_effect = download_side_effect
228+
229+
with patch("builtins.print") as mock_print:
230+
self._call_download_hf_assets(
231+
asset_types=["tokenizer"], additional_patterns=["missing_file.json"]
232+
)
233+
234+
# Check that 404 error was handled gracefully
235+
error_calls = [
236+
call
237+
for call in mock_print.call_args_list
238+
if "File missing_file.json not found, skipping..." in str(call)
239+
]
240+
self.assertTrue(len(error_calls) > 0)
241+
242+
def test_invalid_repo_id_format(self):
243+
"""Test error handling for invalid repo_id format"""
244+
with self.assertRaises(ValueError) as context:
245+
self._call_download_hf_assets(
246+
repo_id="invalid-repo-id", asset_types=["tokenizer"]
247+
)
248+
self.assertIn("Invalid repo_id format", str(context.exception))
249+
250+
def test_unknown_asset_type(self):
251+
"""Test error handling for unknown asset type"""
252+
with self.assertRaises(ValueError) as context:
253+
self._call_download_hf_assets(asset_types=["unknown_type"])
254+
self.assertIn("Unknown asset type unknown_type", str(context.exception))
255+
256+
257+
if __name__ == "__main__":
258+
unittest.main()

0 commit comments

Comments
 (0)