Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions components/src/dynamo/trtllm/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

from enum import Enum


class DisaggregationMode(Enum):
AGGREGATED = "prefill_and_decode"
PREFILL = "prefill"
DECODE = "decode"
ENCODE = "encode"
175 changes: 125 additions & 50 deletions components/src/dynamo/trtllm/encode_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@
# SPDX-License-Identifier: Apache-2.0

import logging
from dataclasses import asdict
from typing import Any, Dict, Union

import torch
from tensorrt_llm.inputs import default_multimodal_input_loader

import dynamo.nixl_connect as nixl_connect
from dynamo.trtllm.utils.disagg_utils import DisaggregatedParamsCodec


class EncodeHelper:
Expand Down Expand Up @@ -185,10 +188,14 @@ async def read_embeddings_from_encode_response(
return encodings_tensor

@staticmethod
async def process_embedding_request(
async def process_encode_request(
request: Dict[str, Any],
multimodal_processor,
connector: nixl_connect.Connector,
tokenizer=None,
model_dir=None,
model_type=None,
engine=None,
):
"""
Process embedding request by loading embeddings and creating NIXL readable operation.
Expand All @@ -203,57 +210,125 @@ async def process_embedding_request(
"""
# Load embeddings first to get the actual shape
messages = request.get("messages", [])
_, _, embedding_paths = multimodal_processor.extract_prompt_and_media(messages)

if not embedding_paths:
# Placeholder for TRTLLM Encoder to be called
# TRTLLM Encoder will return a memory handler on the encoder GPU with the encodings
logging.warning(
"No embedding paths found, NIXL transfer for image urls not supported by TRTLLM Encoder yet"
(
text_prompt,
image_urls,
embedding_paths,
) = multimodal_processor.extract_prompt_and_media(messages)
if embedding_paths:
loaded_data = multimodal_processor.load_tensor_from_path_or_url(
embedding_paths[0]
)
yield {"error": "No embedding paths found"}
return

# Load the embeddings data
loaded_data = multimodal_processor.load_tensor_from_path_or_url(
embedding_paths[0]
)

# Handle both tensor and dictionary formats
if isinstance(loaded_data, dict):
# Dictionary format (e.g., maverick_mm_embed_seashore_v3.pt)
encodings = loaded_data.get("mm_embeddings")
if encodings is None:
yield {"error": "Dictionary embeddings missing 'mm_embeddings' key"}
# Handle both tensor and dictionary formats
if isinstance(loaded_data, dict):
# Dictionary format (e.g., maverick_mm_embed_seashore_v3.pt)
encodings = loaded_data.get("mm_embeddings")
if encodings is None:
yield {"error": "Dictionary embeddings missing 'mm_embeddings' key"}
return

# Store auxiliary data for later transmission
auxiliary_data = {
k: v for k, v in loaded_data.items() if k != "mm_embeddings"
}
else:
# Tensor format (e.g., llava_next_mm_embed_seashore.pt)
encodings = loaded_data
auxiliary_data = {}

# Create readable operation with main embeddings tensor (works for both formats)
descriptor = nixl_connect.Descriptor(encodings)
with connector.create_readable(descriptor) as readable_op:
# Get the metadata for the readable operation
op_metadata = readable_op.metadata()

# Send back shape info, readable metadata, and serialized auxiliary data
response = {
"nixl_readable_metadata": op_metadata.model_dump(),
"embeddings_shape": list(encodings.shape),
"embeddings_dtype": str(encodings.dtype),
"auxiliary_data": EncodeHelper.serialize_tensor_dict(
auxiliary_data
),
}
yield response

# Wait for the prefill worker to complete the read operation
logging.debug(
"EncodeHelper waiting for PrefillHandler to read embeddings..."
)
await readable_op.wait_for_completion()
logging.debug("EncodeHelper completed readable operation.")
else:
logging.info(
"========== ENCODE WORKER: Full EPD - Using MultimodalEncoder =========="
)
inputs = default_multimodal_input_loader(
tokenizer=tokenizer,
model_dir=model_dir,
model_type=model_type,
modality="image",
prompts=[text_prompt],
media=image_urls[0],
)
# engine.llm is the MultimodalEncoder instance
# MultimodalEncoder.generate() returns a list of GenerationResult objects
encoder_outputs = list(engine.llm.generate(inputs))
if not encoder_outputs:
logging.error("ENCODE WORKER: encoder_outputs is empty")
yield {"ep_disaggregated_params": None}
return

# Store auxiliary data for later transmission
auxiliary_data = {
k: v for k, v in loaded_data.items() if k != "mm_embeddings"
}
else:
# Tensor format (e.g., llava_next_mm_embed_seashore.pt)
encodings = loaded_data
auxiliary_data = {}

# Create readable operation with main embeddings tensor (works for both formats)
descriptor = nixl_connect.Descriptor(encodings)
with connector.create_readable(descriptor) as readable_op:
# Get the metadata for the readable operation
op_metadata = readable_op.metadata()

# Send back shape info, readable metadata, and serialized auxiliary data
response = {
"nixl_readable_metadata": op_metadata.model_dump(),
"embeddings_shape": list(encodings.shape),
"embeddings_dtype": str(encodings.dtype),
"auxiliary_data": EncodeHelper.serialize_tensor_dict(auxiliary_data),
}
yield response
ep_disaggregated_params = encoder_outputs[0].disaggregated_params
if ep_disaggregated_params is None:
logging.error(
"ENCODE WORKER: encoder_outputs[0].disaggregated_params is None"
)
yield {"ep_disaggregated_params": None}
return

# Wait for the prefill worker to complete the read operation
logging.debug(
"EncodeHelper waiting for PrefillHandler to read embeddings..."
if (
hasattr(ep_disaggregated_params, "multimodal_embedding_handles")
and ep_disaggregated_params.multimodal_embedding_handles
):
logging.info(
f"ENCODE WORKER: Generated {len(ep_disaggregated_params.multimodal_embedding_handles)} embedding handle(s)"
)
else:
logging.warning(
"ENCODE WORKER: ep_disaggregated_params has no multimodal_embedding_handles"
)
# Prepare for Network Transfer
encoded_params = DisaggregatedParamsCodec.encode(ep_disaggregated_params)
params_dict = asdict(encoded_params)
# Also send the processed prompt (which includes <image> tokens)
# default_multimodal_input_loader returns a list, get the first element
processed_prompt = None
prompt_token_ids = None

if isinstance(inputs, list) and len(inputs) > 0:
first_input = inputs[0]
if isinstance(first_input, dict):
processed_prompt = first_input.get("prompt")
else:
processed_prompt = getattr(first_input, "prompt", None)

# Tokenize the processed prompt for prefill worker
if processed_prompt and tokenizer is not None:
prompt_token_ids = tokenizer.encode(
processed_prompt, add_special_tokens=False
)
logging.info(
f"ENCODE WORKER: Tokenized processed_prompt (length={len(prompt_token_ids)})"
)

logging.info(
f"ENCODE WORKER: Extracted processed_prompt: {processed_prompt}"
)
await readable_op.wait_for_completion()
logging.debug("EncodeHelper completed readable operation.")

yield {
"ep_disaggregated_params": params_dict,
"processed_prompt": processed_prompt, # Prompt with <image> tokens
"prompt_token_ids": prompt_token_ids, # Token IDs for consistency
}
return
37 changes: 27 additions & 10 deletions components/src/dynamo/trtllm/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,40 @@

import logging
from contextlib import asynccontextmanager
from typing import AsyncGenerator, Optional
from typing import AsyncGenerator, Optional, Union

from tensorrt_llm import LLM
from tensorrt_llm import LLM, MultimodalEncoder

from dynamo.trtllm.constants import DisaggregationMode

logging.basicConfig(level=logging.DEBUG)


class TensorRTLLMEngine:
def __init__(self, engine_args):
def __init__(self, engine_args, disaggregation_mode: DisaggregationMode):
self.engine_args = engine_args
self._llm: Optional[LLM] = None
self._disaggregation_mode = disaggregation_mode

async def initialize(self):
if not self._llm:
model = self.engine_args.pop("model")
self._llm = LLM(
model=model,
**self.engine_args,
)
if self._disaggregation_mode == DisaggregationMode.ENCODE:
# Initialize the multimodal encoder for full EPD
max_batch_size = self.engine_args.pop("max_batch_size", 1)
logging.info(
f"Initializing multimodal encoder with max_batch_size: {max_batch_size}"
)
self._llm = MultimodalEncoder(
model=model,
max_batch_size=max_batch_size,
)
else:
# Initialize the regular LLM for decode-only or prefill-decode
self._llm = LLM(
model=model,
**self.engine_args,
)

async def cleanup(self):
if self._llm:
Expand All @@ -33,15 +48,17 @@ async def cleanup(self):
self._llm = None

@property
def llm(self):
def llm(self) -> Union[LLM, MultimodalEncoder]:
if not self._llm:
raise RuntimeError("Engine not initialized")
return self._llm


@asynccontextmanager
async def get_llm_engine(engine_args) -> AsyncGenerator[TensorRTLLMEngine, None]:
engine = TensorRTLLMEngine(engine_args)
async def get_llm_engine(
engine_args, disaggregation_mode
) -> AsyncGenerator[TensorRTLLMEngine, None]:
engine = TensorRTLLMEngine(engine_args, disaggregation_mode)
try:
await engine.initialize()
yield engine
Expand Down
14 changes: 8 additions & 6 deletions components/src/dynamo/trtllm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from tensorrt_llm.llmapi.tokenizer import tokenizer_factory
from tensorrt_llm.metrics import MetricsCollector
from torch.cuda import device_count
from transformers import AutoConfig
from transformers import AutoConfig, GenerationConfig

import dynamo.nixl_connect as nixl_connect
from dynamo.common.config_dump import dump_config
Expand Down Expand Up @@ -239,8 +239,13 @@ async def init(runtime: DistributedRuntime, config: Config):

# Populate default sampling params from the model
tokenizer = tokenizer_factory(arg_map["model"])

model_config = AutoConfig.from_pretrained(arg_map["model"], trust_remote_code=True)
generation_config = GenerationConfig.from_pretrained(
arg_map["model"], trust_remote_code=True
)
default_sampling_params = SamplingParams()
default_sampling_params._setup(tokenizer)
default_sampling_params._setup(tokenizer, model_config, generation_config)
default_sampling_params.stop = None
model_input = ModelInput.Tokens

Expand All @@ -262,9 +267,6 @@ async def init(runtime: DistributedRuntime, config: Config):
if modality == "multimodal":
engine_args["skip_tokenizer_init"] = False
model_input = ModelInput.Text
model_config = AutoConfig.from_pretrained(
config.model_path, trust_remote_code=True
)
multimodal_processor = MultimodalRequestProcessor(
model_type=model_config.model_type,
model_dir=config.model_path,
Expand All @@ -286,7 +288,7 @@ async def init(runtime: DistributedRuntime, config: Config):
config.dump_config_to, {"engine_args": engine_args, "dynamo_args": config}
)

async with get_llm_engine(engine_args) as engine:
async with get_llm_engine(engine_args, config.disaggregation_mode) as engine:
endpoint = component.endpoint(config.endpoint)

# should ideally call get_engine_runtime_config
Expand Down
Loading
Loading