Skip to content

Commit 31b0f0f

Browse files
authored
[https://nvbugs/5445466][fix] Eliminate race when loading HF dynamic modules (#7268)
Signed-off-by: Chang Liu (Enterprise Products) <[email protected]>
1 parent 2e43753 commit 31b0f0f

File tree

2 files changed

+45
-11
lines changed

2 files changed

+45
-11
lines changed

tensorrt_llm/_torch/model_config.py

Lines changed: 45 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
1+
import contextlib
12
import json
23
import os
34
from dataclasses import dataclass, field
45
from pathlib import Path
56
from typing import Dict, Generic, List, Optional, TypeVar
67

8+
import filelock
79
import torch
810
import transformers
11+
from transformers.utils import HF_MODULES_CACHE
912

1013
from tensorrt_llm import logger
1114
from tensorrt_llm._torch.pyexecutor.config_utils import is_nemotron_hybrid
@@ -58,6 +61,35 @@ def get_layer_initial_global_assignments(self, layer_idx: int) -> List[int]:
5861
return None
5962

6063

64+
@contextlib.contextmanager
65+
def config_file_lock(timeout: int = 10):
66+
"""
67+
Context manager for file locking when loading pretrained configs.
68+
69+
This prevents race conditions when multiple processes try to download/load
70+
the same model configuration simultaneously.
71+
72+
Args:
73+
timeout: Maximum time to wait for lock acquisition in seconds
74+
"""
75+
# Use a single global lock file in HF cache directory
76+
# This serializes all model loading operations to prevent race conditions
77+
lock_path = Path(HF_MODULES_CACHE) / "_remote_code.lock"
78+
79+
# Create and acquire the lock
80+
lock = filelock.FileLock(str(lock_path), timeout=timeout)
81+
82+
try:
83+
with lock:
84+
yield
85+
except filelock.Timeout:
86+
logger.warning(
87+
f"Failed to acquire config lock within {timeout} seconds, proceeding without lock"
88+
)
89+
# Fallback: proceed without locking to avoid blocking indefinitely
90+
yield
91+
92+
6193
@dataclass(kw_only=True)
6294
class ModelConfig(Generic[TConfig]):
6395
pretrained_config: Optional[TConfig] = None
@@ -358,16 +390,20 @@ def from_pretrained(cls,
358390
checkpoint_dir: str,
359391
trust_remote_code=False,
360392
**kwargs):
361-
pretrained_config = transformers.AutoConfig.from_pretrained(
362-
checkpoint_dir,
363-
trust_remote_code=trust_remote_code,
364-
)
393+
# Use file lock to prevent race conditions when multiple processes
394+
# try to import/cache the same remote model config file
395+
with config_file_lock():
396+
pretrained_config = transformers.AutoConfig.from_pretrained(
397+
checkpoint_dir,
398+
trust_remote_code=trust_remote_code,
399+
)
400+
401+
# Find the cache path by looking for the config.json file which should be in all
402+
# huggingface models
403+
model_dir = Path(
404+
transformers.utils.hub.cached_file(checkpoint_dir,
405+
'config.json')).parent
365406

366-
# Find the cache path by looking for the config.json file which should be in all
367-
# huggingface models
368-
model_dir = Path(
369-
transformers.utils.hub.cached_file(checkpoint_dir,
370-
'config.json')).parent
371407
quant_config = QuantConfig()
372408
layer_quant_config = None
373409
moe_backend = kwargs.get('moe_backend', 'CUTLASS')

tests/integration/test_lists/waives.txt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -281,8 +281,6 @@ triton_server/test_triton.py::test_t5_ib[t5-ib] SKIP (https://nvbugs/5456482)
281281
triton_server/test_triton_llm.py::test_gpt_speculative_decoding_bls[False-False-1---False-True-True-0-128-disableDecoupleMode-inflight_fused_batching-disableTrtOverlap-0.2-guaranteed_no_evict---1-1-1-False-ensemble] SKIP (https://nvbugs/5456485)
282282
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[GSM8K-gen_tp=1-ctx_pp=4] SKIP (https://nvbugs/5434320)
283283
accuracy/test_llm_api_pytorch.py::TestQwen3_235B_A22B::test_nvfp4[latency_moe_trtllm_eagle3] SKIP (https://nvbugs/5437384)
284-
accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[latency_trtllmgen] SKIP (https://nvbugs/5445466)
285-
accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[latency] SKIP (https://nvbugs/5445466)
286284
llmapi/test_llm_examples.py::test_llmapi_speculative_decoding_mtp SKIP (https://nvbugs/5461796)
287285
accuracy/test_llm_api.py::TestLlama3_1_8BInstruct::test_gather_generation_logits_cuda_graph SKIP (https://nvbugs/5365525)
288286
examples/test_phi.py::test_phi_fp8_with_bf16_lora[Phi-3-mini-128k-instruct] SKIP (https://nvbugs/5465143)

0 commit comments

Comments
 (0)