-
Notifications
You must be signed in to change notification settings - Fork 1.7k
[TRTLLM-5627] feat: Implement pytorch sampler for MTP #6245
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
a7e7147
4b280e0
2282633
bc8a15f
b66befe
3be14a6
0ee79b3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -1,7 +1,7 @@ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
from abc import ABC, abstractmethod | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
from collections.abc import Iterable | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
from dataclasses import dataclass | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
from typing import Literal | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
from typing import Literal, Optional | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
import torch | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -151,6 +151,118 @@ def top_p_sampling_batch(logits: torch.Tensor, top_p: float = 0.9): | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return next_tokens, softmax | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def forward_native( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
logits: torch.Tensor, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
k: Optional[torch.Tensor], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
p: Optional[torch.Tensor], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
) -> torch.Tensor: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
PyTorch-native implementation of top-k and top-p sampling. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
The logits tensor may be updated in-place. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
logits = apply_top_k_top_p(logits, k, p) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
probs = logits.softmax(dim=-1, dtype=torch.float32) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return random_sample(probs) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def random_sample(probs: torch.Tensor, ) -> torch.Tensor: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"""Randomly sample from the probabilities. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
We use this function instead of torch.multinomial because torch.multinomial | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
causes CPU-GPU synchronization. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
q = torch.empty_like(probs) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
q.exponential_() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return probs.div_(q).argmax(dim=-1).view(-1) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def apply_min_p( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
logits: torch.Tensor, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
min_p: torch.Tensor, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
) -> torch.Tensor: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Filters logits using adaptive probability thresholding. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# Convert logits to probability distribution | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
probability_values = torch.nn.functional.softmax(logits, dim=-1) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# Calculate maximum probabilities per sequence | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
max_probabilities = torch.amax(probability_values, dim=-1, keepdim=True) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# Reshape min_p for broadcasting | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
adjusted_min_p = min_p.unsqueeze(1) * max_probabilities | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# Identify valid tokens using threshold comparison | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
valid_token_mask = probability_values >= adjusted_min_p | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# Apply mask using boolean indexing | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
logits[~valid_token_mask] = -float('inf') | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return logits | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def apply_top_k_top_p( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
logits: torch.Tensor, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
k: Optional[torch.Tensor], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
p: Optional[torch.Tensor], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
) -> torch.Tensor: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"""Apply top-k and top-p masks to the logits. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
If a top-p is used, this function will sort the logits tensor, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
which can be slow for large batches. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
The logits tensor may be updated in-place. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
logits_sort, logits_idx = logits.sort(dim=-1, descending=False) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if k is not None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# Apply top-k. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
top_k_mask = logits_sort.size(1) - k.to(torch.long) # shape: B | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
top_k_mask = top_k_mask.clamp(min=0) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# Get all the top_k values. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
top_k_mask = logits_sort < top_k_mask | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
logits_sort.masked_fill_(top_k_mask, -float("inf")) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if p is not None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# Apply top-p. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
probs_sort = logits_sort.softmax(dim=-1) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
probs_sum = torch.cumsum(probs_sort, dim=-1, out=probs_sort) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# at least one | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
top_p_mask[:, -1] = False | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
logits_sort.masked_fill_(top_p_mask, -float("inf")) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# Re-sort the probabilities. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
logits = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return logits | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def greedy_sample(logits: torch.Tensor) -> torch.Tensor: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return logits.argmax(dim=-1).view(-1) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def apply_temperature( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
logits: torch.Tensor, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
temp: torch.Tensor, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
) -> torch.Tensor: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# Use in-place division to avoid creating a new tensor. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return logits.div_(temp.unsqueeze(dim=1)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def sampling_batch(logits: torch.Tensor, temperatures: torch.Tensor, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
top_k: torch.Tensor, top_p: torch.Tensor, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
min_p: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
raw_probs = torch.softmax(logits, dim=-1) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
greedy_sampled = greedy_sample(logits) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
logits = apply_temperature(logits, temperatures) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
logits = apply_min_p(logits, min_p) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
random_sampled = forward_native(logits, top_k, top_p) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
next_tokens = torch.where( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
temperatures <= 1e-2, # Match the clamping threshold | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
greedy_sampled, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
random_sampled, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
out=greedy_sampled, # Reuse tensor | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
token_probs = torch.gather(raw_probs, dim=1, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
index=next_tokens.unsqueeze(1)).squeeze(-1) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
log_probs = torch.log(token_probs) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return next_tokens, log_probs | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+246
to
+264
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Unified sampling function is well-designed but has a potential numerical stability issue. The The log probability calculation at Line 262 could produce NaN values if token_probs = torch.gather(raw_probs, dim=1,
index=next_tokens.unsqueeze(1)).squeeze(-1)
- log_probs = torch.log(token_probs)
+ log_probs = torch.log(token_probs.clamp(min=1e-8)) This prevents log(0) = -inf issues that could propagate through the system. 📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def greedy_search_sampling_batch(logits): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
next_tokens = torch.argmax(logits, dim=-1) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
softmax = torch.softmax(logits, dim=-1) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Uh oh!
There was an error while loading. Please reload this page.