1414import json
1515import os
1616from pathlib import Path
17+ from typing import cast
1718
1819import torch
1920from huggingface_hub import snapshot_download
@@ -248,10 +249,7 @@ def load_model_config(
248249 return model .config # type: ignore[attr-defined]
249250
250251 if not isinstance (model , (str , os .PathLike )):
251- raise TypeError (
252- "Expected model to be a string, Path, or PreTrainedModel, "
253- f"got { type (model )} "
254- )
252+ raise TypeError (f"Expected model to be a string or Path, got { type (model )} " )
255253
256254 try :
257255 logger .debug (f"Loading config with AutoConfig from: { model } " )
@@ -271,6 +269,12 @@ def load_model_config(
271269
272270def load_model_checkpoint_config_dict (
273271 config : str | os .PathLike | PretrainedConfig | PreTrainedModel | dict ,
272+ cache_dir : str | Path | None = None ,
273+ force_download : bool = False ,
274+ local_files_only : bool = False ,
275+ token : str | bool | None = None ,
276+ revision : str | None = None ,
277+ ** kwargs ,
274278) -> dict :
275279 """
276280 Load model configuration as dictionary from various sources.
@@ -279,6 +283,12 @@ def load_model_checkpoint_config_dict(
279283 or extracting from existing model/config instances.
280284
281285 :param config: Local path, PretrainedConfig, PreTrainedModel, or dict
286+ :param cache_dir: Directory to cache downloaded files
287+ :param force_download: Whether to force re-download existing files
288+ :param local_files_only: Only use cached files without downloading
289+ :param token: Authentication token for private models
290+ :param revision: Model revision (branch, tag, or commit hash)
291+ :param kwargs: Additional arguments for `check_download_model_config`
282292 :return: Configuration dictionary
283293 :raises TypeError: If config is not a supported type
284294 :raises FileNotFoundError: If config.json cannot be found
@@ -301,7 +311,18 @@ def load_model_checkpoint_config_dict(
301311 f"or PretrainedConfig, got { type (config )} "
302312 )
303313
304- path = Path (config )
314+ path = cast (
315+ "Path" ,
316+ check_download_model_config (
317+ config ,
318+ cache_dir = cache_dir ,
319+ force_download = force_download ,
320+ local_files_only = local_files_only ,
321+ token = token ,
322+ revision = revision ,
323+ ** kwargs ,
324+ ),
325+ )
305326
306327 if path .is_dir ():
307328 path = path / "config.json"
@@ -378,7 +399,7 @@ def load_model_checkpoint_weight_files(path: str | os.PathLike) -> list[Path]:
378399 Searches for weight files in various formats (.bin, .safetensors) through
379400 automatic detection of different organization patterns.
380401
381- :param path: Local checkpoint directory, index file, or weight file path
402+ :param path: HF ID, local checkpoint directory, index file, or weight file path
382403 :return: List of paths to weight files
383404 :raises TypeError: If path is not a string or Path-like object
384405 :raises FileNotFoundError: If path doesn't exist or no weight files found
@@ -416,14 +437,26 @@ def load_model_checkpoint_weight_files(path: str | os.PathLike) -> list[Path]:
416437
417438def load_model_checkpoint_state_dict (
418439 model : str | os .PathLike | PreTrainedModel | nn .Module ,
440+ cache_dir : str | Path | None = None ,
441+ force_download : bool = False ,
442+ local_files_only : bool = False ,
443+ token : str | bool | None = None ,
444+ revision : str | None = None ,
445+ ** kwargs ,
419446) -> dict [str , Tensor ]:
420447 """
421448 Load complete model state dictionary from various sources.
422449
423450 Supports loading from model instances, local checkpoint directories,
424451 or individual weight files with automatic format detection.
425452
426- :param model: Model instance, checkpoint directory, or weight file path
453+ :param model: Model instance, HF ID, checkpoint directory, or weight file path
454+ :param cache_dir: Directory to cache downloaded files
455+ :param force_download: Whether to force re-download existing files
456+ :param local_files_only: Only use cached files without downloading
457+ :param token: Authentication token for private models
458+ :param revision: Model revision (branch, tag, or commit hash)
459+ :param kwargs: Additional arguments for `check_download_model_checkpoint`
427460 :return: Dictionary mapping parameter names to tensors
428461 :raises ValueError: If unsupported file format is encountered
429462 """
@@ -432,7 +465,17 @@ def load_model_checkpoint_state_dict(
432465 return model .state_dict () # type: ignore[union-attr]
433466
434467 logger .debug (f"Loading model weights from: { model } " )
435- weight_files = load_model_checkpoint_weight_files (model )
468+ weight_files = load_model_checkpoint_weight_files (
469+ check_download_model_checkpoint (
470+ model ,
471+ cache_dir = cache_dir ,
472+ force_download = force_download ,
473+ local_files_only = local_files_only ,
474+ token = token ,
475+ revision = revision ,
476+ ** kwargs ,
477+ )
478+ )
436479
437480 state_dict = {}
438481
0 commit comments