Skip to content

Commit 6aa0dfc

Browse files
lancellydominicshanshan
authored andcommitted
[TRTLLM-6835][fix] Fix potential hang caused by python multiprocessing when prefetching weights (NVIDIA#6927)
Signed-off-by: Lance Liao <[email protected]> Signed-off-by: Wangshanshan <[email protected]>
1 parent 4c4141a commit 6aa0dfc

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

tensorrt_llm/_torch/models/checkpoints/hf/weight_loader.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import glob
22
import multiprocessing
33
import os
4+
from concurrent.futures import ThreadPoolExecutor
45
from typing import Any, List
56

67
import psutil
@@ -120,7 +121,7 @@ def prefetch_files(self, file_names: List[str]):
120121
if len(local_file_names) == 0:
121122
return
122123

123-
max_processes = min(multiprocessing.cpu_count() * 2, 16,
124-
len(local_file_names))
125-
with multiprocessing.Pool(processes=max_processes) as pool:
126-
pool.map(self._prefetch_one_file, local_file_names)
124+
max_workers = min(multiprocessing.cpu_count() * 2, 16,
125+
len(local_file_names))
126+
with ThreadPoolExecutor(max_workers=max_workers) as executor:
127+
list(executor.map(self._prefetch_one_file, local_file_names))

0 commit comments

Comments
 (0)