|
1 |
| -(architecture-overview)= |
| 1 | +# Architecture Overview |
2 | 2 |
|
3 |
| -# TensorRT-LLM Architecture |
| 3 | +The `LLM` class is a core entry point for the TensorRT-LLM, providing a simplified `generate()` API for efficient large language model inference. This abstraction aims to streamline the user experience, as demonstrated with TinyLlama: |
4 | 4 |
|
5 |
| -TensorRT-LLM is a toolkit to assemble optimized solutions to perform Large Language Model (LLM) inference. It offers a Model Definition API to define models and compile efficient [TensorRT](https://developer.nvidia.com/tensorrt) engines for NVIDIA GPUs. It also contains Python and C++ components to build runtimes to execute those engines as well as backends for the [Triton Inference |
6 |
| -Server](https://developer.nvidia.com/nvidia-triton-inference-server) to easily create web-based services for LLMs. TensorRT-LLM supports multi-GPU and multi-node configurations (through MPI). |
| 5 | +```python |
| 6 | +from tensorrt_llm import LLM |
7 | 7 |
|
8 |
| -As a user, the very first step to create an inference solution is to either define your own model or select a pre-defined network architecture (refer to {ref}`models` for the list of models supported by TensorRT-LLM). Once defined, that model must be trained using a training framework (training is outside of the scope of TensorRT-LLM). For pre-defined models, checkpoints can be downloaded from various providers. To illustrate that point, a lot of examples in TensorRT-LLM use model weights obtained from the [Hugging Face](https://huggingface.co) hub and trained using [NVIDIA Nemo](https://developer.nvidia.com/nemo) or [PyTorch](https://pytorch.org). |
| 8 | +# Initialize the LLM with a specified model |
| 9 | +llm = LLM(model="TinyLlama/TinyLlama-1.1B-Chat-v1.0") |
9 | 10 |
|
10 |
| -Equipped with the model definition and the weights, a user must use TensorRT-LLM's Model Definition API to recreate the model in a way that can be compiled by TensorRT into an efficient engine. For ease of use, TensorRT-LLM already supports a handful of standard models. |
| 11 | +# Generate text using the model |
| 12 | +output = llm.generate("Hello, my name is") |
| 13 | +``` |
11 | 14 |
|
12 |
| -Together with the Model Definition API to describe models, TensorRT-LLM provides users with components to create a runtime that executes the efficient TensorRT engine. Runtime components offer beam-search, along with extensive sampling functionalities such as top-K and top-P sampling. The exhaustive list can be found in the documentation of the {ref}`gpt-runtime`. The C++ runtime is the recommended runtime. |
| 15 | +The `LLM` class automatically manages essential pre and post-processing steps, including tokenization (encoding input prompts into numerical representations) and detokenization (decoding model outputs back into human-readable text). |
13 | 16 |
|
14 |
| -TensorRT-LLM also includes Python and C++ backends for NVIDIA Triton Inference Server to assemble solutions for LLM online serving. The C++ backend implements in-flight batching as explained in the {ref}`executor` documentation and is the recommended backend. |
| 17 | +Internally, the `LLM` class orchestrates the creation of a dedicated `PyExecutor(Worker)` process on each rank. |
15 | 18 |
|
16 |
| -## Model Weights |
| 19 | + |
17 | 20 |
|
18 |
| -TensorRT-LLM is a library for LLM inference, and so to use it, you need to supply a set of trained weights. You can either use your own model weights trained in a framework like [NVIDIA NeMo](https://www.nvidia.com/en-us/ai-data-science/generative-ai/nemo-framework/) or pull a set of pretrained weights from repositories like the Hugging Face Hub. |
| 21 | +This `PyExecutor` operates in a continuous background loop, designed for the efficient, asynchronous processing of inference requests. |
| 22 | + |
| 23 | +The `PyExecutor`'s functionality is built upon several key components: |
| 24 | + |
| 25 | +- `Scheduler`: Responsible for determining which active requests are ready for execution at each processing step. |
| 26 | + |
| 27 | +- `KVCacheManager`: Manages the allocation, deallocation, and maintenance of the Key-Value (KV) Cache. This is a critical optimization for Transformer models, significantly enhancing performance during autoregressive text generation by storing previously computed attention keys and values. |
| 28 | + |
| 29 | +- `ModelEngine`: Handles the loading and highly efficient execution of the language model on the GPU hardware. |
| 30 | + |
| 31 | +- `Sampler`: Takes the raw outputs (logits) from the ModelEngine and applies appropriate sampling strategies (e.g., greedy, top-k, top-p, beam search) to generate the final output tokens. |
| 32 | + |
| 33 | +During each iteration of its background loop, the `PyExecutor` performs the following sequence of operations: |
| 34 | + |
| 35 | +- Request Fetching: Retrieves new inference requests from an internal request queue, if available. |
| 36 | + |
| 37 | +- Scheduling: Interacts with the `Scheduler` to identify and prioritize requests that are ready to be processed in the current step. |
| 38 | + |
| 39 | +- Resource Preparation: Coordinates with the `KVCacheManager` to ensure that the necessary Key-Value (KV) Cache resources are allocated for the selected requests. |
| 40 | + |
| 41 | +- Model Execution: Invokes the `ModelEngine` to perform a forward pass on the scheduled requests, predicting the next output tokens. |
| 42 | + |
| 43 | +- Output Handling: Updates the partial outputs for ongoing requests and finalizes the results for any requests that have reached completion, returning them to the user. |
| 44 | + |
| 45 | + |
| 46 | +## Runtime Optimizations |
| 47 | + |
| 48 | +TensorRT-LLM enhances inference throughput and reduces latency by integrating a suite of runtime optimizations, including CUDA Graph, [Overlap Scheduler](../features/overlap-scheduler.md), [Speculative decoding](../features/speculative-decoding.md), etc. |
| 49 | + |
| 50 | +### CUDA Graph |
| 51 | + |
| 52 | +CUDA Graphs drastically reduce the CPU-side overhead associated with launching GPU kernels, which is particularly impactful in PyTorch-based inference where Python's host-side code can be a bottleneck. By capturing a sequence of CUDA operations as a single graph, the entire sequence can be launched with one API call, minimizing CPU-GPU synchronization and driver overhead. |
| 53 | + |
| 54 | +To maximize the "hit rate" of these cached graphs, TensorRT-LLM employs CUDA Graph padding. If an incoming batch's size doesn't match a captured graph, it's padded to the nearest larger, supported size for which a graph exists. While this incurs minor overhead from computing "wasted" tokens, it's often a better trade-off than falling back to slower eager mode execution. This optimization has a significant impact, demonstrating up to a 22% end-to-end throughput increase on certain models and hardware. |
| 55 | + |
| 56 | +### Overlap Scheduler |
| 57 | + |
| 58 | +The Overlap Scheduler maximizes GPU utilization by hiding CPU-bound latency behind GPU computation. |
| 59 | + |
| 60 | +The key strategy is to launch the GPU's work for the next step (n+1) immediately, without waiting for the CPU to finish processing the results of the current step (n). This allows the CPU to handle tasks like checking stop criteria or updating responses for one batch while the GPU is already executing the model for the subsequent batch. |
| 61 | + |
| 62 | +This concurrent execution pipeline is illustrated in the `PyExecutor`'s logic: |
| 63 | + |
| 64 | +```python |
| 65 | +# Schedule and launch GPU work for the current step (n) |
| 66 | +scheduled_batch, _, _ = self._schedule() |
| 67 | +batch_outputs = self._forward_step(scheduled_batch, previous_tensors_device) |
| 68 | +sample_state = self._sample_async(scheduled_batch, batch_outputs) |
| 69 | + |
| 70 | +# While the GPU is busy, process the CPU-bound results from the previous step (n-1) |
| 71 | +if self.previous_batch is not None: |
| 72 | + self._process_previous_batch() |
| 73 | +``` |
| 74 | + |
| 75 | +This approach effectively reduces GPU idle time and improves overall hardware occupancy. While it introduces one extra decoding step into the pipeline, the resulting throughput gain is a significant trade-off. For this reason, the Overlap Scheduler is enabled by default in TensorRT-LLM. |
0 commit comments