14
14
import json
15
15
import os
16
16
from pathlib import Path
17
- from typing import cast
18
17
19
18
import torch
20
19
from huggingface_hub import snapshot_download
@@ -81,8 +80,9 @@ def download_model_checkpoint_from_hub(
81
80
logger .info (f"Downloaded a model checkpoint from HuggingFace to: { local_path } " )
82
81
return Path (local_path )
83
82
except Exception as hf_exception :
84
- logger .error (f"Failed to download checkpoint: { hf_exception } " )
85
- raise FileNotFoundError (f"Checkpoint not found: { model_id } " ) from hf_exception
83
+ raise FileNotFoundError (
84
+ f"Failed to download checkpoint for { model_id } : { hf_exception } "
85
+ ) from hf_exception
86
86
87
87
88
88
def check_download_model_checkpoint (
@@ -111,6 +111,8 @@ def check_download_model_checkpoint(
111
111
:raises TypeError: If model is not a supported type
112
112
:raises ValueError: If local path is not a directory
113
113
"""
114
+ logger .debug (f"Checking download model checkpoint for: { model } " )
115
+
114
116
if isinstance (model , (PreTrainedModel , nn .Module )):
115
117
logger .debug ("Model is already a PreTrainedModel or nn.Module instance" )
116
118
return model
@@ -122,26 +124,20 @@ def check_download_model_checkpoint(
122
124
123
125
checkpoint_path = Path (model )
124
126
125
- if not checkpoint_path .exists ():
126
- logger .debug (
127
- f"Model path does not exist, downloading from hub: { checkpoint_path } "
128
- )
129
- return download_model_checkpoint_from_hub (
130
- model_id = str (checkpoint_path ),
131
- cache_dir = cache_dir ,
132
- force_download = force_download ,
133
- local_files_only = local_files_only ,
134
- token = token ,
135
- revision = revision ,
136
- ** kwargs ,
137
- )
138
-
139
- if not checkpoint_path .is_dir ():
140
- raise ValueError (
141
- f"Expected a directory for checkpoint, got file: { checkpoint_path } "
142
- )
143
-
144
- return checkpoint_path .resolve ()
127
+ if checkpoint_path .exists ():
128
+ logger .debug (f"Model path exists locally: { checkpoint_path } " )
129
+ return checkpoint_path .resolve ()
130
+
131
+ logger .debug (f"Model path does not exist, downloading from hub: { checkpoint_path } " )
132
+ return download_model_checkpoint_from_hub (
133
+ model_id = str (checkpoint_path ),
134
+ cache_dir = cache_dir ,
135
+ force_download = force_download ,
136
+ local_files_only = local_files_only ,
137
+ token = token ,
138
+ revision = revision ,
139
+ ** kwargs ,
140
+ )
145
141
146
142
147
143
def check_download_model_config (
@@ -189,31 +185,36 @@ def check_download_model_config(
189
185
)
190
186
191
187
config_path = Path (config )
192
- if not config_path .exists ():
193
- logger .debug (f"Config path does not exist, downloading from hub: { config_path } " )
194
- return AutoConfig .from_pretrained (
195
- str (config_path ),
188
+ if config_path .exists () and config_path .is_dir ():
189
+ logger .debug ("Config path is a directory, looking for config.json in it" )
190
+ config_path = config_path / "config.json"
191
+
192
+ if config_path .exists () and config_path .is_file ():
193
+ logger .debug (f"Using local config file: { config_path } " )
194
+ return config_path .resolve ()
195
+
196
+ try :
197
+ local_path = snapshot_download (
198
+ str (config ),
199
+ allow_patterns = ["config.json" ],
196
200
cache_dir = cache_dir ,
197
201
force_download = force_download ,
198
202
local_files_only = local_files_only ,
199
203
token = token ,
200
204
revision = revision ,
201
205
** kwargs ,
202
206
)
207
+ logger .info (f"Downloaded a model config from HuggingFace to: { local_path } " )
203
208
204
- logger .debug (f"Using local config path: { config_path } " )
205
-
206
- if not config_path .is_file ():
207
- config_path = config_path / "config.json"
208
-
209
- if not config_path .exists ():
210
- raise FileNotFoundError (f"No config.json found at { config_path } " )
211
-
212
- return config_path .resolve ()
209
+ return Path (local_path ) / "config.json"
210
+ except Exception as hf_exception :
211
+ raise FileNotFoundError (
212
+ f"Failed to download config for { config } : { hf_exception } "
213
+ ) from hf_exception
213
214
214
215
215
216
def load_model_config (
216
- model : str | os .PathLike | PreTrainedModel | PretrainedConfig ,
217
+ model : str | os .PathLike | PreTrainedModel | PretrainedConfig | dict ,
217
218
cache_dir : str | Path | None = None ,
218
219
force_download : bool = False ,
219
220
local_files_only : bool = False ,
@@ -240,35 +241,45 @@ def load_model_config(
240
241
"""
241
242
logger .debug (f"Loading model config from: { model } " )
242
243
243
- if isinstance (model , PretrainedConfig ):
244
- logger .debug ("Model is already a PretrainedConfig instance" )
245
- return model
244
+ config = check_download_model_config (
245
+ model ,
246
+ cache_dir = cache_dir ,
247
+ force_download = force_download ,
248
+ local_files_only = local_files_only ,
249
+ token = token ,
250
+ revision = revision ,
251
+ ** kwargs ,
252
+ )
246
253
247
- if isinstance (model , PreTrainedModel ):
248
- logger .debug ("Model is a PreTrainedModel instance, returning its config " )
249
- return model . config # type: ignore[attr-defined]
254
+ if isinstance (config , PretrainedConfig ):
255
+ logger .debug ("Model is already a PretrainedConfig instance" )
256
+ return config
250
257
251
- if not isinstance (model , (str , os .PathLike )):
252
- raise TypeError (f"Expected model to be a string or Path, got { type (model )} " )
258
+ if isinstance (config , dict ):
259
+ logger .debug ("Model is a dictionary, loading config from dict" )
260
+ return PretrainedConfig .from_dict (config )
253
261
254
262
try :
255
263
logger .debug (f"Loading config with AutoConfig from: { model } " )
264
+ # use model to ensure proper handling of HF args
265
+ # it will resolve to the previously downloaded config path
256
266
return AutoConfig .from_pretrained (
257
- model ,
267
+ str ( model ) ,
258
268
cache_dir = cache_dir ,
259
269
force_download = force_download ,
260
270
local_files_only = local_files_only ,
261
271
token = token ,
262
272
revision = revision ,
263
273
** kwargs ,
264
274
)
265
- except ValueError as err :
266
- logger .error (f"Failed to load config from { model } : { err } " )
267
- raise FileNotFoundError (f"Config not found for model: { model } " ) from err
275
+ except Exception as hf_exception :
276
+ raise FileNotFoundError (
277
+ f"Failed to download model config for { model } : { hf_exception } "
278
+ ) from hf_exception
268
279
269
280
270
281
def load_model_checkpoint_config_dict (
271
- config : str | os .PathLike | PretrainedConfig | PreTrainedModel | dict ,
282
+ model : str | os .PathLike | PretrainedConfig | PreTrainedModel | dict ,
272
283
cache_dir : str | Path | None = None ,
273
284
force_download : bool = False ,
274
285
local_files_only : bool = False ,
@@ -282,7 +293,7 @@ def load_model_checkpoint_config_dict(
282
293
Supports loading from local config.json files, checkpoint directories,
283
294
or extracting from existing model/config instances.
284
295
285
- :param config : Local path, PretrainedConfig, PreTrainedModel, or dict
296
+ :param model : Local path, PretrainedConfig, PreTrainedModel, or dict
286
297
:param cache_dir: Directory to cache downloaded files
287
298
:param force_download: Whether to force re-download existing files
288
299
:param local_files_only: Only use cached files without downloading
@@ -293,45 +304,30 @@ def load_model_checkpoint_config_dict(
293
304
:raises TypeError: If config is not a supported type
294
305
:raises FileNotFoundError: If config.json cannot be found
295
306
"""
307
+ logger .debug (f"Loading model config dict from: { model } " )
308
+ config = check_download_model_config (
309
+ model ,
310
+ cache_dir = cache_dir ,
311
+ force_download = force_download ,
312
+ local_files_only = local_files_only ,
313
+ token = token ,
314
+ revision = revision ,
315
+ ** kwargs ,
316
+ )
317
+
296
318
if isinstance (config , dict ):
297
319
logger .debug ("Config is already a dictionary, returning as is" )
298
320
return config
299
321
300
- if isinstance (config , PreTrainedModel ):
301
- logger .debug ("Config is a PreTrainedModel instance, returning its config dict" )
302
- return config .config .to_dict () # type: ignore[attr-defined]
303
-
304
322
if isinstance (config , PretrainedConfig ):
305
323
logger .debug ("Config is a PretrainedConfig instance, returning its dict" )
306
324
return config .to_dict ()
307
325
308
- if not isinstance (config , (str , os .PathLike )):
309
- raise TypeError (
310
- f"Expected config to be a string, Path, PreTrainedModel, "
311
- f"or PretrainedConfig, got { type (config )} "
312
- )
326
+ if not isinstance (config , Path ):
327
+ raise TypeError (f"Expected config to be a Path, got { type (config )} " )
313
328
314
- path = cast (
315
- "Path" ,
316
- check_download_model_config (
317
- config ,
318
- cache_dir = cache_dir ,
319
- force_download = force_download ,
320
- local_files_only = local_files_only ,
321
- token = token ,
322
- revision = revision ,
323
- ** kwargs ,
324
- ),
325
- )
326
-
327
- if path .is_dir ():
328
- path = path / "config.json"
329
-
330
- if not path .exists ():
331
- raise FileNotFoundError (f"No config.json found at { path } " )
332
-
333
- logger .debug (f"Loading config from: { path } " )
334
- with path .open () as file :
329
+ logger .debug (f"Loading config from: { config } " )
330
+ with config .open () as file :
335
331
return json .load (file )
336
332
337
333
@@ -353,12 +349,10 @@ def load_model_checkpoint_index_weight_files(
353
349
if not isinstance (path , (str , os .PathLike )):
354
350
raise TypeError (f"Expected path to be a string or Path, got { type (path )} " )
355
351
356
- path = Path (path )
357
-
358
- if not path .exists ():
352
+ if not (path := Path (path )).exists ():
359
353
raise FileNotFoundError (f"Model checkpoint path does not exist: { path } " )
360
354
361
- if path .is_file () and path .suffix == ".index.json" :
355
+ if path .is_file () and path .name . endswith ( ".index.json" ) :
362
356
logger .debug (f"Single index file provided: { path } " )
363
357
index_files = [path ]
364
358
elif path .is_dir () and (glob_files := list (path .glob ("*.index.json" ))):
@@ -368,28 +362,38 @@ def load_model_checkpoint_index_weight_files(
368
362
logger .debug (f"No index files found in directory: { path } " )
369
363
return []
370
364
371
- files = []
365
+ files = set ()
372
366
373
367
for index_file in index_files :
374
368
if not index_file .exists ():
375
369
raise FileNotFoundError (
376
370
f"Index file under { path } at { index_file } does not exist"
377
371
)
372
+
378
373
logger .debug (f"Reading index file: { index_file } " )
379
374
with index_file .open () as file_handle :
380
375
index_data = json .load (file_handle )
381
- if not index_data .get ("weight_map" ):
382
- raise ValueError (f"Index file { index_file } does not contain a weight_map" )
376
+
377
+ if (
378
+ not isinstance (index_data , dict )
379
+ or not index_data .get ("weight_map" )
380
+ or not isinstance (index_data ["weight_map" ], dict )
381
+ ):
382
+ raise ValueError (
383
+ f"Index file { index_file } does not contain a valid weight_map"
384
+ )
385
+
383
386
for weight_file in set (index_data ["weight_map" ].values ()):
384
387
# Resolve relative paths to the index file's directory
385
- weight_file_path = Path (index_file ).parent / weight_file
386
- if not weight_file_path .exists ():
388
+ if not (
389
+ weight_file_path := Path (index_file ).parent / str (weight_file )
390
+ ).exists ():
387
391
raise FileNotFoundError (
388
392
f"Weight file for { path } at { weight_file_path } does not exist"
389
393
)
390
- files .append (weight_file_path )
394
+ files .add (weight_file_path . resolve () )
391
395
392
- return files
396
+ return list ( files )
393
397
394
398
395
399
def load_model_checkpoint_weight_files (path : str | os .PathLike ) -> list [Path ]:
@@ -408,14 +412,12 @@ def load_model_checkpoint_weight_files(path: str | os.PathLike) -> list[Path]:
408
412
if not isinstance (path , (str , os .PathLike )):
409
413
raise TypeError (f"Expected path to be a string or Path, got { type (path )} " )
410
414
411
- path = Path (path )
412
-
413
- if not path .exists ():
415
+ if not (path := Path (path )).exists ():
414
416
raise FileNotFoundError (f"Model checkpoint path does not exist: { path } " )
415
417
416
- if index_files := load_model_checkpoint_index_weight_files (path ):
417
- logger .debug (f"Found index files at { path } : { index_files } " )
418
- return index_files
418
+ if weight_index_files := load_model_checkpoint_index_weight_files (path ):
419
+ logger .debug (f"Found index files at { path } : { weight_index_files } " )
420
+ return weight_index_files
419
421
420
422
if path .is_file () and path .suffix in {".bin" , ".safetensors" }:
421
423
logger .debug (f"Single weight file provided: { path } " )
@@ -460,25 +462,26 @@ def load_model_checkpoint_state_dict(
460
462
:return: Dictionary mapping parameter names to tensors
461
463
:raises ValueError: If unsupported file format is encountered
462
464
"""
465
+ logger .debug (f"Loading model state dict from: { model } " )
466
+
467
+ model = check_download_model_checkpoint (
468
+ model ,
469
+ cache_dir = cache_dir ,
470
+ force_download = force_download ,
471
+ local_files_only = local_files_only ,
472
+ token = token ,
473
+ revision = revision ,
474
+ ** kwargs ,
475
+ )
476
+
463
477
if isinstance (model , (PreTrainedModel , nn .Module )):
464
478
logger .debug ("Model is already a PreTrainedModel or nn.Module instance" )
465
479
return model .state_dict () # type: ignore[union-attr]
466
480
467
481
logger .debug (f"Loading model weights from: { model } " )
468
- weight_files = load_model_checkpoint_weight_files (
469
- check_download_model_checkpoint (
470
- model ,
471
- cache_dir = cache_dir ,
472
- force_download = force_download ,
473
- local_files_only = local_files_only ,
474
- token = token ,
475
- revision = revision ,
476
- ** kwargs ,
477
- )
478
- )
482
+ weight_files = load_model_checkpoint_weight_files (model )
479
483
480
484
state_dict = {}
481
-
482
485
for file in weight_files :
483
486
if file .suffix == ".safetensors" :
484
487
logger .debug (f"Loading safetensors weights from: { file } " )
0 commit comments