Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 95 additions & 6 deletions tensorrt_llm/llmapi/mpi_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ def task():
'''

state = None
# Global MPICommExecutor instance to be reused across multiple MpiCommSession instances
# This is necessary because MPICommExecutor can only be created once per MPI process
_global_comm_executor = None
_global_mpi_pool = None

@staticmethod
def is_initialized() -> bool:
Expand Down Expand Up @@ -183,6 +187,7 @@ def __init__(self, comm=None, n_workers: int = 1):
self.n_workers = n_workers
self.thread_pool: Optional[ThreadPoolExecutor] = None
self.mpi_pool: Optional[MPIPoolExecutor] = None
self.owns_mpi_pool = False # Track if this instance owns the mpi_pool

if n_workers <= 0:
raise ValueError(
Expand Down Expand Up @@ -230,9 +235,11 @@ def submit_sync(self, task: Callable[..., T], *args, **kwargs) -> List[T]:
return [future.result() for future in futures]

def shutdown(self, wait=True):
if self.mpi_pool is not None:
# Only shutdown the mpi_pool if this instance created it
# For shared global mpi_pool, we don't shut it down
if self.mpi_pool is not None and self.owns_mpi_pool:
self.mpi_pool.shutdown(wait=wait)
self.mpi_pool = None
self.mpi_pool = None
if self.thread_pool is not None:
self.thread_pool.shutdown(wait=wait)
self.thread_pool = None
Expand All @@ -244,8 +251,36 @@ def _start_mpi_pool(self):
assert not self.mpi_pool, 'MPI session already started'

self.thread_pool = ThreadPoolExecutor(max_workers=2)
comm_executor = MPICommExecutor(self.comm)
self.mpi_pool = comm_executor.__enter__()

# Use global MPICommExecutor if using COMM_WORLD
# This is necessary because MPICommExecutor can only be created once per MPI process
logger_debug(
f"_start_mpi_pool: ENABLE_MULTI_DEVICE={ENABLE_MULTI_DEVICE}, self.comm={self.comm}\n",
"grey")
if ENABLE_MULTI_DEVICE:
logger_debug(
f"_start_mpi_pool: Checking if self.comm == mpi4py.MPI.COMM_WORLD: {self.comm == mpi4py.MPI.COMM_WORLD}\n",
"grey")
if ENABLE_MULTI_DEVICE and self.comm == mpi4py.MPI.COMM_WORLD:
if MPINodeState._global_comm_executor is None:
logger_debug("Creating global MPICommExecutor for COMM_WORLD\n",
"yellow")
MPINodeState._global_comm_executor = MPICommExecutor(self.comm)
MPINodeState._global_mpi_pool = MPINodeState._global_comm_executor.__enter__(
)
else:
logger_debug("Reusing global MPICommExecutor for COMM_WORLD\n",
"yellow")
self.mpi_pool = MPINodeState._global_mpi_pool
self.owns_mpi_pool = False
else:
logger_debug(
f"_start_mpi_pool: Creating new MPICommExecutor (not COMM_WORLD or ENABLE_MULTI_DEVICE=False)\n",
"grey")
# For non-COMM_WORLD communicators, create a new executor
comm_executor = MPICommExecutor(self.comm)
self.mpi_pool = comm_executor.__enter__()
self.owns_mpi_pool = True

def __del__(self):
self.shutdown_abort()
Expand All @@ -264,9 +299,35 @@ class RemoteTask(NamedTuple):
class RemoteMpiCommSessionClient(MpiSession):
'''
RemoteMpiCommSessionClient is a variant of MpiCommSession that is used to connect to a remote MPI pool.

