|
| 1 | +import contextlib |
1 | 2 | import json
|
2 | 3 | import os
|
3 | 4 | from dataclasses import dataclass, field
|
4 | 5 | from pathlib import Path
|
5 | 6 | from typing import Dict, Generic, List, Optional, TypeVar
|
6 | 7 |
|
| 8 | +import filelock |
7 | 9 | import torch
|
8 | 10 | import transformers
|
| 11 | +from transformers.utils import HF_MODULES_CACHE |
9 | 12 |
|
10 | 13 | from tensorrt_llm import logger
|
11 | 14 | from tensorrt_llm._torch.pyexecutor.config_utils import is_nemotron_hybrid
|
@@ -58,6 +61,35 @@ def get_layer_initial_global_assignments(self, layer_idx: int) -> List[int]:
|
58 | 61 | return None
|
59 | 62 |
|
60 | 63 |
|
| 64 | +@contextlib.contextmanager |
| 65 | +def config_file_lock(timeout: int = 10): |
| 66 | + """ |
| 67 | + Context manager for file locking when loading pretrained configs. |
| 68 | +
|
| 69 | + This prevents race conditions when multiple processes try to download/load |
| 70 | + the same model configuration simultaneously. |
| 71 | +
|
| 72 | + Args: |
| 73 | + timeout: Maximum time to wait for lock acquisition in seconds |
| 74 | + """ |
| 75 | + # Use a single global lock file in HF cache directory |
| 76 | + # This serializes all model loading operations to prevent race conditions |
| 77 | + lock_path = Path(HF_MODULES_CACHE) / "_remote_code.lock" |
| 78 | + |
| 79 | + # Create and acquire the lock |
| 80 | + lock = filelock.FileLock(str(lock_path), timeout=timeout) |
| 81 | + |
| 82 | + try: |
| 83 | + with lock: |
| 84 | + yield |
| 85 | + except filelock.Timeout: |
| 86 | + logger.warning( |
| 87 | + f"Failed to acquire config lock within {timeout} seconds, proceeding without lock" |
| 88 | + ) |
| 89 | + # Fallback: proceed without locking to avoid blocking indefinitely |
| 90 | + yield |
| 91 | + |
| 92 | + |
61 | 93 | @dataclass(kw_only=True)
|
62 | 94 | class ModelConfig(Generic[TConfig]):
|
63 | 95 | pretrained_config: Optional[TConfig] = None
|
@@ -358,16 +390,20 @@ def from_pretrained(cls,
|
358 | 390 | checkpoint_dir: str,
|
359 | 391 | trust_remote_code=False,
|
360 | 392 | **kwargs):
|
361 |
| - pretrained_config = transformers.AutoConfig.from_pretrained( |
362 |
| - checkpoint_dir, |
363 |
| - trust_remote_code=trust_remote_code, |
364 |
| - ) |
| 393 | + # Use file lock to prevent race conditions when multiple processes |
| 394 | + # try to import/cache the same remote model config file |
| 395 | + with config_file_lock(): |
| 396 | + pretrained_config = transformers.AutoConfig.from_pretrained( |
| 397 | + checkpoint_dir, |
| 398 | + trust_remote_code=trust_remote_code, |
| 399 | + ) |
| 400 | + |
| 401 | + # Find the cache path by looking for the config.json file which should be in all |
| 402 | + # huggingface models |
| 403 | + model_dir = Path( |
| 404 | + transformers.utils.hub.cached_file(checkpoint_dir, |
| 405 | + 'config.json')).parent |
365 | 406 |
|
366 |
| - # Find the cache path by looking for the config.json file which should be in all |
367 |
| - # huggingface models |
368 |
| - model_dir = Path( |
369 |
| - transformers.utils.hub.cached_file(checkpoint_dir, |
370 |
| - 'config.json')).parent |
371 | 407 | quant_config = QuantConfig()
|
372 | 408 | layer_quant_config = None
|
373 | 409 | moe_backend = kwargs.get('moe_backend', 'CUTLASS')
|
|
0 commit comments