Skip to content

Commit a3c0f80

Browse files
committed
Updates for rebasing, fixes and simplifications, resolve styling
Signed-off-by: Mark Kurtz <[email protected]>
1 parent 041051a commit a3c0f80

File tree

6 files changed

+1164
-765
lines changed

6 files changed

+1164
-765
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@ ignore = [
156156
"COM812",
157157
"ISC001",
158158
"TC002",
159+
"TC003", # allow imports outside of type checking blocks
159160
"S311", # allow random number generators
160161
"PLW1514", # allow Path.open without encoding
161162
"RET505", # allow `else` blocks

src/speculators/utils/transformers_utils.py

Lines changed: 113 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
import json
1515
import os
1616
from pathlib import Path
17-
from typing import cast
1817

1918
import torch
2019
from huggingface_hub import snapshot_download
@@ -81,8 +80,9 @@ def download_model_checkpoint_from_hub(
8180
logger.info(f"Downloaded a model checkpoint from HuggingFace to: {local_path}")
8281
return Path(local_path)
8382
except Exception as hf_exception:
84-
logger.error(f"Failed to download checkpoint: {hf_exception}")
85-
raise FileNotFoundError(f"Checkpoint not found: {model_id}") from hf_exception
83+
raise FileNotFoundError(
84+
f"Failed to download checkpoint for {model_id}: {hf_exception}"
85+
) from hf_exception
8686

8787

8888
def check_download_model_checkpoint(
@@ -111,6 +111,8 @@ def check_download_model_checkpoint(
111111
:raises TypeError: If model is not a supported type
112112
:raises ValueError: If local path is not a directory
113113
"""
114+
logger.debug(f"Checking download model checkpoint for: {model}")
115+
114116
if isinstance(model, (PreTrainedModel, nn.Module)):
115117
logger.debug("Model is already a PreTrainedModel or nn.Module instance")
116118
return model
@@ -122,26 +124,20 @@ def check_download_model_checkpoint(
122124

123125
checkpoint_path = Path(model)
124126

125-
if not checkpoint_path.exists():
126-
logger.debug(
127-
f"Model path does not exist, downloading from hub: {checkpoint_path}"
128-
)
129-
return download_model_checkpoint_from_hub(
130-
model_id=str(checkpoint_path),
131-
cache_dir=cache_dir,
132-
force_download=force_download,
133-
local_files_only=local_files_only,
134-
token=token,
135-
revision=revision,
136-
**kwargs,
137-
)
138-
139-
if not checkpoint_path.is_dir():
140-
raise ValueError(
141-
f"Expected a directory for checkpoint, got file: {checkpoint_path}"
142-
)
143-
144-
return checkpoint_path.resolve()
127+
if checkpoint_path.exists():
128+
logger.debug(f"Model path exists locally: {checkpoint_path}")
129+
return checkpoint_path.resolve()
130+
131+
logger.debug(f"Model path does not exist, downloading from hub: {checkpoint_path}")
132+
return download_model_checkpoint_from_hub(
133+
model_id=str(checkpoint_path),
134+
cache_dir=cache_dir,
135+
force_download=force_download,
136+
local_files_only=local_files_only,
137+
token=token,
138+
revision=revision,
139+
**kwargs,
140+
)
145141

146142

147143
def check_download_model_config(
@@ -189,31 +185,36 @@ def check_download_model_config(
189185
)
190186

191187
config_path = Path(config)
192-
if not config_path.exists():
193-
logger.debug(f"Config path does not exist, downloading from hub: {config_path}")
194-
return AutoConfig.from_pretrained(
195-
str(config_path),
188+
if config_path.exists() and config_path.is_dir():
189+
logger.debug("Config path is a directory, looking for config.json in it")
190+
config_path = config_path / "config.json"
191+
192+
if config_path.exists() and config_path.is_file():
193+
logger.debug(f"Using local config file: {config_path}")
194+
return config_path.resolve()
195+
196+
try:
197+
local_path = snapshot_download(
198+
str(config),
199+
allow_patterns=["config.json"],
196200
cache_dir=cache_dir,
197201
force_download=force_download,
198202
local_files_only=local_files_only,
199203
token=token,
200204
revision=revision,
201205
**kwargs,
202206
)
207+
logger.info(f"Downloaded a model config from HuggingFace to: {local_path}")
203208

204-
logger.debug(f"Using local config path: {config_path}")
205-
206-
if not config_path.is_file():
207-
config_path = config_path / "config.json"
208-
209-
if not config_path.exists():
210-
raise FileNotFoundError(f"No config.json found at {config_path}")
211-
212-
return config_path.resolve()
209+
return Path(local_path) / "config.json"
210+
except Exception as hf_exception:
211+
raise FileNotFoundError(
212+
f"Failed to download config for {config}: {hf_exception}"
213+
) from hf_exception
213214

214215

215216
def load_model_config(
216-
model: str | os.PathLike | PreTrainedModel | PretrainedConfig,
217+
model: str | os.PathLike | PreTrainedModel | PretrainedConfig | dict,
217218
cache_dir: str | Path | None = None,
218219
force_download: bool = False,
219220
local_files_only: bool = False,
@@ -240,35 +241,45 @@ def load_model_config(
240241
"""
241242
logger.debug(f"Loading model config from: {model}")
242243

243-
if isinstance(model, PretrainedConfig):
244-
logger.debug("Model is already a PretrainedConfig instance")
245-
return model
244+
config = check_download_model_config(
245+
model,
246+
cache_dir=cache_dir,
247+
force_download=force_download,
248+
local_files_only=local_files_only,
249+
token=token,
250+
revision=revision,
251+
**kwargs,
252+
)
246253

247-
if isinstance(model, PreTrainedModel):
248-
logger.debug("Model is a PreTrainedModel instance, returning its config")
249-
return model.config # type: ignore[attr-defined]
254+
if isinstance(config, PretrainedConfig):
255+
logger.debug("Model is already a PretrainedConfig instance")
256+
return config
250257

251-
if not isinstance(model, (str, os.PathLike)):
252-
raise TypeError(f"Expected model to be a string or Path, got {type(model)}")
258+
if isinstance(config, dict):
259+
logger.debug("Model is a dictionary, loading config from dict")
260+
return PretrainedConfig.from_dict(config)
253261

254262
try:
255263
logger.debug(f"Loading config with AutoConfig from: {model}")
264+
# use model to ensure proper handling of HF args
265+
# it will resolve to the previously downloaded config path
256266
return AutoConfig.from_pretrained(
257-
model,
267+
str(model),
258268
cache_dir=cache_dir,
259269
force_download=force_download,
260270
local_files_only=local_files_only,
261271
token=token,
262272
revision=revision,
263273
**kwargs,
264274
)
265-
except ValueError as err:
266-
logger.error(f"Failed to load config from {model}: {err}")
267-
raise FileNotFoundError(f"Config not found for model: {model}") from err
275+
except Exception as hf_exception:
276+
raise FileNotFoundError(
277+
f"Failed to download model config for {model}: {hf_exception}"
278+
) from hf_exception
268279

269280

270281
def load_model_checkpoint_config_dict(
271-
config: str | os.PathLike | PretrainedConfig | PreTrainedModel | dict,
282+
model: str | os.PathLike | PretrainedConfig | PreTrainedModel | dict,
272283
cache_dir: str | Path | None = None,
273284
force_download: bool = False,
274285
local_files_only: bool = False,
@@ -282,7 +293,7 @@ def load_model_checkpoint_config_dict(
282293
Supports loading from local config.json files, checkpoint directories,
283294
or extracting from existing model/config instances.
284295
285-
:param config: Local path, PretrainedConfig, PreTrainedModel, or dict
296+
:param model: Local path, PretrainedConfig, PreTrainedModel, or dict
286297
:param cache_dir: Directory to cache downloaded files
287298
:param force_download: Whether to force re-download existing files
288299
:param local_files_only: Only use cached files without downloading
@@ -293,45 +304,30 @@ def load_model_checkpoint_config_dict(
293304
:raises TypeError: If config is not a supported type
294305
:raises FileNotFoundError: If config.json cannot be found
295306
"""
307+
logger.debug(f"Loading model config dict from: {model}")
308+
config = check_download_model_config(
309+
model,
310+
cache_dir=cache_dir,
311+
force_download=force_download,
312+
local_files_only=local_files_only,
313+
token=token,
314+
revision=revision,
315+
**kwargs,
316+
)
317+
296318
if isinstance(config, dict):
297319
logger.debug("Config is already a dictionary, returning as is")
298320
return config
299321

300-
if isinstance(config, PreTrainedModel):
301-
logger.debug("Config is a PreTrainedModel instance, returning its config dict")
302-
return config.config.to_dict() # type: ignore[attr-defined]
303-
304322
if isinstance(config, PretrainedConfig):
305323
logger.debug("Config is a PretrainedConfig instance, returning its dict")
306324
return config.to_dict()
307325

308-
if not isinstance(config, (str, os.PathLike)):
309-
raise TypeError(
310-
f"Expected config to be a string, Path, PreTrainedModel, "
311-
f"or PretrainedConfig, got {type(config)}"
312-
)
326+
if not isinstance(config, Path):
327+
raise TypeError(f"Expected config to be a Path, got {type(config)}")
313328

314-
path = cast(
315-
"Path",
316-
check_download_model_config(
317-
config,
318-
cache_dir=cache_dir,
319-
force_download=force_download,
320-
local_files_only=local_files_only,
321-
token=token,
322-
revision=revision,
323-
**kwargs,
324-
),
325-
)
326-
327-
if path.is_dir():
328-
path = path / "config.json"
329-
330-
if not path.exists():
331-
raise FileNotFoundError(f"No config.json found at {path}")
332-
333-
logger.debug(f"Loading config from: {path}")
334-
with path.open() as file:
329+
logger.debug(f"Loading config from: {config}")
330+
with config.open() as file:
335331
return json.load(file)
336332

337333

@@ -353,12 +349,10 @@ def load_model_checkpoint_index_weight_files(
353349
if not isinstance(path, (str, os.PathLike)):
354350
raise TypeError(f"Expected path to be a string or Path, got {type(path)}")
355351

356-
path = Path(path)
357-
358-
if not path.exists():
352+
if not (path := Path(path)).exists():
359353
raise FileNotFoundError(f"Model checkpoint path does not exist: {path}")
360354

361-
if path.is_file() and path.suffix == ".index.json":
355+
if path.is_file() and path.name.endswith(".index.json"):
362356
logger.debug(f"Single index file provided: {path}")
363357
index_files = [path]
364358
elif path.is_dir() and (glob_files := list(path.glob("*.index.json"))):
@@ -368,28 +362,38 @@ def load_model_checkpoint_index_weight_files(
368362
logger.debug(f"No index files found in directory: {path}")
369363
return []
370364

371-
files = []
365+
files = set()
372366

373367
for index_file in index_files:
374368
if not index_file.exists():
375369
raise FileNotFoundError(
376370
f"Index file under {path} at {index_file} does not exist"
377371
)
372+
378373
logger.debug(f"Reading index file: {index_file}")
379374
with index_file.open() as file_handle:
380375
index_data = json.load(file_handle)
381-
if not index_data.get("weight_map"):
382-
raise ValueError(f"Index file {index_file} does not contain a weight_map")
376+
377+
if (
378+
not isinstance(index_data, dict)
379+
or not index_data.get("weight_map")
380+
or not isinstance(index_data["weight_map"], dict)
381+
):
382+
raise ValueError(
383+
f"Index file {index_file} does not contain a valid weight_map"
384+
)
385+
383386
for weight_file in set(index_data["weight_map"].values()):
384387
# Resolve relative paths to the index file's directory
385-
weight_file_path = Path(index_file).parent / weight_file
386-
if not weight_file_path.exists():
388+
if not (
389+
weight_file_path := Path(index_file).parent / str(weight_file)
390+
).exists():
387391
raise FileNotFoundError(
388392
f"Weight file for {path} at {weight_file_path} does not exist"
389393
)
390-
files.append(weight_file_path)
394+
files.add(weight_file_path.resolve())
391395

392-
return files
396+
return list(files)
393397

394398

395399
def load_model_checkpoint_weight_files(path: str | os.PathLike) -> list[Path]:
@@ -408,14 +412,12 @@ def load_model_checkpoint_weight_files(path: str | os.PathLike) -> list[Path]:
408412
if not isinstance(path, (str, os.PathLike)):
409413
raise TypeError(f"Expected path to be a string or Path, got {type(path)}")
410414

411-
path = Path(path)
412-
413-
if not path.exists():
415+
if not (path := Path(path)).exists():
414416
raise FileNotFoundError(f"Model checkpoint path does not exist: {path}")
415417

416-
if index_files := load_model_checkpoint_index_weight_files(path):
417-
logger.debug(f"Found index files at {path}: {index_files}")
418-
return index_files
418+
if weight_index_files := load_model_checkpoint_index_weight_files(path):
419+
logger.debug(f"Found index files at {path}: {weight_index_files}")
420+
return weight_index_files
419421

420422
if path.is_file() and path.suffix in {".bin", ".safetensors"}:
421423
logger.debug(f"Single weight file provided: {path}")
@@ -460,25 +462,26 @@ def load_model_checkpoint_state_dict(
460462
:return: Dictionary mapping parameter names to tensors
461463
:raises ValueError: If unsupported file format is encountered
462464
"""
465+
logger.debug(f"Loading model state dict from: {model}")
466+
467+
model = check_download_model_checkpoint(
468+
model,
469+
cache_dir=cache_dir,
470+
force_download=force_download,
471+
local_files_only=local_files_only,
472+
token=token,
473+
revision=revision,
474+
**kwargs,
475+
)
476+
463477
if isinstance(model, (PreTrainedModel, nn.Module)):
464478
logger.debug("Model is already a PreTrainedModel or nn.Module instance")
465479
return model.state_dict() # type: ignore[union-attr]
466480

467481
logger.debug(f"Loading model weights from: {model}")
468-
weight_files = load_model_checkpoint_weight_files(
469-
check_download_model_checkpoint(
470-
model,
471-
cache_dir=cache_dir,
472-
force_download=force_download,
473-
local_files_only=local_files_only,
474-
token=token,
475-
revision=revision,
476-
**kwargs,
477-
)
478-
)
482+
weight_files = load_model_checkpoint_weight_files(model)
479483

480484
state_dict = {}
481-
482485
for file in weight_files:
483486
if file.suffix == ".safetensors":
484487
logger.debug(f"Loading safetensors weights from: {file}")

tests/unit/conftest.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# Load all fixtures within mock package
2+
pytest_plugins = ["tests.unit.mock"]

tests/unit/mock/__init__.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
from .hf_factory import (
2+
MockPretrainedTransformersFactory,
3+
PretrainedBundle,
4+
mock_llama3_2m_config_dict,
5+
mock_llama3_2m_state_dict,
6+
)
7+
8+
__all__ = [
9+
"MockPretrainedTransformersFactory",
10+
"PretrainedBundle",
11+
"mock_llama3_2m_config_dict",
12+
"mock_llama3_2m_state_dict",
13+
]
14+
15+
# Expose all sub-plugins within the mock package to pytest
16+
pytest_plugins = ["tests.unit.mock.hf_factory"]

0 commit comments

Comments
 (0)