1414import json
1515import os
1616from pathlib import Path
17- from typing import cast
1817
1918import torch
2019from huggingface_hub import snapshot_download
@@ -81,8 +80,9 @@ def download_model_checkpoint_from_hub(
8180 logger .info (f"Downloaded a model checkpoint from HuggingFace to: { local_path } " )
8281 return Path (local_path )
8382 except Exception as hf_exception :
84- logger .error (f"Failed to download checkpoint: { hf_exception } " )
85- raise FileNotFoundError (f"Checkpoint not found: { model_id } " ) from hf_exception
83+ raise FileNotFoundError (
84+ f"Failed to download checkpoint for { model_id } : { hf_exception } "
85+ ) from hf_exception
8686
8787
8888def check_download_model_checkpoint (
@@ -111,6 +111,8 @@ def check_download_model_checkpoint(
111111 :raises TypeError: If model is not a supported type
112112 :raises ValueError: If local path is not a directory
113113 """
114+ logger .debug (f"Checking download model checkpoint for: { model } " )
115+
114116 if isinstance (model , (PreTrainedModel , nn .Module )):
115117 logger .debug ("Model is already a PreTrainedModel or nn.Module instance" )
116118 return model
@@ -122,26 +124,20 @@ def check_download_model_checkpoint(
122124
123125 checkpoint_path = Path (model )
124126
125- if not checkpoint_path .exists ():
126- logger .debug (
127- f"Model path does not exist, downloading from hub: { checkpoint_path } "
128- )
129- return download_model_checkpoint_from_hub (
130- model_id = str (checkpoint_path ),
131- cache_dir = cache_dir ,
132- force_download = force_download ,
133- local_files_only = local_files_only ,
134- token = token ,
135- revision = revision ,
136- ** kwargs ,
137- )
138-
139- if not checkpoint_path .is_dir ():
140- raise ValueError (
141- f"Expected a directory for checkpoint, got file: { checkpoint_path } "
142- )
143-
144- return checkpoint_path .resolve ()
127+ if checkpoint_path .exists ():
128+ logger .debug (f"Model path exists locally: { checkpoint_path } " )
129+ return checkpoint_path .resolve ()
130+
131+ logger .debug (f"Model path does not exist, downloading from hub: { checkpoint_path } " )
132+ return download_model_checkpoint_from_hub (
133+ model_id = str (checkpoint_path ),
134+ cache_dir = cache_dir ,
135+ force_download = force_download ,
136+ local_files_only = local_files_only ,
137+ token = token ,
138+ revision = revision ,
139+ ** kwargs ,
140+ )
145141
146142
147143def check_download_model_config (
@@ -189,31 +185,36 @@ def check_download_model_config(
189185 )
190186
191187 config_path = Path (config )
192- if not config_path .exists ():
193- logger .debug (f"Config path does not exist, downloading from hub: { config_path } " )
194- return AutoConfig .from_pretrained (
195- str (config_path ),
188+ if config_path .exists () and config_path .is_dir ():
189+ logger .debug ("Config path is a directory, looking for config.json in it" )
190+ config_path = config_path / "config.json"
191+
192+ if config_path .exists () and config_path .is_file ():
193+ logger .debug (f"Using local config file: { config_path } " )
194+ return config_path .resolve ()
195+
196+ try :
197+ local_path = snapshot_download (
198+ str (config ),
199+ allow_patterns = ["config.json" ],
196200 cache_dir = cache_dir ,
197201 force_download = force_download ,
198202 local_files_only = local_files_only ,
199203 token = token ,
200204 revision = revision ,
201205 ** kwargs ,
202206 )
207+ logger .info (f"Downloaded a model config from HuggingFace to: { local_path } " )
203208
204- logger .debug (f"Using local config path: { config_path } " )
205-
206- if not config_path .is_file ():
207- config_path = config_path / "config.json"
208-
209- if not config_path .exists ():
210- raise FileNotFoundError (f"No config.json found at { config_path } " )
211-
212- return config_path .resolve ()
209+ return Path (local_path ) / "config.json"
210+ except Exception as hf_exception :
211+ raise FileNotFoundError (
212+ f"Failed to download config for { config } : { hf_exception } "
213+ ) from hf_exception
213214
214215
215216def load_model_config (
216- model : str | os .PathLike | PreTrainedModel | PretrainedConfig ,
217+ model : str | os .PathLike | PreTrainedModel | PretrainedConfig | dict ,
217218 cache_dir : str | Path | None = None ,
218219 force_download : bool = False ,
219220 local_files_only : bool = False ,
@@ -240,35 +241,45 @@ def load_model_config(
240241 """
241242 logger .debug (f"Loading model config from: { model } " )
242243
243- if isinstance (model , PretrainedConfig ):
244- logger .debug ("Model is already a PretrainedConfig instance" )
245- return model
244+ config = check_download_model_config (
245+ model ,
246+ cache_dir = cache_dir ,
247+ force_download = force_download ,
248+ local_files_only = local_files_only ,
249+ token = token ,
250+ revision = revision ,
251+ ** kwargs ,
252+ )
246253
247- if isinstance (model , PreTrainedModel ):
248- logger .debug ("Model is a PreTrainedModel instance, returning its config " )
249- return model . config # type: ignore[attr-defined]
254+ if isinstance (config , PretrainedConfig ):
255+ logger .debug ("Model is already a PretrainedConfig instance" )
256+ return config
250257
251- if not isinstance (model , (str , os .PathLike )):
252- raise TypeError (f"Expected model to be a string or Path, got { type (model )} " )
258+ if isinstance (config , dict ):
259+ logger .debug ("Model is a dictionary, loading config from dict" )
260+ return PretrainedConfig .from_dict (config )
253261
254262 try :
255263 logger .debug (f"Loading config with AutoConfig from: { model } " )
264+ # use model to ensure proper handling of HF args
265+ # it will resolve to the previously downloaded config path
256266 return AutoConfig .from_pretrained (
257- model ,
267+ str ( model ) ,
258268 cache_dir = cache_dir ,
259269 force_download = force_download ,
260270 local_files_only = local_files_only ,
261271 token = token ,
262272 revision = revision ,
263273 ** kwargs ,
264274 )
265- except ValueError as err :
266- logger .error (f"Failed to load config from { model } : { err } " )
267- raise FileNotFoundError (f"Config not found for model: { model } " ) from err
275+ except Exception as hf_exception :
276+ raise FileNotFoundError (
277+ f"Failed to download model config for { model } : { hf_exception } "
278+ ) from hf_exception
268279
269280
270281def load_model_checkpoint_config_dict (
271- config : str | os .PathLike | PretrainedConfig | PreTrainedModel | dict ,
282+ model : str | os .PathLike | PretrainedConfig | PreTrainedModel | dict ,
272283 cache_dir : str | Path | None = None ,
273284 force_download : bool = False ,
274285 local_files_only : bool = False ,
@@ -282,7 +293,7 @@ def load_model_checkpoint_config_dict(
282293 Supports loading from local config.json files, checkpoint directories,
283294 or extracting from existing model/config instances.
284295
285- :param config : Local path, PretrainedConfig, PreTrainedModel, or dict
296+ :param model : Local path, PretrainedConfig, PreTrainedModel, or dict
286297 :param cache_dir: Directory to cache downloaded files
287298 :param force_download: Whether to force re-download existing files
288299 :param local_files_only: Only use cached files without downloading
@@ -293,45 +304,30 @@ def load_model_checkpoint_config_dict(
293304 :raises TypeError: If config is not a supported type
294305 :raises FileNotFoundError: If config.json cannot be found
295306 """
307+ logger .debug (f"Loading model config dict from: { model } " )
308+ config = check_download_model_config (
309+ model ,
310+ cache_dir = cache_dir ,
311+ force_download = force_download ,
312+ local_files_only = local_files_only ,
313+ token = token ,
314+ revision = revision ,
315+ ** kwargs ,
316+ )
317+
296318 if isinstance (config , dict ):
297319 logger .debug ("Config is already a dictionary, returning as is" )
298320 return config
299321
300- if isinstance (config , PreTrainedModel ):
301- logger .debug ("Config is a PreTrainedModel instance, returning its config dict" )
302- return config .config .to_dict () # type: ignore[attr-defined]
303-
304322 if isinstance (config , PretrainedConfig ):
305323 logger .debug ("Config is a PretrainedConfig instance, returning its dict" )
306324 return config .to_dict ()
307325
308- if not isinstance (config , (str , os .PathLike )):
309- raise TypeError (
310- f"Expected config to be a string, Path, PreTrainedModel, "
311- f"or PretrainedConfig, got { type (config )} "
312- )
326+ if not isinstance (config , Path ):
327+ raise TypeError (f"Expected config to be a Path, got { type (config )} " )
313328
314- path = cast (
315- "Path" ,
316- check_download_model_config (
317- config ,
318- cache_dir = cache_dir ,
319- force_download = force_download ,
320- local_files_only = local_files_only ,
321- token = token ,
322- revision = revision ,
323- ** kwargs ,
324- ),
325- )
326-
327- if path .is_dir ():
328- path = path / "config.json"
329-
330- if not path .exists ():
331- raise FileNotFoundError (f"No config.json found at { path } " )
332-
333- logger .debug (f"Loading config from: { path } " )
334- with path .open () as file :
329+ logger .debug (f"Loading config from: { config } " )
330+ with config .open () as file :
335331 return json .load (file )
336332
337333
@@ -353,12 +349,10 @@ def load_model_checkpoint_index_weight_files(
353349 if not isinstance (path , (str , os .PathLike )):
354350 raise TypeError (f"Expected path to be a string or Path, got { type (path )} " )
355351
356- path = Path (path )
357-
358- if not path .exists ():
352+ if not (path := Path (path )).exists ():
359353 raise FileNotFoundError (f"Model checkpoint path does not exist: { path } " )
360354
361- if path .is_file () and path .suffix == ".index.json" :
355+ if path .is_file () and path .name . endswith ( ".index.json" ) :
362356 logger .debug (f"Single index file provided: { path } " )
363357 index_files = [path ]
364358 elif path .is_dir () and (glob_files := list (path .glob ("*.index.json" ))):
@@ -368,28 +362,38 @@ def load_model_checkpoint_index_weight_files(
368362 logger .debug (f"No index files found in directory: { path } " )
369363 return []
370364
371- files = []
365+ files = set ()
372366
373367 for index_file in index_files :
374368 if not index_file .exists ():
375369 raise FileNotFoundError (
376370 f"Index file under { path } at { index_file } does not exist"
377371 )
372+
378373 logger .debug (f"Reading index file: { index_file } " )
379374 with index_file .open () as file_handle :
380375 index_data = json .load (file_handle )
381- if not index_data .get ("weight_map" ):
382- raise ValueError (f"Index file { index_file } does not contain a weight_map" )
376+
377+ if (
378+ not isinstance (index_data , dict )
379+ or not index_data .get ("weight_map" )
380+ or not isinstance (index_data ["weight_map" ], dict )
381+ ):
382+ raise ValueError (
383+ f"Index file { index_file } does not contain a valid weight_map"
384+ )
385+
383386 for weight_file in set (index_data ["weight_map" ].values ()):
384387 # Resolve relative paths to the index file's directory
385- weight_file_path = Path (index_file ).parent / weight_file
386- if not weight_file_path .exists ():
388+ if not (
389+ weight_file_path := Path (index_file ).parent / str (weight_file )
390+ ).exists ():
387391 raise FileNotFoundError (
388392 f"Weight file for { path } at { weight_file_path } does not exist"
389393 )
390- files .append (weight_file_path )
394+ files .add (weight_file_path . resolve () )
391395
392- return files
396+ return list ( files )
393397
394398
395399def load_model_checkpoint_weight_files (path : str | os .PathLike ) -> list [Path ]:
@@ -408,14 +412,12 @@ def load_model_checkpoint_weight_files(path: str | os.PathLike) -> list[Path]:
408412 if not isinstance (path , (str , os .PathLike )):
409413 raise TypeError (f"Expected path to be a string or Path, got { type (path )} " )
410414
411- path = Path (path )
412-
413- if not path .exists ():
415+ if not (path := Path (path )).exists ():
414416 raise FileNotFoundError (f"Model checkpoint path does not exist: { path } " )
415417
416- if index_files := load_model_checkpoint_index_weight_files (path ):
417- logger .debug (f"Found index files at { path } : { index_files } " )
418- return index_files
418+ if weight_index_files := load_model_checkpoint_index_weight_files (path ):
419+ logger .debug (f"Found index files at { path } : { weight_index_files } " )
420+ return weight_index_files
419421
420422 if path .is_file () and path .suffix in {".bin" , ".safetensors" }:
421423 logger .debug (f"Single weight file provided: { path } " )
@@ -460,25 +462,26 @@ def load_model_checkpoint_state_dict(
460462 :return: Dictionary mapping parameter names to tensors
461463 :raises ValueError: If unsupported file format is encountered
462464 """
465+ logger .debug (f"Loading model state dict from: { model } " )
466+
467+ model = check_download_model_checkpoint (
468+ model ,
469+ cache_dir = cache_dir ,
470+ force_download = force_download ,
471+ local_files_only = local_files_only ,
472+ token = token ,
473+ revision = revision ,
474+ ** kwargs ,
475+ )
476+
463477 if isinstance (model , (PreTrainedModel , nn .Module )):
464478 logger .debug ("Model is already a PreTrainedModel or nn.Module instance" )
465479 return model .state_dict () # type: ignore[union-attr]
466480
467481 logger .debug (f"Loading model weights from: { model } " )
468- weight_files = load_model_checkpoint_weight_files (
469- check_download_model_checkpoint (
470- model ,
471- cache_dir = cache_dir ,
472- force_download = force_download ,
473- local_files_only = local_files_only ,
474- token = token ,
475- revision = revision ,
476- ** kwargs ,
477- )
478- )
482+ weight_files = load_model_checkpoint_weight_files (model )
479483
480484 state_dict = {}
481-
482485 for file in weight_files :
483486 if file .suffix == ".safetensors" :
484487 logger .debug (f"Loading safetensors weights from: { file } " )
0 commit comments