Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion QEfficient/generation/text_generation_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ def cloud_ai_100_exec_kv(
Decoding Draft Language Model and `return_pdfs`=False for regular model.
sampling_params (Dict[str, Any], default=None): A dictionary of sampling parameters supported by the QAIC backend.
The dictionary should contain the following keys:
`repetition_penalties`, `presence_penalties`, `temperatures`, `top_ks`, `top_ps`,
`repetition_penalties`, `frequency_penalties`, `presence_penalties`, `temperatures`, `top_ks`, `top_ps`,
`min_ps`, and `random_numbers`. Each value should be a numpy array of shape (batch_size, 1).

Returns:
Expand Down
7 changes: 6 additions & 1 deletion QEfficient/transformers/models/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -2364,13 +2364,18 @@ def get_sampling_inputs_and_outputs(
dynamic_axes["repetition_penalties"] = {0: "batch_size"}

example_inputs["past_presence_penalty_buffer"] = torch.zeros(
(fbs if self.continuous_batching else bs, self.model.config.vocab_size), dtype=torch.bool
(fbs if self.continuous_batching else bs, self.model.config.vocab_size), dtype=torch.int32
)
dynamic_axes["past_presence_penalty_buffer"] = {
0: "full_batch_size" if self.continuous_batching else "batch_size",
}
output_names.append("past_presence_penalty_buffer_RetainedState")

example_inputs["frequency_penalties"] = (
torch.zeros((bs, 1), dtype=torch.float) + constants.ONNX_EXPORT_EXAMPLE_FREQUENCY_PENALTIES
)
dynamic_axes["frequency_penalties"] = {0: "batch_size"}

example_inputs["presence_penalties"] = (
torch.zeros((bs, 1), dtype=torch.float) + constants.ONNX_EXPORT_EXAMPLE_PRESENCE_PENALTIES
)
Expand Down
36 changes: 24 additions & 12 deletions QEfficient/transformers/sampler/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,20 +80,19 @@ def decode_path(
)

# Update retained states
scatter_values = torch.ones(last_accepted_output_tokens.shape, dtype=torch.bool)
past_repetition_penalty_buffer = CtxScatterFuncCB3D.apply(
past_repetition_penalty_buffer,
batch_index,
last_accepted_output_tokens,
scatter_values,
torch.ones(last_accepted_output_tokens.shape, dtype=torch.bool),
)
gather_values = past_presence_penalty_buffer[batch_index, last_accepted_output_tokens]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For improved performance, I intended to use CtxGatherFuncCB3D but it doesn't work as last_accepted_output_tokens is a tensor of shape (batch_size, seq_len) whereas the function expects it to be of shape (batch_size, 1). Please let me know if there is a workaround.

past_presence_penalty_buffer = CtxScatterFuncCB3D.apply(
past_presence_penalty_buffer,
batch_index,
last_accepted_output_tokens,
scatter_values,
gather_values + 1,
)
# TODO: For frequency retain state, first gather and then scatter
return past_repetition_penalty_buffer, past_presence_penalty_buffer


Expand All @@ -116,6 +115,7 @@ def sampler_forward(
past_repetition_penalty_buffer: Optional[torch.Tensor] = None,
repetition_penalties: Optional[torch.Tensor] = None,
past_presence_penalty_buffer: Optional[torch.Tensor] = None,
frequency_penalties: Optional[torch.Tensor] = None,
presence_penalties: Optional[torch.Tensor] = None,
temperatures: Optional[torch.Tensor] = None,
top_ks: Optional[torch.Tensor] = None,
Expand All @@ -141,8 +141,13 @@ def sampler_forward(
new tokens, while values < 1 encourage the model to repeat tokens.

past_presence_penalty_buffer (`torch.Tensor`, *optional*):
RetainedState buffer used as a mask to apply presence penalty to the output
generated so far.
RetainedState buffer used as a mask to apply frequency and presence penalties to
the output generated so far.

frequency_penalties (`torch.Tensor`, *optional*):
Sampling parameter that penalizes new tokens based on their frequency in the
generated text so far. Values > 0 encourage the model to use new tokens, while
values < 0 encourage the model to repeat tokens.

presence_penalties (`torch.Tensor`, *optional*):
Sampling parameter that penalizes new tokens based on whether they appear in the
Expand Down Expand Up @@ -243,17 +248,24 @@ def sampler_forward(
repetition_penalties_mask = torch.where(past_repetition_penalty_buffer_selected, repetition_penalties, 1.0)
logits *= repetition_penalties_mask ** (-torch.sign(logits))

if (frequency_penalties != 0.0).any() or (presence_penalties != 0.0).any():
past_presence_penalty_buffer_selected = past_presence_penalty_buffer[batch_index_reshaped].repeat(
spec_length, 1
) # (batch_size * spec_length, vocab_size)

# Frequency Penalty
if (frequency_penalties != 0.0).any():
frequency_penalties = frequency_penalties.repeat(
spec_length, 1
) # (batch_size, 1) -> (batch_size * spec_length, 1)
logits -= frequency_penalties * past_presence_penalty_buffer_selected

# Presence Penalty
if (presence_penalties != 0.0).any():
presence_penalties = presence_penalties.repeat(
spec_length, 1
) # (batch_size, 1) -> (batch_size * spec_length, 1)
past_presence_penalty_buffer_selected = past_presence_penalty_buffer[batch_index_reshaped].repeat(
spec_length, 1
) # (batch_size * spec_length, vocab_size)
logits -= presence_penalties * past_presence_penalty_buffer_selected

# TODO: Frequency Penalty
logits -= presence_penalties * (past_presence_penalty_buffer_selected > 0)

# Temperature Scaling
temperatures = temperatures.repeat(spec_length, 1) # (batch_size, 1) -> (batch_size * spec_length, 1)
Expand Down
2 changes: 2 additions & 0 deletions QEfficient/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def get_models_dir():
QEFF_MODELS_DIR = get_models_dir()

ONNX_EXPORT_EXAMPLE_REPETITION_PENALTIES = 0.5
ONNX_EXPORT_EXAMPLE_FREQUENCY_PENALTIES = 0.5
ONNX_EXPORT_EXAMPLE_PRESENCE_PENALTIES = 0.5
ONNX_EXPORT_EXAMPLE_TEMPERATURES = 0.80
ONNX_EXPORT_EXAMPLE_MAX_TOP_K_IDS = 512
Expand Down Expand Up @@ -139,6 +140,7 @@ class Constants:
MAX_TOP_K_IDS = ONNX_EXPORT_EXAMPLE_MAX_TOP_K_IDS
SAMPLER_OPS = {
"repetition_penalties",
"frequency_penalties",
"presence_penalties",
"temperatures",
"top_ks",
Expand Down
12 changes: 11 additions & 1 deletion examples/on_device_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ def main(args, **kwargs):
max_top_k_ids = int(args.override_qaic_config.get("max_top_k_ids", 512))
sampling_params = {
"repetition_penalties": np.array(args.repetition_penalty, dtype=np.float32).repeat(bs).reshape(-1, 1),
"frequency_penalties": np.array(args.frequency_penalty, dtype=np.float32).repeat(bs).reshape(-1, 1),
"presence_penalties": np.array(args.presence_penalty, dtype=np.float32).repeat(bs).reshape(-1, 1),
# "frequency_penalties": np.array(args.frequency_penalty, dtype=np.float32).repeat(bs).reshape(-1, 1),
"temperatures": np.array(args.temperature, dtype=np.float32).repeat(bs).reshape(-1, 1),
"top_ks": np.array(args.top_k, dtype=np.int32).repeat(bs).reshape(-1, 1),
"top_ps": np.array(args.top_p, dtype=np.float32).repeat(bs).reshape(-1, 1),
Expand Down Expand Up @@ -108,6 +108,7 @@ def main(args, **kwargs):
--mxfp6-matmul \
--override-qaic-config "aic_include_sampler:true aic_return_pdfs:false max_top_k_ids:512" \
--repetition-penalty 1.9 \
--frequency-penalty 0.8 \
--presence-penalty 0.8 \
--temperature 0.67 \
--top-k 54720 \
Expand All @@ -128,6 +129,7 @@ def main(args, **kwargs):
--mxfp6-matmul \
--override-qaic-config "aic_include_sampler:true aic_return_pdfs:false max_top_k_ids:512" \
--repetition-penalty 1.9 \
--frequency-penalty 0.8 \
--presence-penalty 0.8 \
--temperature 0.67 \
--top-k 54720 \
Expand Down Expand Up @@ -208,6 +210,14 @@ def main(args, **kwargs):
"prompt and the generated text so far. Values > 1 encourage the model to use new tokens, "
"while values < 1 encourage the model to repeat tokens.",
)
sampling_group.add_argument(
"--frequency-penalty",
type=float,
default=None,
help="Sampling parameter that penalizes new tokens based on their frequency in the "
"generated text so far. Values > 0 encourage the model to use new tokens, while values < "
"0 encourage the model to repeat tokens.",
)
sampling_group.add_argument(
"--presence-penalty",
type=float,
Expand Down
92 changes: 46 additions & 46 deletions tests/transformers/sampler/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,8 +205,8 @@ def test_greedy_sampling(
return_pdfs=False,
sampling_params={
"repetition_penalties": np.array(1.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1),
"frequency_penalties": np.array(0.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1),
"presence_penalties": np.array(0.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1),
# "frequency_penalties": np.array(0.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1),
"temperatures": np.array(0.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1),
"top_ks": np.array(512, dtype=np.int32).repeat(full_batch_size).reshape(-1, 1),
"top_ps": np.array(1.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1),
Expand Down Expand Up @@ -298,10 +298,10 @@ def test_random_sampling(
include_sampler=True,
return_pdfs=False,
sampling_params={
"repetition_penalties": np.array(20.2, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1),
"presence_penalties": np.array(10.5, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1),
# "frequency_penalties": np.array(0.5, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1),
"temperatures": np.array(100.1, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1),
"repetition_penalties": np.array(1.9, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1),
"frequency_penalties": np.array(0.8, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1),
"presence_penalties": np.array(0.8, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1),
"temperatures": np.array(0.67, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1),
"top_ks": np.array(54720, dtype=np.int32).repeat(full_batch_size).reshape(-1, 1),
"top_ps": np.array(0.89, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1),
"min_ps": np.array(0.6, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1),
Expand All @@ -319,56 +319,56 @@ def test_random_sampling(

# Compare generated texts
golden_texts = {
"w_sampler": "Raymond and my favorite color, alongside reds or purples (I can’t have them both",
"wo_sampler": "John Smith and I am a software engineer. I have been working in the industry for the past ",
"w_sampler": " Kelsey and I am a 20 year old college student. My major in school right now,",
"wo_sampler": " Kaitlyn and I am a 20 year old college student. I am a junior at the",
}
golden_ids = {
"w_sampler": [
[
21380,
322,
590,
25448,
2927,
29892,
19963,
2654,
29879,
470,
3708,
2701,
313,
29902,
735,
93567,
323,
358,
1097,
264,
220,
508,
30010,
29873,
505,
963,
1716,
1060,
2362,
7926,
5575,
13,
3092,
3682,
304,
2978,
1314,
1457,
11,
]
],
"wo_sampler": [
[
2259,
7075,
322,
306,
626,
263,
7047,
22055,
29889,
306,
505,
1063,
1985,
297,
278,
13661,
363,
278,
4940,
29871,
735,
1339,
18499,
323,
358,
1097,
264,
220,
508,
1060,
2362,
7926,
5575,
13,
358,
1097,
264,
27144,
520,
279,
]
],
}
Expand Down
Loading