Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 1 addition & 13 deletions tensorrt_llm/_torch/models/modeling_speculative.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, Generic, Optional, Tuple
from typing import Dict, Generic, Optional, Tuple

import torch
from torch import nn
Expand Down Expand Up @@ -293,18 +293,6 @@ def load_weights_from_target_model(self,
if self.load_lm_head_from_target:
self.lm_head = target_model.lm_head

# TODO: should input/position IDs be included in this? Keeping it implicit
# for now since the shapes/dtypes are the same across all models we have.
def get_warmup_extra_inputs(self, batch_size: int,
num_tokens: int) -> Dict[str, Any]:

hidden_states = torch.empty(batch_size * num_tokens,
self.model.hidden_size,
dtype=self.model.dtype,
device='cuda')

return {'hidden_states': hidden_states}

def apply_eagle3_fc(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""
Hack for eagle3. We might need to run a matmul to reduce
Expand Down
49 changes: 44 additions & 5 deletions tensorrt_llm/_torch/pyexecutor/model_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import weakref
from abc import ABC, abstractmethod
from contextlib import contextmanager
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Callable, Dict, List, Optional, Tuple

import torch
import torch._dynamo.config
Expand All @@ -21,6 +21,7 @@
from tensorrt_llm._torch.pyexecutor.sampler import SampleStateTensors
from tensorrt_llm._torch.speculative import (
get_num_extra_kv_tokens, update_spec_config_from_model_config)
from tensorrt_llm._torch.speculative.drafting_loops import ChainDrafter
from tensorrt_llm._torch.speculative.mtp import SampleStateTensorsMTP
from tensorrt_llm._utils import (is_trace_enabled, nvtx_range, release_gc,
str_dtype_to_torch, torch_dtype_to_str,
Expand Down Expand Up @@ -276,6 +277,8 @@ def __init__(
spec_config: Optional["DecodingBaseConfig"] = None,
lora_config: Optional[LoraConfig] = None,
is_draft_model: bool = False,
drafting_loop_wrapper: Optional[Callable[[torch.nn.Module],
torch.nn.Module]] = None,
):
self.ub_buffers = None
self.batch_size = batch_size
Expand Down Expand Up @@ -311,7 +314,8 @@ def __init__(
max_num_tokens=max_num_tokens,
moe_max_num_tokens=pytorch_backend_config.moe_max_num_tokens,
moe_load_balancer=pytorch_backend_config.moe_load_balancer,
lora_config=lora_config)
lora_config=lora_config,
drafting_loop_wrapper=drafting_loop_wrapper)
# In case that some tests use stub models and override `_load_model`.
if not hasattr(self.model, 'extra_attrs'):
self.model.extra_attrs = {}
Expand Down Expand Up @@ -403,7 +407,7 @@ def __init__(
dtype=torch.int,
device='cuda')
self.without_logits = self.spec_config.spec_dec_mode.without_logits(
)
) or self.model_is_wrapped
self.max_draft_len = spec_config.max_draft_len
else:
self.without_logits = False
Expand Down Expand Up @@ -562,21 +566,33 @@ def warmup(self, resource_manager: ResourceManager) -> None:
# Reset the global cuda graph dummy request to None in warmup.
self.cuda_graph_runner.padding_dummy_request = None

def get_num_extra_decoding_steps():
if isinstance(self.model, ChainDrafter):
return self.model.max_draft_len
else:
assert not self.model_is_wrapped, (
f"Please add logic to determine num_extra_decoding_steps for drafting loop {type(self.model)}"
)
return 0

def get_cuda_graph_warmup_request(batch_size, draft_len):
# Divide by max_beam_width to get an approximation of the number of requests that can be run in parallel.
available_blocks = kv_cache_manager.get_num_free_blocks(
) // self.max_beam_width
if available_blocks >= batch_size:
result = ScheduledRequests()
result.context_requests = []
num_extra_decoding_steps = get_num_extra_decoding_steps()

# Add (batch_size - 1) dummy requests with seq_len=1.
# Should only need one more page per request.
requests = kv_cache_manager.add_dummy_requests(
list(range(batch_size - 1)),
is_gen=True,
max_num_draft_tokens=draft_len,
use_mrope=use_mrope,
max_beam_width=self.max_beam_width)
max_beam_width=self.max_beam_width,
num_extra_decoding_steps=num_extra_decoding_steps)
# Divide by max_beam_width to get an approximation of the number of tokens that can be added to the final request.
available_tokens = kv_cache_manager.get_num_available_tokens(
draft_len)
Expand All @@ -592,13 +608,20 @@ def get_cuda_graph_warmup_request(batch_size, draft_len):
if max_position_embeddings is not None:
token_num = min(token_num,
max_position_embeddings - draft_len)