Note: This class uses a global singleton pattern because ZeroMQ PAIR sockets only support
one connection at a time. Multiple LLM instances will reuse the same client connection.
'''
_global_instance = None
_global_instance_lock = threading.Lock()

def __new__(cls, addr: str, hmac_key: Optional[bytes] = None):
# Implement singleton pattern to reuse the same client connection
# for multiple LLM instances, since PAIR sockets only support one connection
with cls._global_instance_lock:
if cls._global_instance is None or cls._global_instance.addr != addr:
logger_debug(
f"Creating new global RemoteMpiCommSessionClient for {addr}\n",
"yellow")
instance = super().__new__(cls)
cls._global_instance = instance
instance._initialized = False
else:
logger_debug(
f"Reusing existing global RemoteMpiCommSessionClient for {addr}\n",
"yellow")
return cls._global_instance

def __init__(self, addr: str, hmac_key: Optional[bytes] = None):
# Only initialize once
if self._initialized:
return

# FIXME: this is a hack to avoid circular import, resolve later
from tensorrt_llm.executor.ipc import ZeroMqQueue
self.addr = addr
Expand All @@ -277,6 +338,7 @@ def __init__(self, addr: str, hmac_key: Optional[bytes] = None):
socket_type=zmq.PAIR,
use_hmac_encryption=bool(hmac_key))
self._is_shutdown = False
self._initialized = True

def submit(self,
task: Callable[..., T],
Expand Down Expand Up @@ -329,10 +391,16 @@ def abort(self):
self.shutdown()

def shutdown(self, wait=True):
pass
# NOTE: We do NOT close the queue or mark as shutdown for the singleton instance.
# The RemoteMpiCommSessionClient is a global singleton that's reused across multiple
# LLM instances. Marking it as shutdown would prevent subsequent LLM instances from
# using it. The connection stays open for the entire lifetime of the mgmn setup.
logger_debug(
f"RemoteMpiCommSessionClient.shutdown() called (no-op for singleton)\n",
"grey")

def shutdown_abort(self, grace: float = 60, reason=None):
pass
self.shutdown()


class RemoteMpiCommSessionServer():
Expand Down Expand Up @@ -393,7 +461,26 @@ def task_wrapper(task: Callable[..., T], *args, **kwargs) -> T:
def serve(self):
logger_debug(f"RemoteMpiCommSessionServer listening on {self.addr}\n",
"yellow")
pending_futures = []
while True:
# Wait for any pending futures from previous tasks to complete
# This ensures all ranks are ready before accepting the next task
if pending_futures:
logger_debug(
f"RemoteMpiCommSessionServer waiting for {len(pending_futures)} pending futures to complete\n",
"grey")
for future in pending_futures:
try:
future.result() # Wait for completion
except Exception as e:
print_colored(
f"RemoteMpiCommSessionServer future failed with exception: {e}\n",
"red")
pending_futures.clear()
logger_debug(
"RemoteMpiCommSessionServer all pending futures completed\n",
"grey")

message: Optional[RemoteTask] = self.queue.get()
if message is None:
logger_debug(
Expand All @@ -410,6 +497,8 @@ def serve(self):
*message.args, **message.kwargs)
self.num_results = self.session.n_workers
assert len(futures) == self.num_results == mpi_world_size()
# Store futures to wait for them before the next task
pending_futures = list(futures)
if message.sync:
for future in futures:
future.add_done_callback(self.mpi_future_callback)
Expand Down
2 changes: 2 additions & 0 deletions tests/integration/test_lists/test-db/l0_dgx_h100.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ l0_dgx_h100:
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True-True-True]
# ------------- AutoDeploy tests ---------------
- accuracy/test_llm_api_autodeploy.py::TestLlama3_1_8B::test_auto_dtype[False-2]
# llmapi
- unittest/llmapi/test_mpi_session.py::test_llmapi_launch_multiple_tasks
- condition:
ranges:
system_gpu_count:
Expand Down
33 changes: 33 additions & 0 deletions tests/unittest/llmapi/_run_multi_llm_tasks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import os
import sys

cur_dir = os.path.dirname(os.path.abspath(__file__))

from tensorrt_llm import LLM
from tensorrt_llm.llmapi import SamplingParams
from tensorrt_llm.llmapi.utils import print_colored

# isort: off
sys.path.append(os.path.join(cur_dir, '..'))
from utils.llm_data import llm_models_root
# isort: on

model_path = llm_models_root() / "llama-models-v2" / "TinyLlama-1.1B-Chat-v1.0"


def run_llm_tp2():
with LLM(model=model_path, tensor_parallel_size=2) as llm:
sampling_params = SamplingParams(max_tokens=10, end_id=-1)
for output in llm.generate(["Hello, my name is"], sampling_params):
print(output)


def run_multi_llm_tasks():
for i in range(3):
print_colored(f"Running LLM task {i}\n", "green")
run_llm_tp2()
print_colored(f"LLM task {i} completed\n", "green")


if __name__ == "__main__":
run_multi_llm_tasks()
43 changes: 43 additions & 0 deletions tests/unittest/llmapi/_run_multi_mpi_comm_tasks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import os
from typing import Literal

import click

from tensorrt_llm.executor.utils import LlmLauncherEnvs
from tensorrt_llm.llmapi.mpi_session import RemoteMpiCommSessionClient
from tensorrt_llm.llmapi.utils import print_colored


def run_task(task_type: Literal["submit", "submit_sync"]):
tasks = range(10)
assert os.environ[
LlmLauncherEnvs.
TLLM_SPAWN_PROXY_PROCESS_IPC_ADDR] is not None, "TLLM_SPAWN_PROXY_PROCESS_IPC_ADDR is not set"
client = RemoteMpiCommSessionClient(
os.environ[LlmLauncherEnvs.TLLM_SPAWN_PROXY_PROCESS_IPC_ADDR])

for task in tasks:
if task_type == "submit":
client.submit(print_colored, f"{task}\n", "green")
elif task_type == "submit_sync":
res = client.submit_sync(print_colored, f"{task}\n", "green")
print(res)


def run_multi_tasks(task_type: Literal["submit", "submit_sync"]):
for i in range(3):
print_colored(f"Running MPI comm task {i}\n", "green")
run_task(task_type)
print_colored(f"MPI comm task {i} completed\n", "green")


@click.command()
@click.option("--task_type",
type=click.Choice(["submit", "submit_sync"]),
default="submit")
def main(task_type: Literal["submit", "submit_sync"]):
run_multi_tasks(task_type)


if __name__ == "__main__":
main()
58 changes: 58 additions & 0 deletions tests/unittest/llmapi/test_mpi_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,20 @@
from subprocess import PIPE, Popen
from typing import Literal

cur_dir = os.path.dirname(os.path.abspath(__file__))

import pytest

from tensorrt_llm.bindings.BuildInfo import ENABLE_MULTI_DEVICE
from tensorrt_llm.llmapi.mpi_session import (MPINodeState, MpiPoolSession,
RemoteMpiCommSessionClient,
split_mpi_env)

# isort: off
sys.path.append(os.path.join(cur_dir, '..'))
from utils.util import skip_single_gpu
# isort: on


def task0():
if MPINodeState.state is None:
Expand Down Expand Up @@ -108,3 +115,54 @@ def task1():
def test_split_mpi_env():
session = MpiPoolSession(n_workers=4)
session.submit_sync(task1)


@skip_single_gpu
@pytest.mark.parametrize(
"task_script", ["_run_mpi_comm_task.py", "_run_multi_mpi_comm_tasks.py"])
def test_llmapi_launch_multiple_tasks(task_script: str):
"""
Test that the trtllm-llmapi-launch can run multiple tasks.
"""
cur_dir = os.path.dirname(os.path.abspath(__file__))
test_file = os.path.join(cur_dir, "_run_multi_llm_tasks.py")
assert os.path.exists(test_file), f"Test file {test_file} does not exist"
command = [
"mpirun", "-n", "2", "--allow-run-as-root", "trtllm-llmapi-launch",
"python3", test_file
]
print(' '.join(command))

with Popen(command,
env=os.environ,
stdout=PIPE,
stderr=PIPE,
bufsize=1,
start_new_session=True,
universal_newlines=True,
cwd=os.path.dirname(os.path.abspath(__file__))) as process:
# Function to read from a stream and write to output
def read_stream(stream, output_stream):
for line in stream:
output_stream.write(line)
output_stream.flush()

# Create threads to read stdout and stderr concurrently
stdout_thread = threading.Thread(target=read_stream,
args=(process.stdout, sys.stdout))
stderr_thread = threading.Thread(target=read_stream,
args=(process.stderr, sys.stderr))

# Start both threads
stdout_thread.start()
stderr_thread.start()

# Wait for the process to complete
return_code = process.wait()

# Wait for both threads to finish reading
stdout_thread.join()
stderr_thread.join()

if return_code != 0:
raise subprocess.CalledProcessError(return_code, command)
Loading