Skip to content

Commit b7775c9

Browse files
quic-sanisingsanisingquic-dhirajkuquic-rishinr
authored
Unit Tests for On Device Sampling (#463)
This PR adds the following Unit Tests for On Device Sampling: 1. `test_sampler_transform`: Test if `SamplerTransform` adds nodes at the output of a `QEffForCausalLM model` to enable the sampling of next tokens at the device (instead of the host) and returns the next tokens and/or probability distributions. 2. `test_greedy_sampling`: Test greedy sampling with QPC compiled with and without On Device Sampling. 3. `test_random_sampling`: Test random sampling with QPC compiled with and without On Device Sampling. --------- Signed-off-by: quic-sanising <[email protected]> Signed-off-by: sanising <[email protected]> Signed-off-by: Dhiraj Kumar Sah <[email protected]> Signed-off-by: Rishin Raj <[email protected]> Co-authored-by: sanising <[email protected]> Co-authored-by: Dhiraj Kumar Sah <[email protected]> Co-authored-by: Rishin Raj <[email protected]>
1 parent b9a8e7c commit b7775c9

File tree

11 files changed

+847
-31
lines changed

11 files changed

+847
-31
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
repos:
22
- repo: https://github.com/astral-sh/ruff-pre-commit
33
# Ruff version.
4-
rev: v0.5.2
4+
rev: v0.12.7
55
hooks:
66
# Run the linter.
77
- id: ruff

QEfficient/cloud/infer.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,6 @@ def main(
235235
tokenizer,
236236
prompts=prompt,
237237
device_id=device_group,
238-
prompt=prompt,
239238
prompts_txt_file_path=prompts_txt_file_path,
240239
generation_len=generation_len,
241240
)

QEfficient/generation/text_generation_inference.py

Lines changed: 114 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,17 @@
1010
from collections import deque
1111
from dataclasses import dataclass
1212
from time import perf_counter
13-
from typing import Dict, List, Optional, Tuple, Union
13+
from typing import Any, Dict, List, Optional, Tuple, Union
1414

1515
import numpy as np
1616
import transformers
1717
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
1818

1919
from QEfficient.generation.cloud_infer import QAICInferenceSession
2020
from QEfficient.utils import padding_check_and_fix
21+
from QEfficient.utils.constants import Constants
2122
from QEfficient.utils.logging_utils import logger
23+
from QEfficient.utils.sampler_utils import validate_sampler_inputs
2224

2325

2426
@dataclass
@@ -322,6 +324,9 @@ def cloud_ai_100_exec_kv(
322324
automation=False,
323325
prompt_to_lora_id_mapping: Optional[List[int]] = None,
324326
is_tlm: bool = False,
327+
include_sampler: bool = False,
328+
return_pdfs: bool = False,
329+
sampling_params: Optional[Dict[str, Any]] = None,
325330
):
326331
"""
327332
This method generates output until ``eos`` or ``generation_len`` by executing the compiled ``qpc`` on ``Cloud AI 100`` Hardware cards.
@@ -342,6 +347,15 @@ def cloud_ai_100_exec_kv(
342347
:Write_io_dir (str): Path to write the input and output files. ``Defaults to None``.
343348
:automation (bool): If true, it prints input, output, and performance stats. ``Defaults to False``.
344349
:prompt_to_lora_id_mapping (List[int]): Mapping to associate prompts with their respective LoRA adapter.
350+
:include_sampler (bool, default=False): Enable/Disable sampling of next tokens.
351+
:return_pdfs (bool, default=False): Return probability distributions along with sampled
352+
next tokens. For Speculative Decoding Target Language Model,
353+
`return_pdfs`=True always. Otherwise, `return_pdfs`=True for Speculative
354+
Decoding Draft Language Model and `return_pdfs`=False for regular model.
355+
sampling_params (Dict[str, Any], default=None): A dictionary of sampling parameters supported by the QAIC backend.
356+
The dictionary should contain the following keys:
357+
`repetition_penalties`, `presence_penalties`, `temperatures`, `top_ks`, `top_ps`,
358+
`min_ps`, and `random_numbers`. Each value should be a numpy array of shape (batch_size, 1).
345359
346360
Returns:
347361
:CloudAI100ExecInfo: Object holding execution output and performance details.
@@ -372,6 +386,9 @@ def cloud_ai_100_exec_kv(
372386
write_io_dir=write_io_dir,
373387
full_batch_size=full_batch_size,
374388
is_tlm=is_tlm,
389+
include_sampler=include_sampler,
390+
return_pdfs=return_pdfs,
391+
sampling_params=sampling_params,
375392
)
376393
if full_batch_size is None:
377394
exec_info = [
@@ -411,14 +428,24 @@ def __init__(
411428
enable_debug_logs: bool = False,
412429
write_io_dir: Optional[str] = None,
413430
is_tlm: Optional[int] = None,
431+
include_sampler: bool = False,
432+
return_pdfs: bool = False,
433+
sampling_params: Optional[Dict[str, Any]] = None,
414434
) -> None:
415435
self._ctx_len = ctx_len
416436
self._write_io_dir = write_io_dir
417437
self.is_tlm = is_tlm
438+
self.return_pdfs = return_pdfs
439+
self.sampling_params = sampling_params
418440

419441
# Load QPC
420442
self._session = QAICInferenceSession(qpc_path, device_id, enable_debug_logs=enable_debug_logs)
421443

444+
# Validate sampler inputs for On-Device Sampling
445+
self.include_sampler = validate_sampler_inputs(
446+
session_inputs=set(self._session.input_names), include_sampler=include_sampler
447+
)
448+
422449
# Fetch the variables from the QPC
423450
self._vocab_size = self._fetch_vocab_size() # Fetch Vocab size
424451
self.batch_size, self._prefill_seq_len = self._fetch_batch_size_prefill_seq_len()
@@ -523,10 +550,17 @@ def _fetch_vocab_size(
523550
Returns:
524551
vocab_size: The vocabulary size fetched from the session's allowed shapes.
525552
"""
553+
key = (
554+
"probs"
555+
if self.include_sampler and self.return_pdfs
556+
else "next_tokens"
557+
if self.include_sampler
558+
else "logits"
559+
)
526560
if self._session.allowed_shapes:
527-
return [x[self._session.binding_index_map["logits"]] for x in self._session.allowed_shapes][0][1][2]
561+
return [x[self._session.binding_index_map[key]] for x in self._session.allowed_shapes][0][1][2]
528562

529-
return self._session.bindings[self._session.binding_index_map["logits"]].dims[2]
563+
return self._session.bindings[self._session.binding_index_map[key]].dims[2]
530564

531565
def _fetch_generation_len(self, generation_len, max_gen_len):
532566
"""
@@ -574,6 +608,13 @@ def prepare_decode_inputs(self):
574608
decode_inputs["position_ids"] = self.decode_pos_ids
575609
if self.batch_index is not None:
576610
decode_inputs["batch_index"] = self.batch_index
611+
if self.include_sampler:
612+
decode_inputs["last_accepted_output_tokens"] = decode_inputs["input_ids"]
613+
for op in Constants.SAMPLER_OPS:
614+
if self.batch_index is not None:
615+
decode_inputs[op] = self.sampling_params[op][self.batch_index.flatten()]
616+
else:
617+
decode_inputs[op] = self.sampling_params[op]
577618

578619
if self._prompt_to_lora_id_mapping_decode:
579620
if self.full_batch_size:
@@ -589,21 +630,24 @@ def prepare_decode_inputs(self):
589630

590631
def _fetch_next_token_id(self, outputs):
591632
"""
592-
Fetches the next token ID from the model's output logits.
593-
The method identifies the token with the highest probability using argmax along the last dimension.
633+
Fetches the next token ID from the model's output.
634+
594635
Args:
595-
outputs (dict): A dictionary containing the model's output logits. The key "logits" should map to a numpy array of shape (batch_size, sequence_length, vocab_size) or (batch_size, vocab_size).
636+
outputs (dict): A dictionary containing the model's output.
596637
597638
Returns:
598639
numpy.ndarray: An array of the next token IDs for each sequence in the batch.
599640
"""
600-
logits = outputs["logits"]
601-
if len(logits.shape) == 2:
602-
logits = np.expand_dims(logits, 1)
603-
604-
# Get output token
605-
next_token_id = logits.argmax(2)
606-
return next_token_id
641+
if self.include_sampler:
642+
if self.return_pdfs:
643+
return outputs["probs"].argmax(2)
644+
else:
645+
return outputs["next_tokens"].reshape(outputs["next_tokens"].shape[0], outputs["next_tokens"].shape[1])
646+
else:
647+
logits = outputs["logits"]
648+
if len(logits.shape) == 2:
649+
logits = np.expand_dims(logits, 1)
650+
return logits.argmax(2)
607651

608652
def initialize_decode_inputs(self, num_prompts, execution_batch_size, max_gen_length):
609653
"""
@@ -673,6 +717,23 @@ def run_prefill_for_all_inputs(self, prompt_queue, generation_len):
673717

674718
_ = self.update_decode_input(outputs, position_ids, generation_len, decode_batch_id)
675719

720+
def _set_output_buffers(self, batch_size: int = 1, sequence_length: int = 1):
721+
"""
722+
Sets the sizes of the output buffers.
723+
724+
Args:
725+
batch_size (int): The batch size.
726+
"""
727+
if self.include_sampler:
728+
if self.return_pdfs:
729+
probs_out_placeholder = np.zeros((batch_size, sequence_length, self._vocab_size), dtype=np.float32)
730+
self._session.set_buffers({"probs": probs_out_placeholder})
731+
next_tokens_out_placeholder = np.zeros((batch_size, sequence_length, 1), dtype=np.int64)
732+
self._session.set_buffers({"next_tokens": next_tokens_out_placeholder})
733+
else:
734+
logits_out_placeholder = np.zeros((batch_size, sequence_length, self._vocab_size), dtype=np.float32)
735+
self._session.set_buffers({"logits": logits_out_placeholder})
736+
676737
def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_id=None):
677738
"""
678739
Runs prefill for a given prompt and generation length.
@@ -702,9 +763,8 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i
702763
max_gen_len = self._ctx_len - position_ids.max()
703764
generation_len = self._fetch_generation_len(generation_len, max_gen_len)
704765

705-
# Set the prefill logic buffer
706-
logits_out_placeholder = np.zeros((prefill_logit_bs, 1, self._vocab_size), dtype=np.float32)
707-
self._session.set_buffers({"logits": logits_out_placeholder})
766+
# Set the prefill output buffers
767+
self._set_output_buffers(batch_size=prefill_logit_bs, sequence_length=1)
708768

709769
inputs = self.tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len)
710770
inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1)
@@ -714,6 +774,13 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i
714774
inputs["batch_index"] = decode_batch_id
715775
if self.is_tlm:
716776
inputs["num_logits_to_keep"] = np.zeros((1, 1))
777+
if self.include_sampler:
778+
inputs["last_accepted_output_tokens"] = inputs["input_ids"]
779+
for op in Constants.SAMPLER_OPS:
780+
if decode_batch_id is not None:
781+
inputs[op] = self.sampling_params[op][decode_batch_id.flatten()]
782+
else:
783+
inputs[op] = self.sampling_params[op]
717784

718785
if self._prompt_to_lora_id_mapping_prefill:
719786
if self.full_batch_size:
@@ -732,6 +799,8 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i
732799
chunk_inputs["position_ids"] = inputs["position_ids"][
733800
:, i * self._prefill_seq_len : (i + 1) * self._prefill_seq_len
734801
]
802+
if self.include_sampler:
803+
chunk_inputs["last_accepted_output_tokens"] = chunk_inputs["input_ids"]
735804
outputs = self._session.run(chunk_inputs)
736805
if self._write_io_dir is not None:
737806
write_io_files(inputs, outputs, self._write_io_dir, "prefill", "aic_batch_io", True, False)
@@ -753,11 +822,12 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len):
753822
754823
"""
755824

