Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 12 additions & 13 deletions tensorrt_llm/executor/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@
from ..bindings import executor as tllm
from ..builder import Engine
from ..disaggregated_params import DisaggregatedParams
from ..llmapi.llm_args import TorchLlmArgs
from ..llmapi.llm_utils import KvCacheRetentionConfig
from ..llmapi.mpi_session import (MpiSession, external_mpi_comm_available,
need_spawn_mpi_workers)
from ..llmapi.tokenizer import TokenizerBase
from ..llmapi.utils import (AsyncQueue, enable_llm_debug,
enable_worker_single_process_for_tp1, print_colored,
print_colored_debug)
Expand Down Expand Up @@ -354,7 +356,9 @@ def create(
postproc_worker_config: Optional[PostprocWorkerConfig] = None,
is_llm_executor: Optional[bool] = None,
lora_config: Optional[LoraConfig] = None,
garbage_collection_gen0_threshold: Optional[int] = None,
hf_model_dir: Optional[Path] = None,
tokenizer: Optional[TokenizerBase] = None,
llm_args: Optional[TorchLlmArgs] = None,
) -> Union["GenerationExecutorProxy", "GenerationExecutorWorker"]:
# local imports to avoid cyclic importing
from .proxy import GenerationExecutorProxy
Expand All @@ -381,6 +385,9 @@ def create(
"engine": engine,
"executor_config": executor_config,
"batched_logits_processor": batched_logits_processor,
"hf_model_dir": hf_model_dir,
"tokenizer": tokenizer,
"llm_args": llm_args,
}

if lora_config:
Expand All @@ -398,9 +405,7 @@ def create(
model_world_size=model_world_size,
mpi_session=mpi_session,
postproc_worker_config=postproc_worker_config,
is_llm_executor=is_llm_executor,
garbage_collection_gen0_threshold=
garbage_collection_gen0_threshold)
is_llm_executor=is_llm_executor)

# WAR: For the performance of gathering logits, we use single process worker
# for TP1 to avoid the large overhead of IPC.
Expand All @@ -411,9 +416,7 @@ def create(
"Using single process worker for TP1, this may hurt streaming generation performance."
)
return GenerationExecutorWorker(**worker_kwargs,
is_llm_executor=is_llm_executor,
garbage_collection_gen0_threshold=
garbage_collection_gen0_threshold)
is_llm_executor=is_llm_executor)

# For single-gpu case:
# Partition the workload to multiple process for streaming performance.
Expand All @@ -425,9 +428,7 @@ def create(
model_world_size=model_world_size,
mpi_session=None, # use mpi4py
postproc_worker_config=postproc_worker_config,
is_llm_executor=is_llm_executor,
garbage_collection_gen0_threshold=
garbage_collection_gen0_threshold)
is_llm_executor=is_llm_executor)
else:
ctx = multiprocessing.get_context("spawn")
# The ProcessPoolExecutorSession is used to support Windows, as mpi4py cannot.
Expand All @@ -438,9 +439,7 @@ def create(
model_world_size=model_world_size,
mpi_session=mpi_session,
postproc_worker_config=postproc_worker_config,
is_llm_executor=is_llm_executor,
garbage_collection_gen0_threshold=
garbage_collection_gen0_threshold)
is_llm_executor=is_llm_executor)

