diff --git a/QEfficient/generation/text_generation_inference.py b/QEfficient/generation/text_generation_inference.py index 6d04cf573..fb9f612a1 100755 --- a/QEfficient/generation/text_generation_inference.py +++ b/QEfficient/generation/text_generation_inference.py @@ -356,7 +356,7 @@ def cloud_ai_100_exec_kv( Decoding Draft Language Model and `return_pdfs`=False for regular model. sampling_params (Dict[str, Any], default=None): A dictionary of sampling parameters supported by the QAIC backend. The dictionary should contain the following keys: - `repetition_penalties`, `presence_penalties`, `temperatures`, `top_ks`, `top_ps`, + `repetition_penalties`, `frequency_penalties`, `presence_penalties`, `temperatures`, `top_ks`, `top_ps`, `min_ps`, and `random_numbers`. Each value should be a numpy array of shape (batch_size, 1). Returns: diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 633a0b29d..c7edb2070 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -2364,13 +2364,18 @@ def get_sampling_inputs_and_outputs( dynamic_axes["repetition_penalties"] = {0: "batch_size"} example_inputs["past_presence_penalty_buffer"] = torch.zeros( - (fbs if self.continuous_batching else bs, self.model.config.vocab_size), dtype=torch.bool + (fbs if self.continuous_batching else bs, self.model.config.vocab_size), dtype=torch.int32 ) dynamic_axes["past_presence_penalty_buffer"] = { 0: "full_batch_size" if self.continuous_batching else "batch_size", } output_names.append("past_presence_penalty_buffer_RetainedState") + example_inputs["frequency_penalties"] = ( + torch.zeros((bs, 1), dtype=torch.float) + constants.ONNX_EXPORT_EXAMPLE_FREQUENCY_PENALTIES + ) + dynamic_axes["frequency_penalties"] = {0: "batch_size"} + example_inputs["presence_penalties"] = ( torch.zeros((bs, 1), dtype=torch.float) + constants.ONNX_EXPORT_EXAMPLE_PRESENCE_PENALTIES ) diff --git a/QEfficient/transformers/sampler/sampler.py b/QEfficient/transformers/sampler/sampler.py index 96846e712..9689a67c9 100644 --- a/QEfficient/transformers/sampler/sampler.py +++ b/QEfficient/transformers/sampler/sampler.py @@ -80,20 +80,19 @@ def decode_path( ) # Update retained states - scatter_values = torch.ones(last_accepted_output_tokens.shape, dtype=torch.bool) past_repetition_penalty_buffer = CtxScatterFuncCB3D.apply( past_repetition_penalty_buffer, batch_index, last_accepted_output_tokens, - scatter_values, + torch.ones(last_accepted_output_tokens.shape, dtype=torch.bool), ) + gather_values = past_presence_penalty_buffer[batch_index, last_accepted_output_tokens] past_presence_penalty_buffer = CtxScatterFuncCB3D.apply( past_presence_penalty_buffer, batch_index, last_accepted_output_tokens, - scatter_values, + gather_values + 1, ) - # TODO: For frequency retain state, first gather and then scatter return past_repetition_penalty_buffer, past_presence_penalty_buffer @@ -116,6 +115,7 @@ def sampler_forward( past_repetition_penalty_buffer: Optional[torch.Tensor] = None, repetition_penalties: Optional[torch.Tensor] = None, past_presence_penalty_buffer: Optional[torch.Tensor] = None, + frequency_penalties: Optional[torch.Tensor] = None, presence_penalties: Optional[torch.Tensor] = None, temperatures: Optional[torch.Tensor] = None, top_ks: Optional[torch.Tensor] = None, @@ -141,8 +141,13 @@ def sampler_forward( new tokens, while values < 1 encourage the model to repeat tokens. past_presence_penalty_buffer (`torch.Tensor`, *optional*): - RetainedState buffer used as a mask to apply presence penalty to the output - generated so far. + RetainedState buffer used as a mask to apply frequency and presence penalties to + the output generated so far. + + frequency_penalties (`torch.Tensor`, *optional*): + Sampling parameter that penalizes new tokens based on their frequency in the + generated text so far. Values > 0 encourage the model to use new tokens, while + values < 0 encourage the model to repeat tokens. presence_penalties (`torch.Tensor`, *optional*): Sampling parameter that penalizes new tokens based on whether they appear in the @@ -243,17 +248,24 @@ def sampler_forward( repetition_penalties_mask = torch.where(past_repetition_penalty_buffer_selected, repetition_penalties, 1.0) logits *= repetition_penalties_mask ** (-torch.sign(logits)) + if (frequency_penalties != 0.0).any() or (presence_penalties != 0.0).any(): + past_presence_penalty_buffer_selected = past_presence_penalty_buffer[batch_index_reshaped].repeat( + spec_length, 1 + ) # (batch_size * spec_length, vocab_size) + + # Frequency Penalty + if (frequency_penalties != 0.0).any(): + frequency_penalties = frequency_penalties.repeat( + spec_length, 1 + ) # (batch_size, 1) -> (batch_size * spec_length, 1) + logits -= frequency_penalties * past_presence_penalty_buffer_selected + # Presence Penalty if (presence_penalties != 0.0).any(): presence_penalties = presence_penalties.repeat( spec_length, 1 ) # (batch_size, 1) -> (batch_size * spec_length, 1) - past_presence_penalty_buffer_selected = past_presence_penalty_buffer[batch_index_reshaped].repeat( - spec_length, 1 - ) # (batch_size * spec_length, vocab_size) - logits -= presence_penalties * past_presence_penalty_buffer_selected - - # TODO: Frequency Penalty + logits -= presence_penalties * (past_presence_penalty_buffer_selected > 0) # Temperature Scaling temperatures = temperatures.repeat(spec_length, 1) # (batch_size, 1) -> (batch_size * spec_length, 1) diff --git a/QEfficient/utils/constants.py b/QEfficient/utils/constants.py index 57fba282b..d453ecd2c 100644 --- a/QEfficient/utils/constants.py +++ b/QEfficient/utils/constants.py @@ -79,6 +79,7 @@ def get_models_dir(): QEFF_MODELS_DIR = get_models_dir() ONNX_EXPORT_EXAMPLE_REPETITION_PENALTIES = 0.5 +ONNX_EXPORT_EXAMPLE_FREQUENCY_PENALTIES = 0.5 ONNX_EXPORT_EXAMPLE_PRESENCE_PENALTIES = 0.5 ONNX_EXPORT_EXAMPLE_TEMPERATURES = 0.80 ONNX_EXPORT_EXAMPLE_MAX_TOP_K_IDS = 512 @@ -139,6 +140,7 @@ class Constants: MAX_TOP_K_IDS = ONNX_EXPORT_EXAMPLE_MAX_TOP_K_IDS SAMPLER_OPS = { "repetition_penalties", + "frequency_penalties", "presence_penalties", "temperatures", "top_ks", diff --git a/examples/on_device_sampling.py b/examples/on_device_sampling.py index 00d8c2430..8431d5a83 100644 --- a/examples/on_device_sampling.py +++ b/examples/on_device_sampling.py @@ -30,8 +30,8 @@ def main(args, **kwargs): max_top_k_ids = int(args.override_qaic_config.get("max_top_k_ids", 512)) sampling_params = { "repetition_penalties": np.array(args.repetition_penalty, dtype=np.float32).repeat(bs).reshape(-1, 1), + "frequency_penalties": np.array(args.frequency_penalty, dtype=np.float32).repeat(bs).reshape(-1, 1), "presence_penalties": np.array(args.presence_penalty, dtype=np.float32).repeat(bs).reshape(-1, 1), - # "frequency_penalties": np.array(args.frequency_penalty, dtype=np.float32).repeat(bs).reshape(-1, 1), "temperatures": np.array(args.temperature, dtype=np.float32).repeat(bs).reshape(-1, 1), "top_ks": np.array(args.top_k, dtype=np.int32).repeat(bs).reshape(-1, 1), "top_ps": np.array(args.top_p, dtype=np.float32).repeat(bs).reshape(-1, 1), @@ -108,6 +108,7 @@ def main(args, **kwargs): --mxfp6-matmul \ --override-qaic-config "aic_include_sampler:true aic_return_pdfs:false max_top_k_ids:512" \ --repetition-penalty 1.9 \ + --frequency-penalty 0.8 \ --presence-penalty 0.8 \ --temperature 0.67 \ --top-k 54720 \ @@ -128,6 +129,7 @@ def main(args, **kwargs): --mxfp6-matmul \ --override-qaic-config "aic_include_sampler:true aic_return_pdfs:false max_top_k_ids:512" \ --repetition-penalty 1.9 \ + --frequency-penalty 0.8 \ --presence-penalty 0.8 \ --temperature 0.67 \ --top-k 54720 \ @@ -208,6 +210,14 @@ def main(args, **kwargs): "prompt and the generated text so far. Values > 1 encourage the model to use new tokens, " "while values < 1 encourage the model to repeat tokens.", ) + sampling_group.add_argument( + "--frequency-penalty", + type=float, + default=None, + help="Sampling parameter that penalizes new tokens based on their frequency in the " + "generated text so far. Values > 0 encourage the model to use new tokens, while values < " + "0 encourage the model to repeat tokens.", + ) sampling_group.add_argument( "--presence-penalty", type=float, diff --git a/tests/transformers/sampler/test_sampler.py b/tests/transformers/sampler/test_sampler.py index 9335e1d91..99e4a53dd 100644 --- a/tests/transformers/sampler/test_sampler.py +++ b/tests/transformers/sampler/test_sampler.py @@ -205,8 +205,8 @@ def test_greedy_sampling( return_pdfs=False, sampling_params={ "repetition_penalties": np.array(1.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), + "frequency_penalties": np.array(0.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), "presence_penalties": np.array(0.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), - # "frequency_penalties": np.array(0.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), "temperatures": np.array(0.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), "top_ks": np.array(512, dtype=np.int32).repeat(full_batch_size).reshape(-1, 1), "top_ps": np.array(1.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), @@ -298,10 +298,10 @@ def test_random_sampling( include_sampler=True, return_pdfs=False, sampling_params={ - "repetition_penalties": np.array(20.2, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), - "presence_penalties": np.array(10.5, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), - # "frequency_penalties": np.array(0.5, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), - "temperatures": np.array(100.1, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), + "repetition_penalties": np.array(1.9, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), + "frequency_penalties": np.array(0.8, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), + "presence_penalties": np.array(0.8, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), + "temperatures": np.array(0.67, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), "top_ks": np.array(54720, dtype=np.int32).repeat(full_batch_size).reshape(-1, 1), "top_ps": np.array(0.89, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), "min_ps": np.array(0.6, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), @@ -319,56 +319,56 @@ def test_random_sampling( # Compare generated texts golden_texts = { - "w_sampler": "Raymond and my favorite color, alongside reds or purples (I can’t have them both", - "wo_sampler": "John Smith and I am a software engineer. I have been working in the industry for the past ", + "w_sampler": " Kelsey and I am a 20 year old college student. My major in school right now,", + "wo_sampler": " Kaitlyn and I am a 20 year old college student. I am a junior at the", } golden_ids = { "w_sampler": [ [ - 21380, - 322, - 590, - 25448, - 2927, - 29892, - 19963, - 2654, - 29879, - 470, - 3708, - 2701, - 313, - 29902, + 735, + 93567, + 323, + 358, + 1097, + 264, + 220, 508, - 30010, - 29873, - 505, - 963, - 1716, + 1060, + 2362, + 7926, + 5575, + 13, + 3092, + 3682, + 304, + 2978, + 1314, + 1457, + 11, ] ], "wo_sampler": [ [ - 2259, - 7075, - 322, - 306, - 626, - 263, - 7047, - 22055, - 29889, - 306, - 505, - 1063, - 1985, - 297, - 278, - 13661, - 363, - 278, - 4940, - 29871, + 735, + 1339, + 18499, + 323, + 358, + 1097, + 264, + 220, + 508, + 1060, + 2362, + 7926, + 5575, + 13, + 358, + 1097, + 264, + 27144, + 520, + 279, ] ], }