756-
# Set logits placeholder for decode
757-
logits_out_placeholder = np.zeros(
758-
(self.full_batch_size, self._decode_seq_len, self._vocab_size), dtype=np.float32
825+
# Set output placeholders for decode
826+
self._set_output_buffers(
827+
batch_size=self.full_batch_size,
828+
sequence_length=self._decode_seq_len,
759829
)
760-
self._session.set_buffers({"logits": logits_out_placeholder})
830+
761831
# Generate flag for tracking progress for each batch ID
762832
current_decode_ongoing = np.full((self.full_batch_size, 1), True)
763833

@@ -775,10 +845,7 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len):
775845
outputs = self._session.run(decode_inputs)
776846

777847
# Prepare inputs for next iteration
778-
logits = outputs["logits"]
779-
if len(logits.shape) == 2:
780-
logits = np.expand_dims(logits, 1)
781-
next_token_id = logits.argmax(2)
848+
next_token_id = self._fetch_next_token_id(outputs)
782849

783850
for decode_batch_id in range(self.full_batch_size):
784851
if (
@@ -800,7 +867,10 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len):
800867
self.generated_ids[batch_id_map[decode_batch_id], 0] = new_token_id.squeeze(1)
801868
generated_id_current_index[decode_batch_id] = 1
802869

