|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +import tempfile |
| 8 | +import unittest |
| 9 | +from unittest.mock import Mock, patch |
| 10 | + |
| 11 | +from scripts.download_hf_assets import download_hf_assets |
| 12 | + |
| 13 | + |
| 14 | +class TestDownloadHfAssets(unittest.TestCase): |
| 15 | + """Tests for the download_hf_assets script |
| 16 | +
|
| 17 | + We mock `huggingface_hub.list_repo_files` and `huggingface_hub.hf_hub_download` to simulate the meta-llama/Llama-3.1-8B repo |
| 18 | + """ |
| 19 | + |
| 20 | + # Complete file list from the actual meta-llama/Llama-3.1-8B repository |
| 21 | + COMPLETE_REPO_FILES = [ |
| 22 | + "config.json", |
| 23 | + "generation_config.json", |
| 24 | + "model.safetensors.index.json", |
| 25 | + "model-00001-of-00004.safetensors", |
| 26 | + "model-00002-of-00004.safetensors", |
| 27 | + "model-00003-of-00004.safetensors", |
| 28 | + "model-00004-of-00004.safetensors", |
| 29 | + "original/consolidated.00.pth", |
| 30 | + "original/params.json", |
| 31 | + "original/tokenizer.model", |
| 32 | + "special_tokens_map.json", |
| 33 | + "tokenizer.json", |
| 34 | + "tokenizer_config.json", |
| 35 | + "LICENSE", |
| 36 | + "README.md", |
| 37 | + "USE_POLICY.md", |
| 38 | + ] |
| 39 | + |
| 40 | + # Expected files for each asset type |
| 41 | + EXPECTED_FILES = { |
| 42 | + "tokenizer": [ |
| 43 | + "tokenizer.json", |
| 44 | + "tokenizer_config.json", |
| 45 | + "special_tokens_map.json", |
| 46 | + "original/tokenizer.model", |
| 47 | + ], |
| 48 | + "config": ["config.json", "generation_config.json"], |
| 49 | + "safetensors": [ |
| 50 | + "model-00001-of-00004.safetensors", |
| 51 | + "model-00002-of-00004.safetensors", |
| 52 | + "model-00003-of-00004.safetensors", |
| 53 | + "model-00004-of-00004.safetensors", |
| 54 | + "model.safetensors.index.json", |
| 55 | + ], |
| 56 | + "index": ["model.safetensors.index.json"], |
| 57 | + } |
| 58 | + |
| 59 | + def setUp(self): |
| 60 | + self.temp_dir = tempfile.mkdtemp() |
| 61 | + self.repo_id = "meta-llama/Llama-3.1-8B" |
| 62 | + |
| 63 | + def tearDown(self): |
| 64 | + import shutil |
| 65 | + |
| 66 | + shutil.rmtree(self.temp_dir, ignore_errors=True) |
| 67 | + |
| 68 | + def _setup_mocks(self, mock_download, mock_list_files, repo_files=None): |
| 69 | + """Helper to setup mock configurations""" |
| 70 | + mock_list_files.return_value = repo_files or self.COMPLETE_REPO_FILES |
| 71 | + mock_download.return_value = None |
| 72 | + |
| 73 | + def _get_downloaded_files(self, mock_download): |
| 74 | + """Helper to extract downloaded filenames from mock calls""" |
| 75 | + return [call[1]["filename"] for call in mock_download.call_args_list] |
| 76 | + |
| 77 | + def _assert_files_downloaded(self, mock_download, expected_files): |
| 78 | + """Helper to assert expected files were downloaded""" |
| 79 | + self.assertEqual(mock_download.call_count, len(expected_files)) |
| 80 | + downloaded_files = self._get_downloaded_files(mock_download) |
| 81 | + for expected_file in expected_files: |
| 82 | + self.assertIn(expected_file, downloaded_files) |
| 83 | + |
| 84 | + def _call_download_hf_assets(self, **kwargs): |
| 85 | + """Helper to call download_hf_assets with common defaults""" |
| 86 | + defaults = { |
| 87 | + "repo_id": self.repo_id, |
| 88 | + "local_dir": self.temp_dir, |
| 89 | + } |
| 90 | + defaults.update(kwargs) |
| 91 | + return download_hf_assets(**defaults) |
| 92 | + |
| 93 | + @patch("huggingface_hub.list_repo_files") |
| 94 | + @patch("huggingface_hub.hf_hub_download") |
| 95 | + def test_download_single_asset_types(self, mock_download, mock_list_files): |
| 96 | + """Test downloading individual asset types""" |
| 97 | + self._setup_mocks(mock_download, mock_list_files) |
| 98 | + |
| 99 | + # Test each asset type individually |
| 100 | + for asset_type, expected_files in self.EXPECTED_FILES.items(): |
| 101 | + with self.subTest(asset_type=asset_type): |
| 102 | + mock_download.reset_mock() |
| 103 | + self._call_download_hf_assets(asset_types=[asset_type]) |
| 104 | + self._assert_files_downloaded(mock_download, expected_files) |
| 105 | + |
| 106 | + @patch("huggingface_hub.list_repo_files") |
| 107 | + @patch("huggingface_hub.hf_hub_download") |
| 108 | + def test_download_multiple_asset_types(self, mock_download, mock_list_files): |
| 109 | + """Test downloading multiple asset types together""" |
| 110 | + self._setup_mocks(mock_download, mock_list_files) |
| 111 | + |
| 112 | + # Get all expected files (removing duplicates) |
| 113 | + all_expected_files = set() |
| 114 | + for files in self.EXPECTED_FILES.values(): |
| 115 | + all_expected_files.update(files) |
| 116 | + |
| 117 | + self._call_download_hf_assets(asset_types=list(self.EXPECTED_FILES.keys())) |
| 118 | + self._assert_files_downloaded(mock_download, all_expected_files) |
| 119 | + |
| 120 | + @patch("huggingface_hub.list_repo_files") |
| 121 | + @patch("huggingface_hub.hf_hub_download") |
| 122 | + def test_download_all_files(self, mock_download, mock_list_files): |
| 123 | + """Test downloading all files with --all option""" |
| 124 | + self._setup_mocks(mock_download, mock_list_files) |
| 125 | + |
| 126 | + self._call_download_hf_assets(asset_types=[], download_all=True) |
| 127 | + self._assert_files_downloaded(mock_download, self.COMPLETE_REPO_FILES) |
| 128 | + |
| 129 | + @patch("huggingface_hub.list_repo_files") |
| 130 | + @patch("huggingface_hub.hf_hub_download") |
| 131 | + def test_additional_patterns(self, mock_download, mock_list_files): |
| 132 | + """Test downloading with additional file patterns""" |
| 133 | + test_files = ["tokenizer.json", "custom_file.txt", "README.md"] |
| 134 | + self._setup_mocks(mock_download, mock_list_files, repo_files=test_files) |
| 135 | + |
| 136 | + self._call_download_hf_assets( |
| 137 | + asset_types=["tokenizer"], additional_patterns=["*.txt"] |
| 138 | + ) |
| 139 | + |
| 140 | + # Only tokenizer.json and custom_file.txt should be downloaded |
| 141 | + expected_files = ["tokenizer.json", "custom_file.txt"] |
| 142 | + self._assert_files_downloaded(mock_download, expected_files) |
| 143 | + |
| 144 | + @patch("huggingface_hub.hf_hub_download") |
| 145 | + def test_list_files(self, mock_download): |
| 146 | + """Tests that list files returns correct list of files by using real huggingface_hub.list_files""" |
| 147 | + """This test uses larger deepseek-ai/DeepSeek-V3 repo for more thorough test""" |
| 148 | + |
| 149 | + # Setup mock download |
| 150 | + mock_download.return_value = None |
| 151 | + |
| 152 | + # Test downloading safetensors asset type |
| 153 | + self._call_download_hf_assets( |
| 154 | + repo_id="deepseek-ai/DeepSeek-V3", |
| 155 | + asset_types=["safetensors"], |
| 156 | + ) |
| 157 | + |
| 158 | + # Verify all 163 safetensors files plus index file are downloaded |
| 159 | + expected_safetensors_files = [ |
| 160 | + f"model-{i:05d}-of-000163.safetensors" for i in range(1, 164) |
| 161 | + ] |
| 162 | + expected_files = expected_safetensors_files + [ |
| 163 | + "model.safetensors.index.json", |
| 164 | + ] |
| 165 | + |
| 166 | + self._assert_files_downloaded(mock_download, expected_files) |
| 167 | + |
| 168 | + @patch("huggingface_hub.list_repo_files") |
| 169 | + @patch("huggingface_hub.hf_hub_download") |
| 170 | + def test_nested_directory_handling(self, mock_download, mock_list_files): |
| 171 | + """Tests that files in nested directory files are detected and downloaded correctly""" |
| 172 | + test_files = [ |
| 173 | + "tokenizer.json", |
| 174 | + "original/tokenizer.model", |
| 175 | + "original/consolidated.00.pth", # Should NOT be downloaded (no .pth pattern) |
| 176 | + "config.json", |
| 177 | + ] |
| 178 | + self._setup_mocks(mock_download, mock_list_files, repo_files=test_files) |
| 179 | + |
| 180 | + self._call_download_hf_assets(asset_types=["tokenizer", "config"]) |
| 181 | + |
| 182 | + # Verify nested tokenizer file is downloaded but .pth file is not |
| 183 | + expected_files = ["tokenizer.json", "original/tokenizer.model", "config.json"] |
| 184 | + self._assert_files_downloaded(mock_download, expected_files) |
| 185 | + |
| 186 | + # Verify .pth file was NOT downloaded |
| 187 | + downloaded_files = self._get_downloaded_files(mock_download) |
| 188 | + self.assertNotIn("original/consolidated.00.pth", downloaded_files) |
| 189 | + |
| 190 | + @patch("huggingface_hub.list_repo_files") |
| 191 | + def test_missing_files_warning(self, mock_list_files): |
| 192 | + """Test warning when requested files are not found""" |
| 193 | + mock_list_files.return_value = ["config.json", "README.md"] |
| 194 | + |
| 195 | + with patch("builtins.print") as mock_print: |
| 196 | + self._call_download_hf_assets(asset_types=["tokenizer"]) |
| 197 | + |
| 198 | + # Check that warning was printed |
| 199 | + warning_calls = [ |
| 200 | + call |
| 201 | + for call in mock_print.call_args_list |
| 202 | + if "Warning: No matching files found for asset_type 'tokenizer'" |
| 203 | + in str(call) |
| 204 | + ] |
| 205 | + self.assertTrue(len(warning_calls) > 0) |
| 206 | + |
| 207 | + @patch("huggingface_hub.list_repo_files") |
| 208 | + @patch("huggingface_hub.hf_hub_download") |
| 209 | + def test_download_failure_handling(self, mock_download, mock_list_files): |
| 210 | + """Test handling of download failures""" |
| 211 | + from requests.exceptions import HTTPError |
| 212 | + |
| 213 | + self._setup_mocks( |
| 214 | + mock_download, |
| 215 | + mock_list_files, |
| 216 | + repo_files=["tokenizer.json", "missing_file.json"], |
| 217 | + ) |
| 218 | + |
| 219 | + # Mock 404 error for missing file |
| 220 | + def download_side_effect(*args, **kwargs): |
| 221 | + if kwargs["filename"] == "missing_file.json": |
| 222 | + response = Mock() |
| 223 | + response.status_code = 404 |
| 224 | + raise HTTPError(response=response) |
| 225 | + return None |
| 226 | + |
| 227 | + mock_download.side_effect = download_side_effect |
| 228 | + |
| 229 | + with patch("builtins.print") as mock_print: |
| 230 | + self._call_download_hf_assets( |
| 231 | + asset_types=["tokenizer"], additional_patterns=["missing_file.json"] |
| 232 | + ) |
| 233 | + |
| 234 | + # Check that 404 error was handled gracefully |
| 235 | + error_calls = [ |
| 236 | + call |
| 237 | + for call in mock_print.call_args_list |
| 238 | + if "File missing_file.json not found, skipping..." in str(call) |
| 239 | + ] |
| 240 | + self.assertTrue(len(error_calls) > 0) |
| 241 | + |
| 242 | + def test_invalid_repo_id_format(self): |
| 243 | + """Test error handling for invalid repo_id format""" |
| 244 | + with self.assertRaises(ValueError) as context: |
| 245 | + self._call_download_hf_assets( |
| 246 | + repo_id="invalid-repo-id", asset_types=["tokenizer"] |
| 247 | + ) |
| 248 | + self.assertIn("Invalid repo_id format", str(context.exception)) |
| 249 | + |
| 250 | + def test_unknown_asset_type(self): |
| 251 | + """Test error handling for unknown asset type""" |
| 252 | + with self.assertRaises(ValueError) as context: |
| 253 | + self._call_download_hf_assets(asset_types=["unknown_type"]) |
| 254 | + self.assertIn("Unknown asset type unknown_type", str(context.exception)) |
| 255 | + |
| 256 | + |
| 257 | +if __name__ == "__main__": |
| 258 | + unittest.main() |
0 commit comments