From abaa4af5aa5a456410270d32c7e0ff50e30261e5 Mon Sep 17 00:00:00 2001 From: Yan Chunwei <328693+Superjomn@users.noreply.github.com> Date: Fri, 17 Oct 2025 12:14:43 +0800 Subject: [PATCH] [https://nvbugs/5437384][test] fix trtllm-llmapi-launch multi tests with single launch (#8397) Signed-off-by: Yan Chunwei <328693+Superjomn@users.noreply.github.com> Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com> --- tensorrt_llm/llmapi/mpi_session.py | 101 ++++++++++++++++-- .../test_lists/test-db/l0_dgx_h100.yml | 2 + tests/unittest/llmapi/_run_multi_llm_tasks.py | 33 ++++++ .../llmapi/_run_multi_mpi_comm_tasks.py | 43 ++++++++ tests/unittest/llmapi/test_mpi_session.py | 58 ++++++++++ 5 files changed, 231 insertions(+), 6 deletions(-) create mode 100644 tests/unittest/llmapi/_run_multi_llm_tasks.py create mode 100644 tests/unittest/llmapi/_run_multi_mpi_comm_tasks.py diff --git a/tensorrt_llm/llmapi/mpi_session.py b/tensorrt_llm/llmapi/mpi_session.py index b5417f71241..f0275d7f90a 100644 --- a/tensorrt_llm/llmapi/mpi_session.py +++ b/tensorrt_llm/llmapi/mpi_session.py @@ -48,6 +48,10 @@ def task(): ''' state = None + # Global MPICommExecutor instance to be reused across multiple MpiCommSession instances + # This is necessary because MPICommExecutor can only be created once per MPI process + _global_comm_executor = None + _global_mpi_pool = None @staticmethod def is_initialized() -> bool: @@ -183,6 +187,7 @@ def __init__(self, comm=None, n_workers: int = 1): self.n_workers = n_workers self.thread_pool: Optional[ThreadPoolExecutor] = None self.mpi_pool: Optional[MPIPoolExecutor] = None + self.owns_mpi_pool = False # Track if this instance owns the mpi_pool if n_workers <= 0: raise ValueError( @@ -230,9 +235,11 @@ def submit_sync(self, task: Callable[..., T], *args, **kwargs) -> List[T]: return [future.result() for future in futures] def shutdown(self, wait=True): - if self.mpi_pool is not None: + # Only shutdown the mpi_pool if this instance created it + # For shared global mpi_pool, we don't shut it down + if self.mpi_pool is not None and self.owns_mpi_pool: self.mpi_pool.shutdown(wait=wait) - self.mpi_pool = None + self.mpi_pool = None if self.thread_pool is not None: self.thread_pool.shutdown(wait=wait) self.thread_pool = None @@ -244,8 +251,36 @@ def _start_mpi_pool(self): assert not self.mpi_pool, 'MPI session already started' self.thread_pool = ThreadPoolExecutor(max_workers=2) - comm_executor = MPICommExecutor(self.comm) - self.mpi_pool = comm_executor.__enter__() + + # Use global MPICommExecutor if using COMM_WORLD + # This is necessary because MPICommExecutor can only be created once per MPI process + logger_debug( + f"_start_mpi_pool: ENABLE_MULTI_DEVICE={ENABLE_MULTI_DEVICE}, self.comm={self.comm}\n", + "grey") + if ENABLE_MULTI_DEVICE: + logger_debug( + f"_start_mpi_pool: Checking if self.comm == mpi4py.MPI.COMM_WORLD: {self.comm == mpi4py.MPI.COMM_WORLD}\n", + "grey") + if ENABLE_MULTI_DEVICE and self.comm == mpi4py.MPI.COMM_WORLD: + if MPINodeState._global_comm_executor is None: + logger_debug("Creating global MPICommExecutor for COMM_WORLD\n", + "yellow") + MPINodeState._global_comm_executor = MPICommExecutor(self.comm) + MPINodeState._global_mpi_pool = MPINodeState._global_comm_executor.__enter__( + ) + else: + logger_debug("Reusing global MPICommExecutor for COMM_WORLD\n", + "yellow") + self.mpi_pool = MPINodeState._global_mpi_pool + self.owns_mpi_pool = False + else: + logger_debug( + f"_start_mpi_pool: Creating new MPICommExecutor (not COMM_WORLD or ENABLE_MULTI_DEVICE=False)\n", + "grey") + # For non-COMM_WORLD communicators, create a new executor + comm_executor = MPICommExecutor(self.comm) + self.mpi_pool = comm_executor.__enter__() + self.owns_mpi_pool = True def __del__(self): self.shutdown_abort() @@ -264,9 +299,35 @@ class RemoteTask(NamedTuple): class RemoteMpiCommSessionClient(MpiSession): ''' RemoteMpiCommSessionClient is a variant of MpiCommSession that is used to connect to a remote MPI pool. + + Note: This class uses a global singleton pattern because ZeroMQ PAIR sockets only support + one connection at a time. Multiple LLM instances will reuse the same client connection. ''' + _global_instance = None + _global_instance_lock = threading.Lock() + + def __new__(cls, addr: str, hmac_key: Optional[bytes] = None): + # Implement singleton pattern to reuse the same client connection + # for multiple LLM instances, since PAIR sockets only support one connection + with cls._global_instance_lock: + if cls._global_instance is None or cls._global_instance.addr != addr: + logger_debug( + f"Creating new global RemoteMpiCommSessionClient for {addr}\n", + "yellow") + instance = super().__new__(cls) + cls._global_instance = instance + instance._initialized = False + else: + logger_debug( + f"Reusing existing global RemoteMpiCommSessionClient for {addr}\n", + "yellow") + return cls._global_instance def __init__(self, addr: str, hmac_key: Optional[bytes] = None): + # Only initialize once + if self._initialized: + return + # FIXME: this is a hack to avoid circular import, resolve later from tensorrt_llm.executor.ipc import ZeroMqQueue self.addr = addr @@ -277,6 +338,7 @@ def __init__(self, addr: str, hmac_key: Optional[bytes] = None): socket_type=zmq.PAIR, use_hmac_encryption=bool(hmac_key)) self._is_shutdown = False + self._initialized = True def submit(self, task: Callable[..., T], @@ -329,10 +391,16 @@ def abort(self): self.shutdown() def shutdown(self, wait=True): - pass + # NOTE: We do NOT close the queue or mark as shutdown for the singleton instance. + # The RemoteMpiCommSessionClient is a global singleton that's reused across multiple + # LLM instances. Marking it as shutdown would prevent subsequent LLM instances from + # using it. The connection stays open for the entire lifetime of the mgmn setup. + logger_debug( + f"RemoteMpiCommSessionClient.shutdown() called (no-op for singleton)\n", + "grey") def shutdown_abort(self, grace: float = 60, reason=None): - pass + self.shutdown() class RemoteMpiCommSessionServer(): @@ -393,7 +461,26 @@ def task_wrapper(task: Callable[..., T], *args, **kwargs) -> T: def serve(self): logger_debug(f"RemoteMpiCommSessionServer listening on {self.addr}\n", "yellow") + pending_futures = [] while True: + # Wait for any pending futures from previous tasks to complete + # This ensures all ranks are ready before accepting the next task + if pending_futures: + logger_debug( + f"RemoteMpiCommSessionServer waiting for {len(pending_futures)} pending futures to complete\n", + "grey") + for future in pending_futures: + try: + future.result() # Wait for completion + except Exception as e: + print_colored( + f"RemoteMpiCommSessionServer future failed with exception: {e}\n", + "red") + pending_futures.clear() + logger_debug( + "RemoteMpiCommSessionServer all pending futures completed\n", + "grey") + message: Optional[RemoteTask] = self.queue.get() if message is None: logger_debug( @@ -410,6 +497,8 @@ def serve(self): *message.args, **message.kwargs) self.num_results = self.session.n_workers assert len(futures) == self.num_results == mpi_world_size() + # Store futures to wait for them before the next task + pending_futures = list(futures) if message.sync: for future in futures: future.add_done_callback(self.mpi_future_callback) diff --git a/tests/integration/test_lists/test-db/l0_dgx_h100.yml b/tests/integration/test_lists/test-db/l0_dgx_h100.yml index 1518613c1d0..89054801879 100644 --- a/tests/integration/test_lists/test-db/l0_dgx_h100.yml +++ b/tests/integration/test_lists/test-db/l0_dgx_h100.yml @@ -43,6 +43,8 @@ l0_dgx_h100: - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True-True-True] # ------------- AutoDeploy tests --------------- - accuracy/test_llm_api_autodeploy.py::TestLlama3_1_8B::test_auto_dtype[False-2] + # llmapi + - unittest/llmapi/test_mpi_session.py::test_llmapi_launch_multiple_tasks - condition: ranges: system_gpu_count: diff --git a/tests/unittest/llmapi/_run_multi_llm_tasks.py b/tests/unittest/llmapi/_run_multi_llm_tasks.py new file mode 100644 index 00000000000..50b75d14c6e --- /dev/null +++ b/tests/unittest/llmapi/_run_multi_llm_tasks.py @@ -0,0 +1,33 @@ +import os +import sys + +cur_dir = os.path.dirname(os.path.abspath(__file__)) + +from tensorrt_llm import LLM +from tensorrt_llm.llmapi import SamplingParams +from tensorrt_llm.llmapi.utils import print_colored + +# isort: off +sys.path.append(os.path.join(cur_dir, '..')) +from utils.llm_data import llm_models_root +# isort: on + +model_path = llm_models_root() / "llama-models-v2" / "TinyLlama-1.1B-Chat-v1.0" + + +def run_llm_tp2(): + with LLM(model=model_path, tensor_parallel_size=2) as llm: + sampling_params = SamplingParams(max_tokens=10, end_id=-1) + for output in llm.generate(["Hello, my name is"], sampling_params): + print(output) + + +def run_multi_llm_tasks(): + for i in range(3): + print_colored(f"Running LLM task {i}\n", "green") + run_llm_tp2() + print_colored(f"LLM task {i} completed\n", "green") + + +if __name__ == "__main__": + run_multi_llm_tasks() diff --git a/tests/unittest/llmapi/_run_multi_mpi_comm_tasks.py b/tests/unittest/llmapi/_run_multi_mpi_comm_tasks.py new file mode 100644 index 00000000000..5b50df94f2d --- /dev/null +++ b/tests/unittest/llmapi/_run_multi_mpi_comm_tasks.py @@ -0,0 +1,43 @@ +import os +from typing import Literal + +import click + +from tensorrt_llm.executor.utils import LlmLauncherEnvs +from tensorrt_llm.llmapi.mpi_session import RemoteMpiCommSessionClient +from tensorrt_llm.llmapi.utils import print_colored + + +def run_task(task_type: Literal["submit", "submit_sync"]): + tasks = range(10) + assert os.environ[ + LlmLauncherEnvs. + TLLM_SPAWN_PROXY_PROCESS_IPC_ADDR] is not None, "TLLM_SPAWN_PROXY_PROCESS_IPC_ADDR is not set" + client = RemoteMpiCommSessionClient( + os.environ[LlmLauncherEnvs.TLLM_SPAWN_PROXY_PROCESS_IPC_ADDR]) + + for task in tasks: + if task_type == "submit": + client.submit(print_colored, f"{task}\n", "green") + elif task_type == "submit_sync": + res = client.submit_sync(print_colored, f"{task}\n", "green") + print(res) + + +def run_multi_tasks(task_type: Literal["submit", "submit_sync"]): + for i in range(3): + print_colored(f"Running MPI comm task {i}\n", "green") + run_task(task_type) + print_colored(f"MPI comm task {i} completed\n", "green") + + +@click.command() +@click.option("--task_type", + type=click.Choice(["submit", "submit_sync"]), + default="submit") +def main(task_type: Literal["submit", "submit_sync"]): + run_multi_tasks(task_type) + + +if __name__ == "__main__": + main() diff --git a/tests/unittest/llmapi/test_mpi_session.py b/tests/unittest/llmapi/test_mpi_session.py index bedce258c26..eac5d877b05 100644 --- a/tests/unittest/llmapi/test_mpi_session.py +++ b/tests/unittest/llmapi/test_mpi_session.py @@ -5,6 +5,8 @@ from subprocess import PIPE, Popen from typing import Literal +cur_dir = os.path.dirname(os.path.abspath(__file__)) + import pytest from tensorrt_llm.bindings.BuildInfo import ENABLE_MULTI_DEVICE @@ -12,6 +14,11 @@ RemoteMpiCommSessionClient, split_mpi_env) +# isort: off +sys.path.append(os.path.join(cur_dir, '..')) +from utils.util import skip_single_gpu +# isort: on + def task0(): if MPINodeState.state is None: @@ -108,3 +115,54 @@ def task1(): def test_split_mpi_env(): session = MpiPoolSession(n_workers=4) session.submit_sync(task1) + + +@skip_single_gpu +@pytest.mark.parametrize( + "task_script", ["_run_mpi_comm_task.py", "_run_multi_mpi_comm_tasks.py"]) +def test_llmapi_launch_multiple_tasks(task_script: str): + """ + Test that the trtllm-llmapi-launch can run multiple tasks. + """ + cur_dir = os.path.dirname(os.path.abspath(__file__)) + test_file = os.path.join(cur_dir, "_run_multi_llm_tasks.py") + assert os.path.exists(test_file), f"Test file {test_file} does not exist" + command = [ + "mpirun", "-n", "2", "--allow-run-as-root", "trtllm-llmapi-launch", + "python3", test_file + ] + print(' '.join(command)) + + with Popen(command, + env=os.environ, + stdout=PIPE, + stderr=PIPE, + bufsize=1, + start_new_session=True, + universal_newlines=True, + cwd=os.path.dirname(os.path.abspath(__file__))) as process: + # Function to read from a stream and write to output + def read_stream(stream, output_stream): + for line in stream: + output_stream.write(line) + output_stream.flush() + + # Create threads to read stdout and stderr concurrently + stdout_thread = threading.Thread(target=read_stream, + args=(process.stdout, sys.stdout)) + stderr_thread = threading.Thread(target=read_stream, + args=(process.stderr, sys.stderr)) + + # Start both threads + stdout_thread.start() + stderr_thread.start() + + # Wait for the process to complete + return_code = process.wait() + + # Wait for both threads to finish reading + stdout_thread.join() + stderr_thread.join() + + if return_code != 0: + raise subprocess.CalledProcessError(return_code, command)