803-
self._session.set_buffers({"logits": logits_out_placeholder})
870+
self._set_output_buffers(
871+
batch_size=self.full_batch_size,
872+
sequence_length=self._decode_seq_len,
873+
)
804874
decode_pause_time += perf_counter() - start
805875

806876
if self._prompt_to_lora_id_mapping_decode:
@@ -817,6 +887,8 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len):
817887
self.generated_ids[batch_id_map[decode_batch_id], generated_id_current_index[decode_batch_id]] = (
818888
next_token_id[decode_batch_id, -1]
819889
)
890+
if self.include_sampler:
891+
decode_inputs["last_accepted_output_tokens"] = decode_inputs["input_ids"]
820892

821893
generated_id_current_index[decode_batch_id] += 1
822894

@@ -852,10 +924,12 @@ def run_decode(self, decode_inputs, generation_len, streamer: Optional[transform
852924
self._write_io_dir = None
853925

854926
# Prepare inputs for next iteration
855-
decode_inputs["input_ids"] = outputs["logits"].argmax(2)
927+
decode_inputs["input_ids"] = self._fetch_next_token_id(outputs)
856928
decode_inputs["position_ids"][:, -1] += 1
857929
self.generated_ids[:, num_token] = decode_inputs["input_ids"][:, -1]
858930
finished_sequences |= decode_inputs["input_ids"] == self.tokenizer.eos_token_id
931+
if self.include_sampler:
932+
decode_inputs["last_accepted_output_tokens"] = decode_inputs["input_ids"]
859933

860934
if finished_sequences.all():
861935
break
@@ -905,9 +979,22 @@ def __init__(
905979
enable_debug_logs: bool = False,
906980
write_io_dir: Optional[str] = None,
907981
is_tlm: bool = False,
982+
include_sampler: bool = False,
983+
return_pdfs: bool = False,
984+
sampling_params: Optional[Dict[str, Any]] = None,
908985
) -> None:
909986
self._qaic_model = QEffTextGenerationBase(
910-
tokenizer, qpc_path, full_batch_size, ctx_len, device_id, enable_debug_logs, write_io_dir, is_tlm
987+
tokenizer=tokenizer,
988+
qpc_path=qpc_path,
989+
full_batch_size=full_batch_size,
990+
ctx_len=ctx_len,
991+
device_id=device_id,
992+
enable_debug_logs=enable_debug_logs,
993+
write_io_dir=write_io_dir,
994+
is_tlm=is_tlm,
995+
include_sampler=include_sampler,
996+
return_pdfs=return_pdfs,
997+
sampling_params=sampling_params,
911998
)
912999
self._full_batch_size = self._qaic_model.full_batch_size
9131000
self._tokenizer = self._qaic_model.tokenizer

QEfficient/transformers/models/modeling_auto.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1365,7 +1365,7 @@ def __init__(
13651365
)
13661366
# Set use_cache=True to get KV values as output during ONNX export
13671367
model.config.use_cache = True
1368-
super().__init__(model, **kwargs)
1368+
super().__init__(model, qaic_config=qaic_config, **kwargs)
13691369
self.num_layers = model.config.num_hidden_layers
13701370
self.continuous_batching = continuous_batching
13711371
self.model.qaic_config = qaic_config
@@ -1379,6 +1379,8 @@ def __init__(
13791379
# are done. The role of the sampler is to just add nodes at the output of the
13801380
# previous transform function.
13811381
self.model, transformed = SamplerTransform.apply(self.model, qaic_config, **kwargs)
1382+
# TODO : Update in qaic_config isn't updated in the hash due to SpDTransforms. Need to move
1383+
# SpDTransforms to PytorchTransforms.
13821384
if self.is_tlm:
13831385
self.model.qaic_config["return_pdfs"] = True
13841386

@@ -1841,6 +1843,7 @@ def generate(
18411843
device_id=device_id,
18421844
generation_len=generation_len,
18431845
is_tlm=self.is_tlm,
1846+
**kwargs,
18441847
)
18451848
else:
18461849
raise NotImplementedError("Only AI_100 runtime is supported right now via generate API")

QEfficient/transformers/sampler/sampler.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,9 @@ def sampler_forward(
193193
batch_size, spec_length, vocab_size = logits.shape
194194
logits = logits.reshape(-1, vocab_size) # Reshape tensor to 2D
195195

196+
if batch_index is None: # Regular model execution
197+
batch_index = torch.arange(batch_size).view(-1, 1)
198+
196199
batch_index_reshaped = batch_index.view(-1)
197200
# Prefill
198201
past_repetition_penalty_buffer_prefill, past_presence_penalty_buffer_prefill = prefill_path(

QEfficient/utils/constants.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
"pretrained_model_name_or_path",
4242
"attn_implementation",
4343
"_attn_implementation",
44+
"qaic_config",
4445
]
4546

4647
# Minimum value for causal mask
@@ -130,6 +131,16 @@ class Constants:
130131
MAX_RETRIES = 10 # This constant will be used set the maximum number of retry attempts for downloading a model using huggingface_hub snapshot_download
131132
NUM_SPECULATIVE_TOKENS = 2
132133
MAX_TOP_K_IDS = ONNX_EXPORT_EXAMPLE_MAX_TOP_K_IDS
134+
SAMPLER_OPS = {
135+
"repetition_penalties",
136+
"presence_penalties",
137+
"temperatures",
138+
"top_ks",
139+
"top_ps",
140+
"min_ps",
141+
"random_numbers",
142+
}
143+
SAMPLER_INPUTS = SAMPLER_OPS | {"last_accepted_output_tokens"}
133144
SDK_APPS_XML = "/opt/qti-aic/versions/apps.xml" # This xml file is parsed to find out the SDK apps version.
134145
SDK_PLATFORM_XML = (
135146
"/opt/qti-aic/versions/platform.xml" # This xml file is parsed to find out the SDK platform version.

0 commit comments

Comments
 (0)