Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions tensorrt_llm/_torch/attention_backend/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,9 +336,10 @@ def restore_from_spec_dec(self) -> None:
def update_spec_dec_param(
self,
is_spec_decoding_enabled,
is_spec_dec_tree,
is_spec_dec_dynamic_tree,
max_draft_tokens,
spec_metadata,
spec_tree_manager,
max_draft_len,
max_total_draft_tokens,
spec_decoding_tensor: Optional['SpecDecodingTensor'] = None):
"""
Hook to be called when using TRTLLM attention backend in spec-dec mode.
Expand Down
90 changes: 51 additions & 39 deletions tensorrt_llm/_torch/attention_backend/trtllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1051,47 +1051,46 @@ def prepare_context_mla_with_cached_kv(self,
def update_spec_dec_param(
self,
is_spec_decoding_enabled,
is_spec_dec_tree,
is_spec_dec_dynamic_tree,
max_draft_tokens,
spec_metadata,
spec_tree_manager,
max_draft_len,
max_total_draft_tokens,
spec_decoding_tensor: Optional['SpecDecodingTensor'] = None,
):

if spec_decoding_tensor is not None:
spec_decoding_position_offsets = spec_decoding_tensor.position_offsets
spec_decoding_packed_mask = spec_decoding_tensor.packed_mask
spec_decoding_generation_lengths = spec_decoding_tensor.generation_lengths
spec_decoding_tensor.position_offsets
spec_decoding_tensor.packed_mask
spec_decoding_tensor.generation_lengths
else:
spec_decoding_position_offsets = None
spec_decoding_packed_mask = None
spec_decoding_generation_lengths = None
pass
# spec_dec mode should only be enabled for pre-Blackwell machines and when there's a spec-dec tree.
self.is_spec_decoding_enabled = is_spec_decoding_enabled and get_sm_version(
) < 100

self.is_spec_dec_tree = False if spec_tree_manager is None else True
self.is_spec_dec_dynamic_tree = False if spec_tree_manager is None else spec_tree_manager.use_dynamic_tree

if get_sm_version() >= 100:
if is_spec_dec_tree or is_spec_dec_dynamic_tree:
assert not is_spec_dec_tree, "Spec-dec tree is not supported on this machine. Please use a pre-Blackwell machine for a spec-dec tree."
assert not is_spec_dec_dynamic_tree, "Spec-dec dynamic tree is not supported on this machine. Please use a pre-Blackwell machine for a spec-dec dynamic tree."
if self.is_spec_dec_tree or self.is_spec_dec_dynamic_tree:
assert not self.is_spec_dec_tree, "Spec-dec tree is not supported on this machine. Please use a pre-Blackwell machine for a spec-dec tree."
assert not self.is_spec_dec_dynamic_tree, "Spec-dec dynamic tree is not supported on this machine. Please use a pre-Blackwell machine for a spec-dec dynamic tree."

# use_spec_decoding is default to true by default, change in runtime by layers / requests
self.use_spec_decoding = self.is_spec_decoding_enabled

self.is_spec_dec_tree = is_spec_dec_tree
self.is_spec_dec_dynamic_tree = is_spec_dec_dynamic_tree

# Parameters can be fixed and not changed during runtime if the
if self.is_spec_decoding_enabled:
self.spec_decoding_position_offsets = torch.empty(
[self.max_num_requests, max_draft_tokens + 1],
[self.max_num_requests, max_total_draft_tokens + 1],
dtype=torch.int,
device='cuda',
)

self.spec_decoding_packed_mask = torch.empty(
[
self.max_num_requests, max_draft_tokens + 1,
math.ceil((max_draft_tokens + 1) / 32)
self.max_num_requests, max_total_draft_tokens + 1,
math.ceil((max_total_draft_tokens + 1) / 32)
],
dtype=torch.int,
device='cuda',
Expand All @@ -1103,30 +1102,41 @@ def update_spec_dec_param(
device='cuda',
)

if self.is_spec_dec_dynamic_tree:
assert spec_decoding_position_offsets is not None, "spec_decoding_position_offsets is required for dynamic tree"
assert spec_decoding_packed_mask is not None, "spec_decoding_packed_mask is required for dynamic tree"
self.spec_decoding_position_offsets.copy_(
spec_decoding_position_offsets, non_blocking=True)
self.spec_decoding_packed_mask.copy_(spec_decoding_packed_mask,
non_blocking=True)
if spec_decoding_generation_lengths is not None:
self.spec_decoding_generation_lengths.copy_(
spec_decoding_generation_lengths, non_blocking=True)
# Prepare the spec-dec mask, position offset and generation length for static tree of dynamic tree.
# We only prepare the spec-dec mask, position offset and generation length for the target model here.
# For the drafter model, we will prepare them in the drafting loops.
is_target_model = not spec_metadata.is_draft_model
is_using_tree = self.is_spec_dec_tree or self.is_spec_dec_dynamic_tree
if is_target_model and is_using_tree:
assert spec_metadata.spec_dec_mode.is_eagle3(
), "Tree decoding is only supported for Eagle3 now"
# If is the dynamic tree
if self.is_spec_dec_dynamic_tree:
# TODO: add dynamic tree logic
assert False, "Dynamic tree is not supported yet"
# If is the static tree
else:
self.generate_spec_decoding_generation_length(
max_draft_tokens=max_draft_tokens)
self.spec_decoding_position_offsets[
:,
].copy_(spec_tree_manager.spec_dec_position_offsets[0, :],
non_blocking=True)
self.spec_decoding_packed_mask[:, :, :].copy_(
spec_tree_manager.spec_dec_packed_mask[0, :, :],
non_blocking=True)
self.spec_decoding_generation_lengths[:].fill_(
spec_tree_manager.max_total_draft_tokens + 1)
else:
# Prepare for the linear-tree.
# Populate the mask that won't change during inference phase.
self.generate_spec_decoding_position_offsets(
max_draft_tokens=max_draft_tokens)
max_total_draft_tokens=max_total_draft_tokens)
self.generate_spec_decoding_packed_mask(
max_draft_tokens=max_draft_tokens)
max_total_draft_tokens=max_total_draft_tokens)
self.generate_spec_decoding_generation_length(
max_draft_tokens=max_draft_tokens)
max_total_draft_tokens=max_total_draft_tokens)

def generate_spec_decoding_position_offsets(self, max_draft_tokens):
position_offset = torch.arange(max_draft_tokens + 1,
def generate_spec_decoding_position_offsets(self, max_total_draft_tokens):
position_offset = torch.arange(max_total_draft_tokens + 1,
dtype=torch.int,
device='cpu',
pin_memory=True)
Expand All @@ -1135,15 +1145,17 @@ def generate_spec_decoding_position_offsets(self, max_draft_tokens):
self.spec_decoding_position_offsets.copy_(position_offset,
non_blocking=True)

def generate_spec_decoding_packed_mask(self, max_draft_tokens):
dummy_idx = torch.arange(max_draft_tokens + 1)
def generate_spec_decoding_packed_mask(self, max_total_draft_tokens):
# TODO: fix this limitation
assert max_total_draft_tokens < 32, "max_total_draft_tokens should be less than 32, will be fixed later"
dummy_idx = torch.arange(max_total_draft_tokens + 1)
spec_decoding_packed_mask = torch.pow(2, dummy_idx + 1) - 1
self.spec_decoding_packed_mask[:, :, 0].copy_(spec_decoding_packed_mask,
non_blocking=True)

def generate_spec_decoding_generation_length(self, max_draft_tokens):
def generate_spec_decoding_generation_length(self, max_total_draft_tokens):
spec_decoding_generation_length = torch.full((self.max_num_requests, ),
max_draft_tokens + 1)
max_total_draft_tokens + 1)
self.spec_decoding_generation_lengths[:self.max_num_requests].copy_(
spec_decoding_generation_length, non_blocking=True)

Expand Down
Loading
Loading