-
Notifications
You must be signed in to change notification settings - Fork 1.8k
[None][feat] Implement advanced sampling for one model path mtp/eagle #6245
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -235,6 +235,113 @@ def is_generation_model(self) -> bool: | |||||
| return False | ||||||
|
|
||||||
|
|
||||||
| def forward_native( | ||||||
| logits: torch.Tensor, | ||||||
| k: Optional[torch.Tensor], | ||||||
| p: Optional[torch.Tensor], | ||||||
| ) -> torch.Tensor: | ||||||
| """ | ||||||
| PyTorch-native implementation of top-k and top-p sampling. | ||||||
|
|
||||||
| The logits tensor may be updated in-place. | ||||||
| """ | ||||||
| logits = apply_top_k_top_p(logits, k, p) | ||||||
| probs = logits.softmax(dim=-1, dtype=torch.float32) | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we can technically skip this softmax
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think I am also wrong here
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think @IzzyPutterman is right: Every time logits/probs are masked, it is sufficient to renormalize the probs such that they sum to one, which is much cheaper than computing softmax. This is probably also why https://docs.flashinfer.ai/api/sampling.html uses function names like Note that much of this is already worked out in #8581, albeit using |
||||||
| return random_sample(probs) | ||||||
|
|
||||||
|
|
||||||
| def random_sample( | ||||||
| probs: torch.Tensor, | ||||||
| ) -> torch.Tensor: | ||||||
| """Randomly sample from the probabilities. | ||||||
|
|
||||||
| We use this function instead of torch.multinomial because torch.multinomial | ||||||
| causes CPU-GPU synchronization. | ||||||
| """ | ||||||
| q = torch.empty_like(probs) | ||||||
| q.exponential_() | ||||||
| return probs.div_(q).argmax(dim=-1).view(-1) | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have to admit that I am not familiar with this sampling scheme. If you happen to have a literature reference at hand, I would be curious to learn more (perhaps also include a comment stating the name of the method). BTW,
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (Disclaimer: I might have well overlooked that
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think I have found the answer to my first question: This uses the "Gumbel max trick" (Coderabbit even points that out in it's review...), after variable transformation from log-probabilities to probabilities. Including a corresponding remark in the doc-string might be useful for future readers. |
||||||
|
|
||||||
|
|
||||||
| def apply_min_p( | ||||||
| logits: torch.Tensor, | ||||||
| min_p: torch.Tensor, | ||||||
| ) -> torch.Tensor: | ||||||
| """ | ||||||
| Filters logits using adaptive probability thresholding. | ||||||
| """ | ||||||
| # Convert logits to probability distribution | ||||||
| probability_values = torch.nn.functional.softmax(logits, dim=-1) | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this effectively neutralizes the temperature right? We apply temp then softmax again which undoes the scaling
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Or perhaps in
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nvm im wrong here, misread something |
||||||
| # Calculate maximum probabilities per sequence | ||||||
| max_probabilities = torch.amax(probability_values, dim=-1, keepdim=True) | ||||||
| # Reshape min_p for broadcasting | ||||||
| adjusted_min_p = min_p.unsqueeze(1) * max_probabilities | ||||||
| # Identify valid tokens using threshold comparison | ||||||
| valid_token_mask = probability_values >= adjusted_min_p | ||||||
| # Apply mask using boolean indexing | ||||||
| logits[~valid_token_mask] = -float("inf") | ||||||
| return logits | ||||||
|
|
||||||
|
|
||||||
| def apply_top_k_top_p( | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note that we already have
and some related functions. We should not duplicate this basic functionality, so let's use the existing functions and extend them as necessary (adding the Also note that I am working on FlashInfer.sampling based alternatives for those functions. This upcoming PR brings support for Ideally, this PR could (i) improve the existing sampling routines and (ii) use them via
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Cf. #6245 (comment)
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Put up #8581 (work in progress!) to give an idea of what to expect.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See also TRTLLM-7723 (and TRTLLM-7152) for scope of ongoing work.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think I tested flashinfer with cuda graphs and it was breaking a bunch. With the generator objects its quite annoying in TRTLLM becuase in warmup we alternate between cuda graph warmup and non-cuda graph warmup, which breaks
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Worth a double check ofc, perhaps there is an easy way around it
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @ixlmar I think the current implementation of TopK TopP only allows all the request having the same TopK TopP value instead of individual requests having different values, please correct me if I'm wrong. The current logic in model_engine.py didn't parse out all the sampling params into GPU tensors for cuda graph, this PR enables that.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @IzzyPutterman The idea of #8581 is to allow choosing between the sampling routines we have today in
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @jhaotingc Correct. This was what I meant in the first comment:
Ideally, this PR could extend |
||||||
| logits: torch.Tensor, | ||||||
| k: Optional[torch.Tensor], | ||||||
| p: Optional[torch.Tensor], | ||||||
| ) -> torch.Tensor: | ||||||
| """Apply top-k and top-p masks to the logits. | ||||||
|
|
||||||
| If a top-p is used, this function will sort the logits tensor, | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As a perf optimization, should we skip the expensive sorting / softmax / cumsum ops for top_p >=1?
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If top_p is 1, we can skip the expensive sorting / softmax / cumsum ops. In the latest trt llm version, it is already implemented. Please refer to https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/_torch/pyexecutor/sampling_utils.py#L159-L171.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The skipping is not possible because in regular decoding the sampling is not captured in cuda graph. |
||||||
| which can be slow for large batches. | ||||||
|
|
||||||
| The logits tensor may be updated in-place. | ||||||
| """ | ||||||
| logits_sort, logits_idx = logits.sort(dim=-1, descending=False) | ||||||
amukkara marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
| if k is not None: | ||||||
| # Apply top-k. | ||||||
| top_k_mask = logits_sort.size(1) - k.to(torch.long) # shape: B | ||||||
| top_k_mask = top_k_mask.clamp(min=0) | ||||||
| # Get all the top_k values. | ||||||
| top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1)) | ||||||
| top_k_mask = logits_sort < top_k_mask | ||||||
| logits_sort.masked_fill_(top_k_mask, -float("inf")) | ||||||
| if p is not None: | ||||||
| # Apply top-p. | ||||||
| probs_sort = logits_sort.softmax(dim=-1) | ||||||
| probs_sum = torch.cumsum(probs_sort, dim=-1, out=probs_sort) | ||||||
| top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1) | ||||||
| # at least one | ||||||
| top_p_mask[:, -1] = False | ||||||
| logits_sort.masked_fill_(top_p_mask, -float("inf")) | ||||||
| # Re-sort the probabilities. | ||||||
| logits = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort) | ||||||
| return logits | ||||||
|
|
||||||
|
|
||||||
| def apply_temperature( | ||||||
| logits: torch.Tensor, | ||||||
| temp: torch.Tensor, | ||||||
| ) -> torch.Tensor: | ||||||
| # Use in-place division to avoid creating a new tensor. | ||||||
| return logits.div_(temp.unsqueeze(dim=1)) | ||||||
|
|
||||||
|
|
||||||
| @torch.compile(options={"max-autotune": True}) | ||||||
| def sampling_batch_spec_dec_one_model( | ||||||
| logits: torch.Tensor, | ||||||
| temperatures: torch.Tensor, | ||||||
| top_k: torch.Tensor, | ||||||
| top_p: torch.Tensor, | ||||||
| min_p: torch.Tensor, | ||||||
| ) -> tuple[torch.Tensor, torch.Tensor]: | ||||||
| raw_probs = torch.softmax(logits, dim=-1) | ||||||
| logits = apply_temperature(logits, temperatures) | ||||||
| logits = apply_min_p(logits, min_p) | ||||||
| random_sampled = forward_native(logits, top_k, top_p) | ||||||
| token_probs = torch.gather(raw_probs, dim=1, index=random_sampled.unsqueeze(1)).squeeze(-1) | ||||||
| log_probs = torch.log(token_probs) | ||||||
| return random_sampled, log_probs | ||||||
|
|
||||||
|
|
||||||
| # Due to tensorrt_llm::runtime::SamplingConfig using vectors, params | ||||||
| # in LlmRequest.sampling_params are either None or single-element lists. | ||||||
| # This helper method simplifies code using such params. | ||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The resolution of
request.sampling_configto sampling strategy has been cleaned up in #8132. See PR description for the intended semantics. The relevant function isTensorRT-LLM/tensorrt_llm/_torch/pyexecutor/sampler.py
Line 261 in 3a5845e
The existing function covers various corner cases already (e.g. temperature=0, top_p=1, etc.) and has extensive unit tests. Consider reusing this function here (perhaps make it "public", i.e., rename to something that does not start with
_).