Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Serving with trtllm-serve

AutoDeploy integrates with the OpenAI-compatible `trtllm-serve` CLI so you can expose AutoDeploy-optimized models over HTTP without writing server code. This page shows how to launch the server with the AutoDeploy backend, configure it via YAML, and validate with a simple request.

## Quick start

Launch `trtllm-serve` with the AutoDeploy backend by setting `--backend _autodeploy`:

```bash
trtllm-serve \
meta-llama/Llama-3.1-8B-Instruct \
--backend _autodeploy \
```

- `model`: HF name or local path
- `--backend _autodeploy`: uses AutoDeploy runtime

Once the server is ready, test with an OpenAI-compatible request:

```bash
curl -s http://localhost:8000/v1/chat/completions \
-H 'Content-Type: application/json' \
-d '{
"model": "meta-llama/Llama-3.1-8B-Instruct",
"messages":[{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Where is New York? Tell me in a single sentence."}],
"max_tokens": 32
}'
```

## Configuration via YAML

Use `--extra_llm_api_options` to supply a YAML file that augments or overrides server/runtime settings.

```bash
trtllm-serve \
meta-llama/Llama-3.1-8B \
--backend _autodeploy \
--extra_llm_api_options autodeploy_config.yaml
```

Example `autodeploy_config.yaml`:

```yaml
# Compilation backend for AutoDeploy
compile_backend: torch-opt # options: torch-simple, torch-compile, torch-cudagraph, torch-opt

# Runtime engine
runtime: trtllm # options: trtllm, demollm

# Model loading
skip_loading_weights: false # set true for architecture-only perf runs

# KV cache memory
free_mem_ratio: 0.8 # fraction of free GPU mem for KV cache

# CUDA graph optimization
cuda_graph_batch_sizes: [1, 2, 4, 8, 16, 32, 64]

# Attention backend
attn_backend: flashinfer # recommended for best performance
```

## Limitations and tips

- KV cache block reuse is disabled automatically for AutoDeploy backend
- AutoDeploy backend doesn't yet support disaggregated serving. WIP
- For best performance:
- Prefer `compile_backend: torch-opt`
- Use `attn_backend: flashinfer`
- Set realistic `cuda_graph_batch_sizes` that match expected traffic
- Tune `free_mem_ratio` to 0.8–0.9

## See also

- [AutoDeploy overview](../auto-deploy.md)
- [Benchmarking with trtllm-bench](./benchmarking_with_trtllm_bench.md)
1 change: 1 addition & 0 deletions docs/source/torch/auto_deploy/auto-deploy.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ The exported graph then undergoes a series of automated transformations, includi
- [Incorporating AutoDeploy into Your Own Workflow](./advanced/workflow.md)
- [Expert Configurations](./advanced/expert_configurations.md)
- [Performance Benchmarking](./advanced/benchmarking_with_trtllm_bench.md)
- [Serving with trtllm-serve](./advanced/serving_with_trtllm_serve.md)

## Roadmap

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,6 @@ def prepare_flashinfer_metadata(
flashinfer.get_seq_lens(paged_kv_indptr, paged_kv_last_page_len, page_size),
position_ids.numel(),
)

# return metadata
return (
qo_indptr,
Expand Down
10 changes: 10 additions & 0 deletions tensorrt_llm/_torch/auto_deploy/llm_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,16 @@ def quant_config(self, value: QuantConfig):
self._quant_config = value

### VALIDATION #################################################################################
@field_validator("max_seq_len", mode="before")
@classmethod
def ensure_max_seq_len(cls, value: Any, info: ValidationInfo) -> Any:
if value is None:
# Fallback to the AutoDeployConfig default when not provided
return AutoDeployConfig.model_fields["max_seq_len"].get_default(
call_default_factory=True
)
return value

@field_validator("build_config", mode="before")
@classmethod
def ensure_no_build_config(cls, value: Any, info: ValidationInfo) -> Any:
Expand Down
24 changes: 17 additions & 7 deletions tensorrt_llm/commands/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from tensorrt_llm import LLM as PyTorchLLM
from tensorrt_llm import MultimodalEncoder
from tensorrt_llm._tensorrt_engine import LLM
from tensorrt_llm._torch.auto_deploy.llm import LLM as AutoDeployLLM
from tensorrt_llm._utils import mpi_rank
from tensorrt_llm.executor.utils import LlmLauncherEnvs
from tensorrt_llm.llmapi import (BuildConfig, CapacitySchedulerPolicy,
Expand Down Expand Up @@ -109,7 +110,7 @@ def get_llm_args(model: str,
capacity_scheduler_policy=CapacitySchedulerPolicy.GUARANTEED_NO_EVICT,
dynamic_batch_config=dynamic_batch_config,
)

backend = backend if backend in ["pytorch", "_autodeploy"] else None
llm_args = {
"model":
model,
Expand Down Expand Up @@ -140,7 +141,7 @@ def get_llm_args(model: str,
"kv_cache_config":
kv_cache_config,
"backend":
backend if backend == "pytorch" else None,
backend,
"num_postprocess_workers":
num_postprocess_workers,
"postprocess_tokenizer_dir":
Expand All @@ -162,9 +163,15 @@ def launch_server(host: str,

backend = llm_args["backend"]
model = llm_args["model"]

if backend == 'pytorch':
llm = PyTorchLLM(**llm_args)
elif backend == '_autodeploy':
# AutoDeploy does not support build_config
llm_args.pop("build_config", None)
# TODO(https://github.com/NVIDIA/TensorRT-LLM/issues/7142):
# AutoDeploy does not support cache reuse yet.
llm_args["kv_cache_config"].enable_block_reuse = False
llm = AutoDeployLLM(**llm_args)
else:
llm = LLM(**llm_args)

Expand Down Expand Up @@ -204,10 +211,13 @@ def launch_mm_encoder_server(
default="localhost",
help="Hostname of the server.")
@click.option("--port", type=int, default=8000, help="Port of the server.")
@click.option("--backend",
type=click.Choice(["pytorch", "trt"]),
default="pytorch",
help="Set to 'pytorch' for pytorch path. Default is cpp path.")
@click.option(
"--backend",
type=click.Choice(["pytorch", "trt", "_autodeploy"]),
default="pytorch",
help=
"Set to 'pytorch' for pytorch path and '_autodeploy' for autodeploy path. Default is pytorch path."
)
@click.option('--log_level',
type=click.Choice(severity_map.keys()),
default='info',
Expand Down
2 changes: 2 additions & 0 deletions tests/integration/defs/accuracy/test_llm_api_autodeploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ def get_default_kwargs(self):
return {
'skip_tokenizer_init': False,
'trust_remote_code': True,
# TODO(https://github.com/NVIDIA/TensorRT-LLM/issues/7142):
# AutoDeploy does not support cache reuse yet.
'kv_cache_config': {
'enable_block_reuse': False,
},
Expand Down
Loading