Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions examples/auto_deploy/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,5 @@
!.vscode
benchmark_results.json
*.png
# ignore config files that users might put here for debugging
*.yaml
56 changes: 45 additions & 11 deletions examples/auto_deploy/build_and_run_ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
# Global torch config, set the torch compile cache to fix up to llama 405B
torch._dynamo.config.cache_size_limit = 20

# simple string, TRT-LLM style text-only prompt or full-scale HF message template
PromptInput = Union[str, Dict, List[Dict]]


class PromptConfig(BaseModel):
"""Prompt configuration.
Expand All @@ -35,17 +38,27 @@ class PromptConfig(BaseModel):
"""

batch_size: int = Field(default=2, description="Number of queries")
queries: Union[str, List[str]] = Field(
queries: Union[PromptInput, List[PromptInput]] = Field(
default_factory=lambda: [
# OPTION 1: simple text prompt
"How big is the universe? ",
"In simple words and in a single sentence, explain the concept of gravity: ",
"How to fix slicing in golf? ",
"Where is the capital of Iceland? ",
"How big is the universe? ",
"In simple words and in a single sentence, explain the concept of gravity: ",
"How to fix slicing in golf? ",
"Where is the capital of Iceland? ",
]
# OPTION 2: wrapped text prompt for TRT-LLM
{"prompt": "In simple words and a single sentence, explain the concept of gravity: "},
# OPTION 3: a full-scale HF message template (this one works for text-only models!)
# Learn more about chat templates: https://huggingface.co/docs/transformers/en/chat_templating
# and multi-modal templates: https://huggingface.co/docs/transformers/en/chat_templating_multimodal
[
{
"role": "user",
"content": "How to fix slicing in golf?",
}
],
# More prompts...
{"prompt": "Where is the capital of Iceland? "},
],
description="Example queries to prompt the model with. We support both TRT-LLM text-only "
"queries via the 'prompt' key and full-scale HF message template called via "
"apply_chat_template.",
)
sp_kwargs: Dict[str, Any] = Field(
default_factory=lambda: {"max_tokens": 100, "top_k": 200, "temperature": 1.0},
Expand All @@ -59,10 +72,28 @@ def model_post_init(self, __context: Any):
NOTE (lucaslie): has to be done with model_post_init to ensure it's always run. field
validators are only run if a value is provided.
"""
queries = [self.queries] if isinstance(self.queries, str) else self.queries
queries = self.queries if isinstance(self.queries, list) else [self.queries]
batch_size = self.batch_size
queries = queries * (batch_size // len(queries) + 1)
self.queries = queries[:batch_size]
queries = queries[:batch_size]

# now let's standardize the queries for the LLM api to understand them
queries_processed = []
for query in queries:
if isinstance(query, str):
queries_processed.append({"prompt": query})
elif isinstance(query, dict):
queries_processed.append(query)
elif isinstance(query, list):
queries_processed.append(
{
"prompt": "Fake prompt. Check out messages field for the HF chat template.",
"messages": query, # contains the actual HF chat template
}
)
else:
raise ValueError(f"Invalid query type: {type(query)}")
self.queries = queries_processed

@field_validator("sp_kwargs", mode="after")
@classmethod
Expand Down Expand Up @@ -239,6 +270,9 @@ def main(config: Optional[ExperimentConfig] = None):

# prompt the model and print its output
ad_logger.info("Running example prompts...")

# now let's try piping through multimodal data

outs = llm.generate(
config.prompt.queries,
sampling_params=SamplingParams(**config.prompt.sp_kwargs),
Expand Down
1 change: 1 addition & 0 deletions tensorrt_llm/_torch/auto_deploy/custom_ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .linear import *
from .mla import *
from .quant import *
from .qwen_ops import *
from .rms_norm import *
from .torch_attention import *
from .torch_backend_attention import *
Expand Down
Loading