diff --git a/docs/source/torch/auto_deploy/advanced/serving_with_trtllm_serve.md b/docs/source/torch/auto_deploy/advanced/serving_with_trtllm_serve.md new file mode 100644 index 00000000000..5a73d047ea4 --- /dev/null +++ b/docs/source/torch/auto_deploy/advanced/serving_with_trtllm_serve.md @@ -0,0 +1,77 @@ +# Serving with trtllm-serve + +AutoDeploy integrates with the OpenAI-compatible `trtllm-serve` CLI so you can expose AutoDeploy-optimized models over HTTP without writing server code. This page shows how to launch the server with the AutoDeploy backend, configure it via YAML, and validate with a simple request. + +## Quick start + +Launch `trtllm-serve` with the AutoDeploy backend by setting `--backend _autodeploy`: + +```bash +trtllm-serve \ + meta-llama/Llama-3.1-8B-Instruct \ + --backend _autodeploy +``` + +- `model`: HF name or local path +- `--backend _autodeploy`: uses AutoDeploy runtime + +Once the server is ready, test with an OpenAI-compatible request: + +```bash +curl -s http://localhost:8000/v1/chat/completions \ + -H 'Content-Type: application/json' \ + -d '{ + "model": "meta-llama/Llama-3.1-8B-Instruct", + "messages":[{"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Where is New York? Tell me in a single sentence."}], + "max_tokens": 32 + }' +``` + +## Configuration via YAML + +Use `--extra_llm_api_options` to supply a YAML file that augments or overrides server/runtime settings. + +```bash +trtllm-serve \ + meta-llama/Llama-3.1-8B \ + --backend _autodeploy \ + --extra_llm_api_options autodeploy_config.yaml +``` + +Example `autodeploy_config.yaml`: + +```yaml +# Compilation backend for AutoDeploy +compile_backend: torch-opt # options: torch-simple, torch-compile, torch-cudagraph, torch-opt + +# Runtime engine +runtime: trtllm # options: trtllm, demollm + +# Model loading +skip_loading_weights: false # set true for architecture-only perf runs + +# KV cache memory +free_mem_ratio: 0.8 # fraction of free GPU mem for KV cache + +# CUDA graph optimization +cuda_graph_batch_sizes: [1, 2, 4, 8, 16, 32, 64] + +# Attention backend +attn_backend: flashinfer # recommended for best performance +``` + +## Limitations and tips + +- KV cache block reuse is disabled automatically for AutoDeploy backend +- AutoDeploy backend doesn't yet support disaggregated serving. WIP +- For best performance: + - Prefer `compile_backend: torch-opt` + - Use `attn_backend: flashinfer` + - Set realistic `cuda_graph_batch_sizes` that match expected traffic + - Tune `free_mem_ratio` to 0.8–0.9 + +## See also + +- [AutoDeploy overview](../auto-deploy.md) +- [Benchmarking with trtllm-bench](./benchmarking_with_trtllm_bench.md) diff --git a/docs/source/torch/auto_deploy/auto-deploy.md b/docs/source/torch/auto_deploy/auto-deploy.md index fc00c0ccc3e..185e1f321ae 100644 --- a/docs/source/torch/auto_deploy/auto-deploy.md +++ b/docs/source/torch/auto_deploy/auto-deploy.md @@ -59,6 +59,7 @@ The exported graph then undergoes a series of automated transformations, includi - [Incorporating AutoDeploy into Your Own Workflow](./advanced/workflow.md) - [Expert Configurations](./advanced/expert_configurations.md) - [Performance Benchmarking](./advanced/benchmarking_with_trtllm_bench.md) +- [Serving with trtllm-serve](./advanced/serving_with_trtllm_serve.md) ## Roadmap diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py index 414039a5065..01fb0deb576 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py @@ -198,7 +198,6 @@ def prepare_flashinfer_metadata( flashinfer.get_seq_lens(paged_kv_indptr, paged_kv_last_page_len, page_size), position_ids.numel(), ) - # return metadata return ( qo_indptr, diff --git a/tensorrt_llm/_torch/auto_deploy/llm_args.py b/tensorrt_llm/_torch/auto_deploy/llm_args.py index 812dfea29cd..9811274a8bc 100644 --- a/tensorrt_llm/_torch/auto_deploy/llm_args.py +++ b/tensorrt_llm/_torch/auto_deploy/llm_args.py @@ -274,6 +274,16 @@ def quant_config(self, value: QuantConfig): self._quant_config = value ### VALIDATION ################################################################################# + @field_validator("max_seq_len", mode="before") + @classmethod + def ensure_max_seq_len(cls, value: Any, info: ValidationInfo) -> Any: + if value is None: + # Fallback to the AutoDeployConfig default when not provided + return AutoDeployConfig.model_fields["max_seq_len"].get_default( + call_default_factory=True + ) + return value + @field_validator("build_config", mode="before") @classmethod def ensure_no_build_config(cls, value: Any, info: ValidationInfo) -> Any: diff --git a/tensorrt_llm/commands/serve.py b/tensorrt_llm/commands/serve.py index 07eb13d7968..c1013eb3c5c 100644 --- a/tensorrt_llm/commands/serve.py +++ b/tensorrt_llm/commands/serve.py @@ -14,6 +14,7 @@ from tensorrt_llm import LLM as PyTorchLLM from tensorrt_llm import MultimodalEncoder from tensorrt_llm._tensorrt_engine import LLM +from tensorrt_llm._torch.auto_deploy.llm import LLM as AutoDeployLLM from tensorrt_llm._utils import mpi_rank from tensorrt_llm.executor.utils import LlmLauncherEnvs from tensorrt_llm.llmapi import (BuildConfig, CapacitySchedulerPolicy, @@ -109,7 +110,7 @@ def get_llm_args(model: str, capacity_scheduler_policy=CapacitySchedulerPolicy.GUARANTEED_NO_EVICT, dynamic_batch_config=dynamic_batch_config, ) - + backend = backend if backend in ["pytorch", "_autodeploy"] else None llm_args = { "model": model, @@ -140,7 +141,7 @@ def get_llm_args(model: str, "kv_cache_config": kv_cache_config, "backend": - backend if backend == "pytorch" else None, + backend, "num_postprocess_workers": num_postprocess_workers, "postprocess_tokenizer_dir": @@ -162,9 +163,15 @@ def launch_server(host: str, backend = llm_args["backend"] model = llm_args["model"] - if backend == 'pytorch': llm = PyTorchLLM(**llm_args) + elif backend == '_autodeploy': + # AutoDeploy does not support build_config + llm_args.pop("build_config", None) + # TODO(https://github.com/NVIDIA/TensorRT-LLM/issues/7142): + # AutoDeploy does not support cache reuse yet. + llm_args["kv_cache_config"].enable_block_reuse = False + llm = AutoDeployLLM(**llm_args) else: llm = LLM(**llm_args) @@ -204,10 +211,13 @@ def launch_mm_encoder_server( default="localhost", help="Hostname of the server.") @click.option("--port", type=int, default=8000, help="Port of the server.") -@click.option("--backend", - type=click.Choice(["pytorch", "trt"]), - default="pytorch", - help="Set to 'pytorch' for pytorch path. Default is cpp path.") +@click.option( + "--backend", + type=click.Choice(["pytorch", "trt", "_autodeploy"]), + default="pytorch", + help= + "Set to 'pytorch' for pytorch path and '_autodeploy' for autodeploy path. Default is pytorch path." +) @click.option('--log_level', type=click.Choice(severity_map.keys()), default='info', diff --git a/tests/integration/defs/accuracy/test_llm_api_autodeploy.py b/tests/integration/defs/accuracy/test_llm_api_autodeploy.py index da64969337e..d761ae6851d 100644 --- a/tests/integration/defs/accuracy/test_llm_api_autodeploy.py +++ b/tests/integration/defs/accuracy/test_llm_api_autodeploy.py @@ -30,6 +30,8 @@ def get_default_kwargs(self): return { 'skip_tokenizer_init': False, 'trust_remote_code': True, + # TODO(https://github.com/NVIDIA/TensorRT-LLM/issues/7142): + # AutoDeploy does not support cache reuse yet. 'kv_cache_config': { 'enable_block_reuse': False, },