33import functools
44import logging
55import os
6+ import shutil
67import subprocess
78import sys
89from typing import Any , Callable , Dict , List , Optional , Sequence , Tuple , Union , overload
@@ -926,19 +927,100 @@ def args_bounds_check(
926927 return args [i ] if len (args ) > i and args [i ] is not None else replacement
927928
928929
930+ def install_wget (platform : str ) -> None :
931+ if shutil .which ("wget" ):
932+ _LOGGER .debug ("wget is already installed" )
933+ return
934+ if platform .startswith ("linux" ):
935+ try :
936+ # if its root
937+ if os .geteuid () == 0 :
938+ subprocess .run (["apt-get" , "update" ], check = True )
939+ subprocess .run (["apt-get" , "install" , "-y" , "wget" ], check = True )
940+ else :
941+ _LOGGER .debug ("Please run with sudo permissions" )
942+ subprocess .run (["sudo" , "apt-get" , "update" ], check = True )
943+ subprocess .run (["sudo" , "apt-get" , "install" , "-y" , "wget" ], check = True )
944+ except subprocess .CalledProcessError as e :
945+ _LOGGER .debug ("Error installing wget:" , e )
946+
947+
948+ def install_mpi (platform : str ) -> None :
949+ if platform .startswith ("linux" ):
950+ try :
951+ # if its root
952+ if os .geteuid () == 0 :
953+ subprocess .run (["apt-get" , "update" ], check = True )
954+ subprocess .run (["apt-get" , "install" , "-y" , "libmpich-dev" ], check = True )
955+ subprocess .run (
956+ ["apt-get" , "install" , "-y" , "libopenmpi-dev" ], check = True
957+ )
958+ else :
959+ _LOGGER .debug ("Please run with sudo permissions" )
960+ subprocess .run (["sudo" , "apt-get" , "update" ], check = True )
961+ subprocess .run (
962+ ["sudo" , "apt-get" , "install" , "-y" , "libmpich-dev" ], check = True
963+ )
964+ subprocess .run (
965+ ["sudo" , "apt-get" , "install" , "-y" , "libopenmpi-dev" ], check = True
966+ )
967+ except subprocess .CalledProcessError as e :
968+ _LOGGER .debug ("Error installing mpi libs:" , e )
969+
970+
971+ def download_plugin_lib_path (py_version : str , platform : str ) -> str :
972+ plugin_lib_path = None
973+ if py_version not in ("cp310" , "cp312" ):
974+ _LOGGER .warning (
975+ "No available wheel for python versions other than py3.10 and py3.12"
976+ )
977+ install_wget (platform )
978+ base_url = "https://pypi.nvidia.com/tensorrt-llm/"
979+ file_name = f"tensorrt_llm-0.17.0.post1-{ py_version } -{ py_version } -{ platform } .whl"
980+ print ("file_name is===" , file_name )
981+ download_url = base_url + file_name
982+ cmd = ["wget" , download_url ]
983+ try :
984+ if not (os .path .exists (file_name )):
985+ _LOGGER .info (f"Running command: { ' ' .join (cmd )} " )
986+ subprocess .run (cmd )
987+ _LOGGER .info ("Download complete of wheel" )
988+ if os .path .exists (file_name ):
989+ _LOGGER .info ("filename now present" )
990+ if os .path .exists ("./tensorrt_llm/libs/libnvinfer_plugin_tensorrt_llm.so" ):
991+ plugin_lib_path = (
992+ "./tensorrt_llm/libs/" + "libnvinfer_plugin_tensorrt_llm.so"
993+ )
994+ else :
995+ import zipfile
996+
997+ with zipfile .ZipFile (file_name , "r" ) as zip_ref :
998+ zip_ref .extractall ("." ) # Extract to a folder named 'tensorrt_llm'
999+ plugin_lib_path = (
1000+ "./tensorrt_llm/libs/" + "libnvinfer_plugin_tensorrt_llm.so"
1001+ )
1002+ except subprocess .CalledProcessError as e :
1003+ _LOGGER .debug (f"Error occurred while trying to download: { e } " )
1004+ except Exception as e :
1005+ _LOGGER .debug (f"An unexpected error occurred: { e } " )
1006+ return plugin_lib_path
1007+
1008+
9291009def load_tensorrt_llm () -> bool :
9301010 """
9311011 Attempts to load the TensorRT-LLM plugin and initialize it.
9321012
9331013 Returns:
9341014 bool: True if the plugin was successfully loaded and initialized, False otherwise.
9351015 """
936-
1016+ print ( "coming to check load_tensorrt_llm!!!!" )
9371017 plugin_lib_path = os .environ .get ("TRTLLM_PLUGINS_PATH" )
9381018 if not plugin_lib_path :
9391019 _LOGGER .warning (
9401020 "Please set the TRTLLM_PLUGINS_PATH to the directory containing libnvinfer_plugin_tensorrt_llm.so to use converters for torch.distributed ops or else set the USE_TRTLLM_PLUGINS variable to download the shared library" ,
9411021 )
1022+ for key , value in os .environ .items ():
1023+ print (f"{ key } : { value } " )
9421024 use_trtllm_plugin = os .environ .get ("USE_TRTLLM_PLUGINS" , "0" ).lower () in (
9431025 "1" ,
9441026 "true" ,
@@ -953,38 +1035,12 @@ def load_tensorrt_llm() -> bool:
9531035 else :
9541036 py_version = f"cp{ sys .version_info .major } { sys .version_info .minor } "
9551037 platform = Platform .current_platform ()
956- if Platform == Platform .LINUX_X86_64 :
957- platform = "linux_x86_64"
958- elif Platform == Platform .LINUX_AARCH64 :
959- platform = "linux_aarch64"
960-
961- if py_version not in ("cp310" , "cp312" ):
962- _LOGGER .warning (
963- "No available wheel for python versions other than py3.10 and py3.12"
964- )
965- if py_version == "cp310" and platform == "linux_aarch64" :
966- _LOGGER .warning ("No available wheel for python3.10 with Linux aarch64" )
9671038
968- base_url = "https://pypi.nvidia.com/tensorrt-llm/"
969- file_name = (
970- "tensorrt_llm-0.17.0.post1-{py_version}-{py_version}-{platform}.whl"
971- )
972- download_url = base_url + file_name
973- cmd = ["wget" , download_url ]
974- subprocess .run (cmd )
975- if os .path .exists (file_name ):
976- _LOGGER .info ("filename download is completed" )
977- import zipfile
978-
979- with zipfile .ZipFile (file_name , "r" ) as zip_ref :
980- zip_ref .extractall (
981- "./tensorrt_llm"
982- ) # Extract to a folder named 'tensorrt_llm'
983- plugin_lib_path = (
984- "./tensorrt_llm" + "libnvinfer_plugin_tensorrt_llm.so"
985- )
1039+ platform = str (platform ).lower ()
1040+ plugin_lib_path = download_plugin_lib_path (py_version , platform )
9861041 try :
987- # Load the shared library
1042+ # Load the shared
1043+ install_mpi (platform )
9881044 handle = ctypes .CDLL (plugin_lib_path )
9891045 _LOGGER .info (f"Successfully loaded plugin library: { plugin_lib_path } " )
9901046 except OSError as e_os_error :
0 commit comments