10
10
from collections import deque
11
11
from dataclasses import dataclass
12
12
from time import perf_counter
13
- from typing import Dict , List , Optional , Tuple , Union
13
+ from typing import Any , Dict , List , Optional , Tuple , Union
14
14
15
15
import numpy as np
16
16
import transformers
17
17
from transformers import PreTrainedTokenizer , PreTrainedTokenizerFast
18
18
19
19
from QEfficient .generation .cloud_infer import QAICInferenceSession
20
20
from QEfficient .utils import padding_check_and_fix
21
+ from QEfficient .utils .constants import Constants
21
22
from QEfficient .utils .logging_utils import logger
23
+ from QEfficient .utils .sampler_utils import validate_sampler_inputs
22
24
23
25
24
26
@dataclass
@@ -322,6 +324,9 @@ def cloud_ai_100_exec_kv(
322
324
automation = False ,
323
325
prompt_to_lora_id_mapping : Optional [List [int ]] = None ,
324
326
is_tlm : bool = False ,
327
+ include_sampler : bool = False ,
328
+ return_pdfs : bool = False ,
329
+ sampling_params : Optional [Dict [str , Any ]] = None ,
325
330
):
326
331
"""
327
332
This method generates output until ``eos`` or ``generation_len`` by executing the compiled ``qpc`` on ``Cloud AI 100`` Hardware cards.
@@ -342,6 +347,15 @@ def cloud_ai_100_exec_kv(
342
347
:Write_io_dir (str): Path to write the input and output files. ``Defaults to None``.
343
348
:automation (bool): If true, it prints input, output, and performance stats. ``Defaults to False``.
344
349
:prompt_to_lora_id_mapping (List[int]): Mapping to associate prompts with their respective LoRA adapter.
350
+ :include_sampler (bool, default=False): Enable/Disable sampling of next tokens.
351
+ :return_pdfs (bool, default=False): Return probability distributions along with sampled
352
+ next tokens. For Speculative Decoding Target Language Model,
353
+ `return_pdfs`=True always. Otherwise, `return_pdfs`=True for Speculative
354
+ Decoding Draft Language Model and `return_pdfs`=False for regular model.
355
+ sampling_params (Dict[str, Any], default=None): A dictionary of sampling parameters supported by the QAIC backend.
356
+ The dictionary should contain the following keys:
357
+ `repetition_penalties`, `presence_penalties`, `temperatures`, `top_ks`, `top_ps`,
358
+ `min_ps`, and `random_numbers`. Each value should be a numpy array of shape (batch_size, 1).
345
359
346
360
Returns:
347
361
:CloudAI100ExecInfo: Object holding execution output and performance details.
@@ -372,6 +386,9 @@ def cloud_ai_100_exec_kv(
372
386
write_io_dir = write_io_dir ,
373
387
full_batch_size = full_batch_size ,
374
388
is_tlm = is_tlm ,
389
+ include_sampler = include_sampler ,
390
+ return_pdfs = return_pdfs ,
391
+ sampling_params = sampling_params ,
375
392
)
376
393
if full_batch_size is None :
377
394
exec_info = [
@@ -411,14 +428,24 @@ def __init__(
411
428
enable_debug_logs : bool = False ,
412
429
write_io_dir : Optional [str ] = None ,
413
430
is_tlm : Optional [int ] = None ,
431
+ include_sampler : bool = False ,
432
+ return_pdfs : bool = False ,
433
+ sampling_params : Optional [Dict [str , Any ]] = None ,
414
434
) -> None :
415
435
self ._ctx_len = ctx_len
416
436
self ._write_io_dir = write_io_dir
417
437
self .is_tlm = is_tlm
438
+ self .return_pdfs = return_pdfs
439
+ self .sampling_params = sampling_params
418
440
419
441
# Load QPC
420
442
self ._session = QAICInferenceSession (qpc_path , device_id , enable_debug_logs = enable_debug_logs )
421
443
444
+ # Validate sampler inputs for On-Device Sampling
445
+ self .include_sampler = validate_sampler_inputs (
446
+ session_inputs = set (self ._session .input_names ), include_sampler = include_sampler
447
+ )
448
+
422
449
# Fetch the variables from the QPC
423
450
self ._vocab_size = self ._fetch_vocab_size () # Fetch Vocab size
424
451
self .batch_size , self ._prefill_seq_len = self ._fetch_batch_size_prefill_seq_len ()
@@ -523,10 +550,17 @@ def _fetch_vocab_size(
523
550
Returns:
524
551
vocab_size: The vocabulary size fetched from the session's allowed shapes.
525
552
"""
553
+ key = (
554
+ "probs"
555
+ if self .include_sampler and self .return_pdfs
556
+ else "next_tokens"
557
+ if self .include_sampler
558
+ else "logits"
559
+ )
526
560
if self ._session .allowed_shapes :
527
- return [x [self ._session .binding_index_map ["logits" ]] for x in self ._session .allowed_shapes ][0 ][1 ][2 ]
561
+ return [x [self ._session .binding_index_map [key ]] for x in self ._session .allowed_shapes ][0 ][1 ][2 ]
528
562
529
- return self ._session .bindings [self ._session .binding_index_map ["logits" ]].dims [2 ]
563
+ return self ._session .bindings [self ._session .binding_index_map [key ]].dims [2 ]
530
564
531
565
def _fetch_generation_len (self , generation_len , max_gen_len ):
532
566
"""
@@ -574,6 +608,13 @@ def prepare_decode_inputs(self):
574
608
decode_inputs ["position_ids" ] = self .decode_pos_ids
575
609
if self .batch_index is not None :
576
610
decode_inputs ["batch_index" ] = self .batch_index
611
+ if self .include_sampler :
612
+ decode_inputs ["last_accepted_output_tokens" ] = decode_inputs ["input_ids" ]
613
+ for op in Constants .SAMPLER_OPS :
614
+ if self .batch_index is not None :
615
+ decode_inputs [op ] = self .sampling_params [op ][self .batch_index .flatten ()]
616
+ else :
617
+ decode_inputs [op ] = self .sampling_params [op ]
577
618
578
619
if self ._prompt_to_lora_id_mapping_decode :
579
620
if self .full_batch_size :
@@ -589,21 +630,24 @@ def prepare_decode_inputs(self):
589
630
590
631
def _fetch_next_token_id (self , outputs ):
591
632
"""
592
- Fetches the next token ID from the model's output logits .
593
- The method identifies the token with the highest probability using argmax along the last dimension.
633
+ Fetches the next token ID from the model's output.
634
+
594
635
Args:
595
- outputs (dict): A dictionary containing the model's output logits. The key "logits" should map to a numpy array of shape (batch_size, sequence_length, vocab_size) or (batch_size, vocab_size) .
636
+ outputs (dict): A dictionary containing the model's output.
596
637
597
638
Returns:
598
639
numpy.ndarray: An array of the next token IDs for each sequence in the batch.
599
640
"""
600
- logits = outputs ["logits" ]
601
- if len (logits .shape ) == 2 :
602
- logits = np .expand_dims (logits , 1 )
603
-
604
- # Get output token
605
- next_token_id = logits .argmax (2 )
606
- return next_token_id
641
+ if self .include_sampler :
642
+ if self .return_pdfs :
643
+ return outputs ["probs" ].argmax (2 )
644
+ else :
645
+ return outputs ["next_tokens" ].reshape (outputs ["next_tokens" ].shape [0 ], outputs ["next_tokens" ].shape [1 ])
646
+ else :
647
+ logits = outputs ["logits" ]
648
+ if len (logits .shape ) == 2 :
649
+ logits = np .expand_dims (logits , 1 )
650
+ return logits .argmax (2 )
607
651
608
652
def initialize_decode_inputs (self , num_prompts , execution_batch_size , max_gen_length ):
609
653
"""
@@ -673,6 +717,23 @@ def run_prefill_for_all_inputs(self, prompt_queue, generation_len):
673
717
674
718
_ = self .update_decode_input (outputs , position_ids , generation_len , decode_batch_id )
675
719
720
+ def _set_output_buffers (self , batch_size : int = 1 , sequence_length : int = 1 ):
721
+ """
722
+ Sets the sizes of the output buffers.
723
+
724
+ Args:
725
+ batch_size (int): The batch size.
726
+ """
727
+ if self .include_sampler :
728
+ if self .return_pdfs :
729
+ probs_out_placeholder = np .zeros ((batch_size , sequence_length , self ._vocab_size ), dtype = np .float32 )
730
+ self ._session .set_buffers ({"probs" : probs_out_placeholder })
731
+ next_tokens_out_placeholder = np .zeros ((batch_size , sequence_length , 1 ), dtype = np .int64 )
732
+ self ._session .set_buffers ({"next_tokens" : next_tokens_out_placeholder })
733
+ else :
734
+ logits_out_placeholder = np .zeros ((batch_size , sequence_length , self ._vocab_size ), dtype = np .float32 )
735
+ self ._session .set_buffers ({"logits" : logits_out_placeholder })
736
+
676
737
def run_prefill (self , prompt , generation_len , prefill_logit_bs = 1 , decode_batch_id = None ):
677
738
"""
678
739
Runs prefill for a given prompt and generation length.
@@ -702,9 +763,8 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i
702
763
max_gen_len = self ._ctx_len - position_ids .max ()
703
764
generation_len = self ._fetch_generation_len (generation_len , max_gen_len )
704
765
705
- # Set the prefill logic buffer
706
- logits_out_placeholder = np .zeros ((prefill_logit_bs , 1 , self ._vocab_size ), dtype = np .float32 )
707
- self ._session .set_buffers ({"logits" : logits_out_placeholder })
766
+ # Set the prefill output buffers
767
+ self ._set_output_buffers (batch_size = prefill_logit_bs , sequence_length = 1 )
708
768
709
769
inputs = self .tokenizer (prompt , return_tensors = "np" , padding = "max_length" , max_length = padded_len )
710
770
inputs ["position_ids" ] = np .where (inputs .pop ("attention_mask" ), np .arange (padded_len ), - 1 )
@@ -714,6 +774,13 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i
714
774
inputs ["batch_index" ] = decode_batch_id
715
775
if self .is_tlm :
716
776
inputs ["num_logits_to_keep" ] = np .zeros ((1 , 1 ))
777
+ if self .include_sampler :
778
+ inputs ["last_accepted_output_tokens" ] = inputs ["input_ids" ]
779
+ for op in Constants .SAMPLER_OPS :
780
+ if decode_batch_id is not None :
781
+ inputs [op ] = self .sampling_params [op ][decode_batch_id .flatten ()]
782
+ else :
783
+ inputs [op ] = self .sampling_params [op ]
717
784
718
785
if self ._prompt_to_lora_id_mapping_prefill :
719
786
if self .full_batch_size :
@@ -732,6 +799,8 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i
732
799
chunk_inputs ["position_ids" ] = inputs ["position_ids" ][
733
800
:, i * self ._prefill_seq_len : (i + 1 ) * self ._prefill_seq_len
734
801
]
802
+ if self .include_sampler :
803
+ chunk_inputs ["last_accepted_output_tokens" ] = chunk_inputs ["input_ids" ]
735
804
outputs = self ._session .run (chunk_inputs )
736
805
if self ._write_io_dir is not None :
737
806
write_io_files (inputs , outputs , self ._write_io_dir , "prefill" , "aic_batch_io" , True , False )
@@ -753,11 +822,12 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len):
753
822
754
823
"""
755
824
756
- # Set logits placeholder for decode
757
- logits_out_placeholder = np .zeros (
758
- (self .full_batch_size , self ._decode_seq_len , self ._vocab_size ), dtype = np .float32
825
+ # Set output placeholders for decode
826
+ self ._set_output_buffers (
827
+ batch_size = self .full_batch_size ,
828
+ sequence_length = self ._decode_seq_len ,
759
829
)
760
- self . _session . set_buffers ({ "logits" : logits_out_placeholder })
830
+
761
831
# Generate flag for tracking progress for each batch ID
762
832
current_decode_ongoing = np .full ((self .full_batch_size , 1 ), True )
763
833
@@ -775,10 +845,7 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len):
775
845
outputs = self ._session .run (decode_inputs )
776
846
777
847
# Prepare inputs for next iteration
778
- logits = outputs ["logits" ]
779
- if len (logits .shape ) == 2 :
780
- logits = np .expand_dims (logits , 1 )
781
- next_token_id = logits .argmax (2 )
848
+ next_token_id = self ._fetch_next_token_id (outputs )
782
849
783
850
for decode_batch_id in range (self .full_batch_size ):
784
851
if (
@@ -800,7 +867,10 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len):
800
867
self .generated_ids [batch_id_map [decode_batch_id ], 0 ] = new_token_id .squeeze (1 )
801
868
generated_id_current_index [decode_batch_id ] = 1
802
869
803
- self ._session .set_buffers ({"logits" : logits_out_placeholder })
870
+ self ._set_output_buffers (
871
+ batch_size = self .full_batch_size ,
872
+ sequence_length = self ._decode_seq_len ,
873
+ )
804
874
decode_pause_time += perf_counter () - start
805
875
806
876
if self ._prompt_to_lora_id_mapping_decode :
@@ -817,6 +887,8 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len):
817
887
self .generated_ids [batch_id_map [decode_batch_id ], generated_id_current_index [decode_batch_id ]] = (
818
888
next_token_id [decode_batch_id , - 1 ]
819
889
)
890
+ if self .include_sampler :
891
+ decode_inputs ["last_accepted_output_tokens" ] = decode_inputs ["input_ids" ]
820
892
821
893
generated_id_current_index [decode_batch_id ] += 1
822
894
@@ -852,10 +924,12 @@ def run_decode(self, decode_inputs, generation_len, streamer: Optional[transform
852
924
self ._write_io_dir = None
853
925
854
926
# Prepare inputs for next iteration
855
- decode_inputs ["input_ids" ] = outputs [ "logits" ]. argmax ( 2 )
927
+ decode_inputs ["input_ids" ] = self . _fetch_next_token_id ( outputs )
856
928
decode_inputs ["position_ids" ][:, - 1 ] += 1
857
929
self .generated_ids [:, num_token ] = decode_inputs ["input_ids" ][:, - 1 ]
858
930
finished_sequences |= decode_inputs ["input_ids" ] == self .tokenizer .eos_token_id
931
+ if self .include_sampler :
932
+ decode_inputs ["last_accepted_output_tokens" ] = decode_inputs ["input_ids" ]
859
933
860
934
if finished_sequences .all ():
861
935
break
@@ -905,9 +979,22 @@ def __init__(
905
979
enable_debug_logs : bool = False ,
906
980
write_io_dir : Optional [str ] = None ,
907
981
is_tlm : bool = False ,
982
+ include_sampler : bool = False ,
983
+ return_pdfs : bool = False ,
984
+ sampling_params : Optional [Dict [str , Any ]] = None ,
908
985
) -> None :
909
986
self ._qaic_model = QEffTextGenerationBase (
910
- tokenizer , qpc_path , full_batch_size , ctx_len , device_id , enable_debug_logs , write_io_dir , is_tlm
987
+ tokenizer = tokenizer ,
988
+ qpc_path = qpc_path ,
989
+ full_batch_size = full_batch_size ,
990
+ ctx_len = ctx_len ,
991
+ device_id = device_id ,
992
+ enable_debug_logs = enable_debug_logs ,
993
+ write_io_dir = write_io_dir ,
994
+ is_tlm = is_tlm ,
995
+ include_sampler = include_sampler ,
996
+ return_pdfs = return_pdfs ,
997
+ sampling_params = sampling_params ,
911
998
)
912
999
self ._full_batch_size = self ._qaic_model .full_batch_size
913
1000
self ._tokenizer = self ._qaic_model .tokenizer
0 commit comments