def wait_first_completed(
self, futures: List[GenerationResult]
Expand Down
9 changes: 4 additions & 5 deletions tensorrt_llm/executor/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ def __init__(
worker_cls: type = GenerationExecutorWorker,
postproc_worker_config: Optional[PostprocWorkerConfig] = None,
is_llm_executor: Optional[bool] = None,
garbage_collection_gen0_threshold: Optional[int] = None,
) -> None:
postproc_worker_config = postproc_worker_config or PostprocWorkerConfig(
)
Expand Down Expand Up @@ -87,14 +86,14 @@ def __init__(

self.model_world_size = model_world_size

self.garbage_collection_gen0_threshold = garbage_collection_gen0_threshold
self.garbage_collection_gen0_threshold = worker_kwargs[
"llm_args"].garbage_collection_gen0_threshold if worker_kwargs.get(
"llm_args", None) is not None else None

worker_kwargs = dict(**worker_kwargs,
worker_queues=self._setup_queues(),
postproc_worker_config=postproc_worker_config,
is_llm_executor=False,
garbage_collection_gen0_threshold=self.
garbage_collection_gen0_threshold)
is_llm_executor=False)

if "log_level" not in worker_kwargs:
worker_kwargs["log_level"] = logger.level
Expand Down
101 changes: 64 additions & 37 deletions tensorrt_llm/executor/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@
mpi_comm, mpi_rank, nvtx_range_debug)
from ..bindings import executor as tllm
from ..builder import ConfigEncoder, Engine, EngineConfig
from ..llmapi.llm_args import PybindMirror
from ..llmapi.llm_args import PybindMirror, TorchLlmArgs
from ..llmapi.mpi_session import set_mpi_session_cpp
from ..llmapi.tokenizer import TokenizerBase
from ..llmapi.tracer import VizTracer, global_tracer, set_global_tracer
from ..llmapi.utils import (AsyncQueue, ManagedThread, _SyncQueue,
clear_sched_affinity, print_colored_debug,
Expand Down Expand Up @@ -60,7 +61,9 @@ def __init__(
postproc_worker_config: Optional[PostprocWorkerConfig] = None,
is_llm_executor: Optional[bool] = None,
lora_config: Optional[LoraConfig] = None,
garbage_collection_gen0_threshold: Optional[int] = None,
hf_model_dir: Optional[Path] = None,
tokenizer: Optional[TokenizerBase] = None,
llm_args: Optional[TorchLlmArgs] = None,
) -> None:
postproc_config = postproc_worker_config or PostprocWorkerConfig()
super().__init__(
Expand All @@ -81,54 +84,49 @@ def __init__(
self._await_response_helper = AwaitResponseHelper(
self) # TODO: make it weakref
self._executor_config = executor_config
self._is_pytorch_backend = getattr(self._executor_config, "backend",
None) == "pytorch"
self._is_pytorch_backend = llm_args is not None and llm_args.backend == "pytorch"
self.llm_args = llm_args

if global_mpi_size() > 1:
logger.set_rank(self.global_rank)

if isinstance(engine, list):
engine = engine[self.rank]

if executor_config is None:
executor_config = tllm.ExecutorConfig(1)

executor_config.logits_post_processor_config = tllm.LogitsPostProcessorConfig(
processor_batched=batched_logits_processor, replicate=False)

def _create_engine():
def _get_comm_ranks_device_id():
device_id = self.global_rank % torch.cuda.device_count()
torch.cuda.set_device(device_id)

# Make sure C++ executor would use same devices/ranks as py_executor
global_rank = global_mpi_rank()
comm_ranks = mpi_comm().allgather(global_rank)
device_ids = mpi_comm().allgather(device_id)
return comm_ranks, device_ids

def _create_py_executor(executor_config):
assert executor_config is None, "expect an empty executor_config is _create_py_executor"
executor_config = llm_args.get_executor_config(
hf_model_dir, tokenizer)
# Persist so downstream code (e.g., default max_tokens deduction) has access
self._executor_config = executor_config
executor_config.logits_post_processor_config = tllm.LogitsPostProcessorConfig(
processor_batched=batched_logits_processor, replicate=False)
comm_ranks, device_ids = _get_comm_ranks_device_id()
executor_config.parallel_config = tllm.ParallelConfig(
participant_ids=comm_ranks, device_ids=device_ids)

if isinstance(engine, Engine):
return tllm.Executor(engine.engine,
json.dumps(engine.config.to_dict(),
cls=ConfigEncoder),
tllm.ModelType.DECODER_ONLY,
executor_config=executor_config,
managed_weights=engine.managed_weights)

if not hasattr(executor_config, "backend"):
return tllm.Executor(engine, tllm.ModelType.DECODER_ONLY,
executor_config)
args = {
"executor_config": executor_config,
"checkpoint_dir": executor_config.hf_model_dir,
}
assert hasattr(
executor_config, "backend"
), "executor_config should be with backend in _create_py_executor"
if executor_config.backend == "pytorch":
from tensorrt_llm._torch.pyexecutor.py_executor_creator import \
create_py_executor
create_executor = create_py_executor
args["lora_config"] = lora_config
args[
"garbage_collection_gen0_threshold"] = garbage_collection_gen0_threshold
"garbage_collection_gen0_threshold"] = llm_args.garbage_collection_gen0_threshold
elif executor_config.backend == "_autodeploy":
from tensorrt_llm._torch.auto_deploy.shim.ad_executor import \
create_autodeploy_executor
Expand All @@ -138,7 +136,30 @@ def _create_engine():
f"Unsupported backend config: {executor_config.backend}")
return create_executor(**args)

self.engine = _create_engine()
def _create_engine(executor_config):
if executor_config is None:
executor_config = tllm.ExecutorConfig(1)
executor_config.logits_post_processor_config = tllm.LogitsPostProcessorConfig(
processor_batched=batched_logits_processor, replicate=False)
comm_ranks, device_ids = _get_comm_ranks_device_id()
executor_config.parallel_config = tllm.ParallelConfig(
participant_ids=comm_ranks, device_ids=device_ids)

if isinstance(engine, Engine):
return tllm.Executor(engine.engine,
json.dumps(engine.config.to_dict(),
cls=ConfigEncoder),
tllm.ModelType.DECODER_ONLY,
executor_config=executor_config,
managed_weights=engine.managed_weights)

assert not hasattr(executor_config, "backend")
return tllm.Executor(engine, tllm.ModelType.DECODER_ONLY,
executor_config)

self.engine = _create_py_executor(
executor_config) if llm_args is not None else _create_engine(
executor_config)

self._lora_manager: Optional[LoraManager] = None
self._prompt_adapter_manager: Optional[PromptAdapterManager] = None
Expand All @@ -161,7 +182,7 @@ def _create_engine():
if engine_config.build_config.max_prompt_embedding_table_size > 0:
self._prompt_adapter_manager = PromptAdapterManager()

if getattr(executor_config, "backend",
if getattr(self._executor_config, "backend",
"") == "pytorch" and lora_config is not None:
from tensorrt_llm._torch.pyexecutor.resource_manager import \
ResourceManagerType
Expand Down Expand Up @@ -430,14 +451,16 @@ def _enqueue_request(self, request: GenerationRequest) -> int:
context_phase_params = request.disaggregated_params.get_context_phase_params(
)

is_overlap_enabled = self._is_pytorch_backend and not self._executor_config.pytorch_backend_config.disable_overlap_scheduler
if is_overlap_enabled:
is_disaggregated = self.engine.kv_cache_transceiver is not None
if is_disaggregated and (
request_type == tllm.RequestType.REQUEST_TYPE_CONTEXT_ONLY):
raise ValueError(
"Context only requests are not supported in pytorch backend when overlap is enabled."
)
if self._is_pytorch_backend:
assert isinstance(self.llm_args, TorchLlmArgs)
if not self.llm_args.disable_overlap_scheduler:
is_disaggregated = self.engine.kv_cache_transceiver is not None
if is_disaggregated and (
request_type
== tllm.RequestType.REQUEST_TYPE_CONTEXT_ONLY):
raise ValueError(
"Context only requests are not supported in pytorch backend when overlap is enabled."
)

assert request.id is not None

Expand Down Expand Up @@ -641,7 +664,9 @@ def worker_main(
is_llm_executor: Optional[
bool] = True, # whether it's the main executor instance
lora_config: Optional[LoraConfig] = None,
garbage_collection_gen0_threshold: Optional[int] = None,
hf_model_dir: Optional[Path] = None,
tokenizer: Optional[TokenizerBase] = None,
llm_args: Optional[TorchLlmArgs] = None,
) -> None:
mpi_comm().barrier()
print_colored_debug(f"Worker {mpi_rank()} entering worker_main...\n",
Expand Down Expand Up @@ -768,7 +793,9 @@ def notify_proxy_threads_to_quit():
postproc_worker_config=postproc_worker_config,
is_llm_executor=is_llm_executor,
lora_config=lora_config,
garbage_collection_gen0_threshold=garbage_collection_gen0_threshold)
hf_model_dir=hf_model_dir,
tokenizer=tokenizer,
llm_args=llm_args)
except Exception as e:
logger.error(f"Failed to initialize executor on rank {mpi_rank()}: {e}")
logger.error(traceback.format_exc())
Expand Down
Loading