assert token_num > num_extra_decoding_steps, (
"Cannot fuse drafting loop. We do not have enough KV cache space "
"for all of the draft tokens.")
token_num -= num_extra_decoding_steps

max_seq_len_request = kv_cache_manager.add_dummy_requests(
request_ids=[batch_size - 1],
token_nums=[token_num],
is_gen=True,
max_num_draft_tokens=draft_len,
use_mrope=use_mrope,
max_beam_width=self.max_beam_width)[0]
max_beam_width=self.max_beam_width,
num_extra_decoding_steps=num_extra_decoding_steps)[0]
# Add the longest request before all other seq_len=1 request to simulate the padding CUDA graph case.
# This batch contains both the longest request and the shortest requests,
# it also contains the maximum number of requests and the maximum token number,
Expand All @@ -620,6 +643,13 @@ def get_warmup_request(num_tokens: int, num_gen_tokens: int):
if num_tokens > self.max_num_tokens or num_tokens > available_tokens:
return None

num_extra_decoding_steps = get_num_extra_decoding_steps()
if num_extra_decoding_steps > 0:
# Disable autotuning for fused drafting loops for now.
# There are a few bugs that can cause illegal memory accesses
# during warmup.
return None

num_ctx_tokens = num_tokens - num_gen_tokens
num_ctx_requests = 0
ctx_requests = []
Expand Down Expand Up @@ -905,6 +935,8 @@ def _load_model(self,
moe_max_num_tokens: Optional[int] = None,
moe_load_balancer: Optional[MoeLoadBalancerConfig] = None,
lora_config: Optional[LoraConfig] = None,
drafting_loop_wrapper: Optional[Callable[
[torch.nn.Module], torch.nn.Module]] = None,
**kwargs) -> DecoderModelForCausalLM:
config = checkpoint_loader.load_config(
checkpoint_dir,
Expand Down Expand Up @@ -1008,6 +1040,13 @@ def init_meta_tensor(t: torch.Tensor):
logger.info("moe_load_balancer finalize model done")

torch.cuda.current_stream().synchronize()

if drafting_loop_wrapper is not None:
model = drafting_loop_wrapper(model)
self.model_is_wrapped = True
else:
self.model_is_wrapped = False

return model

def _call_load_weights(self, load_method, weights, weight_mapper):
Expand Down
23 changes: 20 additions & 3 deletions tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,13 +260,29 @@ def create_py_executor(
with mem_monitor.observe_creation_stage(
_ExecutorCreationStage.MODEL_ENGINE_DRAFT):
draft_spec_config = copy.copy(spec_config)
draft_pytorch_backend_config = copy.copy(pytorch_backend_config)
if spec_config.load_format == "dummy":
draft_pytorch_backend_config.load_format = LoadFormat.DUMMY
# The draft model won't have any draft tokens attached to
# generation requests when we invoke it autoregressively
draft_spec_config.max_draft_len = 0

use_chain_drafter = (
executor_config.guided_decoding_config is None
and not pytorch_backend_config.enable_mixed_sampler
and pytorch_backend_config.attn_backend == "TRTLLM")

if use_chain_drafter:

def drafting_loop_wrapper(model):
from tensorrt_llm._torch.speculative.drafting_loops import \
ChainDrafter

return ChainDrafter(spec_config.max_draft_len, model)
else:
drafting_loop_wrapper = None

draft_pytorch_backend_config = copy.copy(pytorch_backend_config)
if spec_config.load_format == "dummy":
draft_pytorch_backend_config.load_format = LoadFormat.DUMMY

draft_model_engine = PyTorchModelEngine(
model_path=spec_config.speculative_model_dir,
pytorch_backend_config=draft_pytorch_backend_config,
Expand All @@ -282,6 +298,7 @@ def create_py_executor(
spec_config=draft_spec_config,
checkpoint_loader=executor_config.checkpoint_loader,
is_draft_model=True,
drafting_loop_wrapper=drafting_loop_wrapper,
)
draft_model_engine.kv_cache_manager_key = ResourceManagerType.DRAFT_KV_CACHE_MANAGER
draft_model_engine.load_weights_from_target_model(
Expand Down
10 changes: 10 additions & 0 deletions tensorrt_llm/_torch/pyexecutor/resource_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,11 @@ def add_dummy_requests(
max_num_draft_tokens: int = 0,
use_mrope: bool = False,
max_beam_width: int = 1,
# For capturable drafting loops. During normal inference, the draft model always
# has enough KV cache space to fit all of our draft tokens. During warmup, however,
# we need to make the KV cache manager aware that multiple autoregressive steps will
# occur.
num_extra_decoding_steps: int = 0,
):
beam_width = max_beam_width
requests = []
Expand Down Expand Up @@ -502,6 +507,10 @@ def add_dummy_requests(
self.impl.add_sequence(req_id, token_num, beam_width, req)
for _ in range(self.num_extra_kv_tokens):
self.impl.add_token(req_id)

for _ in range(num_extra_decoding_steps):
self.impl.add_token(req_id)

if is_gen:
req.state = LlmRequestState.GENERATION_IN_PROGRESS
req.prompt_len = token_num - 1
Expand All @@ -510,6 +519,7 @@ def add_dummy_requests(
if prepare_resource:
for _ in range(max_num_draft_tokens):
self.impl.add_token(req_id)

requests.append(req)
return requests

Expand Down
150 changes: 150 additions & 0 deletions tensorrt_llm/_torch/speculative/drafting_loops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
"""
This module contains capturable drafting loops for speculative decoding.
These are torch modules wrap another draft model. The wrapped module
is supposed to invoke the draft model autoregressively and invoke
a sampling algorithm to obtain draft tokens. By structuring the code
like this, we are able to avoid host overhead: the entire drafting process
for speculation can be launched as a single CUDA graph.
"""

from contextlib import contextmanager

import torch

from tensorrt_llm._torch.attention_backend.interface import AttentionMetadata
from tensorrt_llm._torch.speculative.eagle3 import Eagle3SpecMetadata
from tensorrt_llm._torch.speculative.interface import SpecMetadata


@contextmanager
def save_metadata_state(attn_metadata: AttentionMetadata,
spec_metadata: SpecMetadata) -> None:
batch_size = attn_metadata.num_seqs

if attn_metadata.is_cuda_graph:
seq_len = attn_metadata._seq_lens[:batch_size].clone()
seq_len_cuda = attn_metadata._seq_lens_cuda[:batch_size].clone()
kv_lens = attn_metadata.kv_lens_cuda.clone()

assert spec_metadata.is_cuda_graph
num_tokens = spec_metadata.num_tokens
if isinstance(spec_metadata, Eagle3SpecMetadata):
read_indices = spec_metadata.hidden_states_read_indices[:
batch_size].clone(
)
write_indices = spec_metadata.hidden_states_write_indices[:
batch_size].clone(
)

try:
yield
finally:
if attn_metadata.is_cuda_graph:
attn_metadata._seq_lens[:batch_size].copy_(seq_len[:batch_size])
attn_metadata._seq_lens_cuda[:batch_size].copy_(
seq_len_cuda[:batch_size])
attn_metadata.kv_lens_cuda[:batch_size].copy_(kv_lens[:batch_size])

spec_metadata.num_tokens = num_tokens
if isinstance(spec_metadata, Eagle3SpecMetadata):
spec_metadata.hidden_states_read_indices[:batch_size].copy_(
read_indices)
spec_metadata.hidden_states_write_indices[:batch_size].copy_(
write_indices)

# This restore has to happen even if the spec_metadata is not being used
# for CUDA graphs. It won't be reset by spec_metadata.prepare().
if isinstance(spec_metadata, Eagle3SpecMetadata):
spec_metadata.is_first_draft = True
spec_metadata.eagle3_resource_manager.is_first_draft = True


def prepare_for_generation(attn_metadata: AttentionMetadata,
spec_metadata: SpecMetadata,
last_tokens_idx: torch.Tensor) -> None:
batch_size = attn_metadata.num_seqs
attn_metadata._seq_lens[:batch_size].fill_(1)
attn_metadata._seq_lens_cuda[:batch_size].fill_(1)
attn_metadata.on_update()
attn_metadata.kv_lens_cuda[:batch_size] += 1

attn_metadata.host_request_types[:attn_metadata.num_contexts].fill_(1)
attn_metadata.num_contexts = 0

spec_metadata.num_tokens = batch_size

if isinstance(spec_metadata, Eagle3SpecMetadata):
spec_metadata.eagle3_resource_manager.is_first_draft = False
spec_metadata.is_first_draft = False

old_write_indices = spec_metadata.hidden_states_write_indices

spec_metadata.hidden_states_read_indices[:batch_size].copy_(
old_write_indices[last_tokens_idx])
spec_metadata.hidden_states_write_indices[:batch_size].copy_(
torch.arange(
batch_size,
dtype=spec_metadata.hidden_states_write_indices.dtype,
device=spec_metadata.hidden_states_write_indices.device))


class ChainDrafter(torch.nn.Module):

def __init__(self, max_draft_len: int, draft_model: torch.nn.Module):
super().__init__()
self.draft_model = draft_model
self.config = self.draft_model.config
self.model_config = self.draft_model.model_config
self.max_draft_len = max_draft_len

def forward(self, input_ids: torch.Tensor, position_ids: torch.Tensor,
attn_metadata: AttentionMetadata,
spec_metadata: AttentionMetadata, **kwargs) -> None:

logits = self.draft_model.forward(input_ids=input_ids,
position_ids=position_ids,
attn_metadata=attn_metadata,
spec_metadata=spec_metadata)

new_draft_tokens = [self.sample(logits)]

with save_metadata_state(attn_metadata, spec_metadata):
batch_size = attn_metadata.num_seqs
last_tokens_idx = torch.cumsum(
attn_metadata.seq_lens_cuda, dim=0, dtype=torch.long) - 1
new_position_ids = position_ids[0, last_tokens_idx] + 1

prepare_for_generation(attn_metadata, spec_metadata,
last_tokens_idx)

for i in range(self.max_draft_len - 1):
logits = self.draft_model.forward(
input_ids=new_draft_tokens[-1],
position_ids=new_position_ids,
attn_metadata=attn_metadata,
spec_metadata=spec_metadata)
new_draft_tokens.append(self.sample(logits))
new_position_ids += 1
attn_metadata.kv_lens_cuda[:batch_size] += 1
if i == 0 and isinstance(spec_metadata, Eagle3SpecMetadata):
spec_metadata.hidden_states_read_indices[:batch_size].copy_(
spec_metadata.hidden_states_write_indices[:batch_size])

return torch.stack(new_draft_tokens)

def sample(self, logits: torch.Tensor) -> torch.Tensor:
# TODO: inject the sampler here so we can support non-greedy
tokens = torch.argmax(logits, dim=-1)
if hasattr(self.draft_model.model, "d2t"):
d2t = self.draft_model.model.d2t.data
return tokens + d2t[tokens]

return tokens

def load_weights_from_target_model(self,
target_model: torch.nn.Module) -> None:
loader = getattr(self.draft_model, "load_weights_from_target_model",
None)
if callable(loader):
self.draft_model.load_weights_from_target_model(target_model)
Loading