33import functools
44import logging
55import os
6+ import shutil
67import subprocess
78import sys
89from typing import Any , Callable , Dict , List , Optional , Sequence , Tuple , Union , overload
@@ -926,19 +927,98 @@ def args_bounds_check(
926927 return args [i ] if len (args ) > i and args [i ] is not None else replacement
927928
928929
930+ def install_wget (platform : str ) -> None :
931+ if shutil .which ("wget" ):
932+ _LOGGER .debug ("wget is already installed" )
933+ return
934+ if platform .startswith ("linux" ):
935+ try :
936+ # if its root
937+ if os .geteuid () == 0 :
938+ subprocess .run (["apt-get" , "update" ], check = True )
939+ subprocess .run (["apt-get" , "install" , "-y" , "wget" ], check = True )
940+ else :
941+ _LOGGER .debug ("Please run with sudo permissions" )
942+ subprocess .run (["sudo" , "apt-get" , "update" ], check = True )
943+ subprocess .run (["sudo" , "apt-get" , "install" , "-y" , "wget" ], check = True )
944+ except subprocess .CalledProcessError as e :
945+ _LOGGER .debug ("Error installing wget:" , e )
946+
947+
948+ def install_mpi (platform : str ) -> None :
949+ if platform .startswith ("linux" ):
950+ try :
951+ # if its root
952+ if os .geteuid () == 0 :
953+ subprocess .run (["apt-get" , "update" ], check = True )
954+ subprocess .run (["apt-get" , "install" , "-y" , "libmpich-dev" ], check = True )
955+ subprocess .run (
956+ ["apt-get" , "install" , "-y" , "libopenmpi-dev" ], check = True
957+ )
958+ else :
959+ _LOGGER .debug ("Please run with sudo permissions" )
960+ subprocess .run (["sudo" , "apt-get" , "update" ], check = True )
961+ subprocess .run (
962+ ["sudo" , "apt-get" , "install" , "-y" , "libmpich-dev" ], check = True
963+ )
964+ subprocess .run (
965+ ["sudo" , "apt-get" , "install" , "-y" , "libopenmpi-dev" ], check = True
966+ )
967+ except subprocess .CalledProcessError as e :
968+ _LOGGER .debug ("Error installing mpi libs:" , e )
969+
970+
971+ def download_plugin_lib_path (py_version : str , platform : str ) -> str :
972+ plugin_lib_path = None
973+ if py_version not in ("cp310" , "cp312" ):
974+ _LOGGER .warning (
975+ "No available wheel for python versions other than py3.10 and py3.12"
976+ )
977+ install_wget (platform )
978+ base_url = "https://pypi.nvidia.com/tensorrt-llm/"
979+ file_name = f"tensorrt_llm-0.17.0.post1-{ py_version } -{ py_version } -{ platform } .whl"
980+ download_url = base_url + file_name
981+ cmd = ["wget" , download_url ]
982+ try :
983+ if not (os .path .exists (file_name )):
984+ _LOGGER .info (f"Running command: { ' ' .join (cmd )} " )
985+ subprocess .run (cmd )
986+ _LOGGER .info ("Download complete of wheel" )
987+ if os .path .exists (file_name ):
988+ _LOGGER .info ("filename now present" )
989+ if os .path .exists ("./tensorrt_llm/libs/libnvinfer_plugin_tensorrt_llm.so" ):
990+ plugin_lib_path = (
991+ "./tensorrt_llm/libs/" + "libnvinfer_plugin_tensorrt_llm.so"
992+ )
993+ else :
994+ import zipfile
995+
996+ with zipfile .ZipFile (file_name , "r" ) as zip_ref :
997+ zip_ref .extractall ("." ) # Extract to a folder named 'tensorrt_llm'
998+ plugin_lib_path = (
999+ "./tensorrt_llm/libs/" + "libnvinfer_plugin_tensorrt_llm.so"
1000+ )
1001+ except subprocess .CalledProcessError as e :
1002+ _LOGGER .debug (f"Error occurred while trying to download: { e } " )
1003+ except Exception as e :
1004+ _LOGGER .debug (f"An unexpected error occurred: { e } " )
1005+ return plugin_lib_path
1006+
1007+
9291008def load_tensorrt_llm () -> bool :
9301009 """
9311010 Attempts to load the TensorRT-LLM plugin and initialize it.
9321011
9331012 Returns:
9341013 bool: True if the plugin was successfully loaded and initialized, False otherwise.
9351014 """
936-
9371015 plugin_lib_path = os .environ .get ("TRTLLM_PLUGINS_PATH" )
9381016 if not plugin_lib_path :
9391017 _LOGGER .warning (
9401018 "Please set the TRTLLM_PLUGINS_PATH to the directory containing libnvinfer_plugin_tensorrt_llm.so to use converters for torch.distributed ops or else set the USE_TRTLLM_PLUGINS variable to download the shared library" ,
9411019 )
1020+ for key , value in os .environ .items ():
1021+ print (f"{ key } : { value } " )
9421022 use_trtllm_plugin = os .environ .get ("USE_TRTLLM_PLUGINS" , "0" ).lower () in (
9431023 "1" ,
9441024 "true" ,
@@ -953,38 +1033,12 @@ def load_tensorrt_llm() -> bool:
9531033 else :
9541034 py_version = f"cp{ sys .version_info .major } { sys .version_info .minor } "
9551035 platform = Platform .current_platform ()
956- if Platform == Platform .LINUX_X86_64 :
957- platform = "linux_x86_64"
958- elif Platform == Platform .LINUX_AARCH64 :
959- platform = "linux_aarch64"
960-
961- if py_version not in ("cp310" , "cp312" ):
962- _LOGGER .warning (
963- "No available wheel for python versions other than py3.10 and py3.12"
964- )
965- if py_version == "cp310" and platform == "linux_aarch64" :
966- _LOGGER .warning ("No available wheel for python3.10 with Linux aarch64" )
9671036
968- base_url = "https://pypi.nvidia.com/tensorrt-llm/"
969- file_name = (
970- "tensorrt_llm-0.17.0.post1-{py_version}-{py_version}-{platform}.whl"
971- )
972- download_url = base_url + file_name
973- cmd = ["wget" , download_url ]
974- subprocess .run (cmd )
975- if os .path .exists (file_name ):
976- _LOGGER .info ("filename download is completed" )
977- import zipfile
978-
979- with zipfile .ZipFile (file_name , "r" ) as zip_ref :
980- zip_ref .extractall (
981- "./tensorrt_llm"
982- ) # Extract to a folder named 'tensorrt_llm'
983- plugin_lib_path = (
984- "./tensorrt_llm" + "libnvinfer_plugin_tensorrt_llm.so"
985- )
1037+ platform = str (platform ).lower ()
1038+ plugin_lib_path = download_plugin_lib_path (py_version , platform )
9861039 try :
987- # Load the shared library
1040+ # Load the shared
1041+ install_mpi (platform )
9881042 handle = ctypes .CDLL (plugin_lib_path )
9891043 _LOGGER .info (f"Successfully loaded plugin library: { plugin_lib_path } " )
9901044 except OSError as e_os_error :
0 commit comments