From 21e230198dfd5a797b0bc9f2796467391ebebdbb Mon Sep 17 00:00:00 2001 From: Mohit Soni Date: Wed, 9 Jul 2025 09:41:56 +0000 Subject: [PATCH 1/5] Adding CB in Llama4 Signed-off-by: Mohit Soni --- .../generation/text_generation_inference.py | 6 +- .../models/llama4/modeling_llama4.py | 95 +++++++++++++------ .../transformers/models/modeling_auto.py | 59 +++++++++--- 3 files changed, 118 insertions(+), 42 deletions(-) diff --git a/QEfficient/generation/text_generation_inference.py b/QEfficient/generation/text_generation_inference.py index fd7ef03ff..5ac70bf58 100755 --- a/QEfficient/generation/text_generation_inference.py +++ b/QEfficient/generation/text_generation_inference.py @@ -444,7 +444,11 @@ def __init__( self._set_tokenizer_params() # set tokenizer params # Skip inputs/outputs self._session.skip_buffers( - [x for x in self._session.input_names + self._session.output_names if x.startswith("past_")] + [ + x + for x in self._session.input_names + self._session.output_names + if x.startswith("past_") or x.endswith("_RetainedState") + ] ) def _set_tokenizer_params(self): diff --git a/QEfficient/transformers/models/llama4/modeling_llama4.py b/QEfficient/transformers/models/llama4/modeling_llama4.py index 4b957ebec..13ef38e6b 100644 --- a/QEfficient/transformers/models/llama4/modeling_llama4.py +++ b/QEfficient/transformers/models/llama4/modeling_llama4.py @@ -831,7 +831,15 @@ def __init__(self, model): self.language_model = self.model.language_model self.config = self.model.config - def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_values): + def forward( + self, + input_ids, + vision_embeds, + position_ids, + image_idx, + past_key_values, + batch_index: Optional[torch.LongTensor] = None, + ): inputs_embeds = self.model.language_model.get_input_embeddings()(input_ids) selected = input_ids == self.model.config.image_token_index indices1 = selected.to(torch.int64).cumsum(1) - 1 @@ -841,7 +849,11 @@ def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_va image_embeds = torch.where(selected.unsqueeze(-1), image_features_expanded, inputs_embeds) inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), inputs_embeds, image_embeds) outputs = self.model.language_model( - inputs_embeds=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, use_cache=True + inputs_embeds=inputs_embeds, + position_ids=position_ids, + past_key_values=past_key_values, + batch_index=batch_index, + use_cache=True, ) next_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0) image_idx = torch.where(image_idx < next_idx, next_idx, image_idx) @@ -888,6 +900,9 @@ def get_specializations( ctx_len: int, img_size: int, kv_offload: bool = False, + continuous_batching: bool = False, + kv_cache_batch_size: Optional[int] = None, + full_batch_size: Optional[int] = None, **compiler_options, ): max_num_tiles = compiler_options.pop("max_num_tiles", None) @@ -936,28 +951,42 @@ def get_specializations( "img_size": img_size, } ] - lang = [ - { - "batch_size": batch_size, - "seq_len": prefill_seq_len, - "ctx_len": ctx_len, - "max_num_tiles": max_num_tiles, - "img_size": img_size, - "vision_size": vision_size, - "chunk_length": prefill_seq_len, - "chunk_ctx_len": chunk_ctx_len, - }, - { - "batch_size": batch_size, - "seq_len": "1", - "ctx_len": ctx_len, - "max_num_tiles": max_num_tiles, - "img_size": img_size, - "vision_size": vision_size, - "chunk_length": prefill_seq_len, - "chunk_ctx_len": chunk_ctx_len, - }, - ] + + lang_prefill = { + "batch_size": 1 if continuous_batching else batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "max_num_tiles": max_num_tiles, + "img_size": img_size, + "vision_size": vision_size, + "chunk_length": prefill_seq_len, + "chunk_ctx_len": chunk_ctx_len, + } + if continuous_batching: + lang_prefill["full_batch_size"] = kv_cache_batch_size + else: + lang_prefill["batch_size"] = kv_cache_batch_size + if full_batch_size: + lang_prefill["full_batch_exec_size"] = full_batch_size + + lang_decode = { + "batch_size": full_batch_size if continuous_batching else batch_size, + "seq_len": 1, + "ctx_len": ctx_len, + "max_num_tiles": max_num_tiles, + "img_size": img_size, + "vision_size": vision_size, + "chunk_length": prefill_seq_len, + "chunk_ctx_len": chunk_ctx_len, + } + if continuous_batching: + lang_decode["full_batch_size"] = kv_cache_batch_size + else: + lang_decode["batch_size"] = kv_cache_batch_size + + lang = [] + lang.append(lang_prefill) + lang.append(lang_decode) specializations = {} @@ -966,18 +995,22 @@ def get_specializations( specializations["lang"] = lang return specializations, compiler_options else: + lang[0].pop("vision_size") + lang[1].pop("vision_size") return lang, compiler_options - def get_onnx_dynamic_axes(self, kv_offload: bool = False): + def get_onnx_dynamic_axes(self, kv_offload: bool = False, continuous_batching: bool = False): # Define dynamic axes vision_dynamic_axes = {} lang_dynamic_axes = {} lang_dynamic_axes["input_ids"] = {0: "batch_size", 1: "seq_len"} lang_dynamic_axes["position_ids"] = {0: "batch_size", 1: "seq_len"} lang_dynamic_axes["vision_embeds"] = {0: "vision_size"} + if continuous_batching: + lang_dynamic_axes["batch_index"] = {0: "batch_size"} vision_dynamic_axes["pixel_values"] = {0: "max_num_tiles", 2: "img_size", 3: "img_size"} - pkv_dynamic_axes = {0: "batch_size"} + pkv_dynamic_axes = {0: "full_batch_size" if continuous_batching else "batch_size"} for i in range(self.language_model.config.num_hidden_layers): # switch between chunk_ctx_len and ctx_len for RoPE and NoPE layers. if int((i + 1) % 4 != 0): @@ -1040,7 +1073,7 @@ def get_dummy_pkv_cache(self, config, batch_size, seq_len): past_key_values.append(pkv) return past_key_values - def get_dummy_inputs(self, kv_offload: bool = False): + def get_dummy_inputs(self, kv_offload: bool = False, continuous_batching: bool = False): if vis_cfg := getattr(self.config, "vision_config", None): img_size = getattr(vis_cfg, "image_size", 336) else: @@ -1085,10 +1118,14 @@ def get_dummy_inputs(self, kv_offload: bool = False): .repeat(constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, 1) ) lang_inputs["image_idx"] = torch.zeros((inputs_shapes["image_idx"]), dtype=torch.int64) + + bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE + fbs: int = constants.ONNX_EXPORT_EXAMPLE_FBS + # Add data for KV past_key_values = self.get_dummy_pkv_cache( config=self.language_model.config, - batch_size=constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, + batch_size=fbs if continuous_batching else bs, seq_len=constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, ) @@ -1097,6 +1134,8 @@ def get_dummy_inputs(self, kv_offload: bool = False): for kv in ["key", "value"]: lang_inputs["past_key_values"][i].append(torch.zeros(past_key_values[0][0].shape, dtype=torch.float32)) + if continuous_batching: + lang_inputs["batch_index"] = torch.arange(bs).view(bs, 1) inputs = {} if kv_offload: inputs["vision"] = vision_inputs diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 42898381d..0b4f069cf 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -545,6 +545,7 @@ class _QEffAutoModelForImageTextToTextDualQPC: def __init__( self, model: nn.Module, + continuous_batching, **kwargs, ): if kwargs.pop("full_batch_size", None): @@ -553,6 +554,7 @@ def __init__( self.config = model.config self.vision_model = QEffVisionEncoderForTextImageToTextModel(model, **kwargs) self.lang_model = QEffCausalLMForTextImageToTextModel(model, **kwargs) + self.continuous_batching = continuous_batching self.input_shapes, self.output_names = None, None @property @@ -592,8 +594,8 @@ def export( export_dir: Optional[str] = None, **kwargs, ) -> str: - inputs = self.model.get_dummy_inputs(kv_offload=True) - dynamic_axes = self.model.get_onnx_dynamic_axes(kv_offload=True) + inputs = self.model.get_dummy_inputs(kv_offload=True, continuous_batching=self.continuous_batching) + dynamic_axes = self.model.get_onnx_dynamic_axes(kv_offload=True, continuous_batching=self.continuous_batching) output_names = self.model.get_output_names(kv_offload=True) self.vision_model.export( @@ -630,14 +632,20 @@ def compile( skip_lang: Optional[bool] = False, **compiler_options, ) -> str: - if any(param is not None for param in [full_batch_size, kv_cache_batch_size, num_speculative_tokens]): + if skip_lang and skip_vision: + raise ValueError("Expected at least one of 'skip_lang' or 'skip_vision' to be False") + + if self.continuous_batching and full_batch_size is None: + raise TypeError("`full_batch_size` is required when `continuous_batching=True`.") + + if kv_cache_batch_size and not full_batch_size: raise ValueError( - f"Expected 'full_batch_size', 'kv_cache_batch_size', 'num_speculative_tokens' to be None but got: " - f"full_batch_size={full_batch_size}, kv_cache_batch_size={kv_cache_batch_size}, num_speculative_tokens={num_speculative_tokens}, " + "KV caching requires continuous batching. Please set `full_batch_size` and " + "enable `continuous_batching=True` in `from_pretrained`." ) - if skip_lang and skip_vision: - raise ValueError("Expected at least one of 'skip_lang' or 'skip_vision' to be False") + # Infer kv_cache_batch_size if not provided + kv_cache_batch_size = kv_cache_batch_size or full_batch_size or batch_size output_names = self.model.get_output_names(kv_offload=True) @@ -647,6 +655,9 @@ def compile( ctx_len=ctx_len, img_size=img_size, kv_offload=True, + continuous_batching=self.continuous_batching, + kv_cache_batch_size=kv_cache_batch_size, + full_batch_size=full_batch_size, **compiler_options, ) @@ -715,6 +726,8 @@ def compile( def generate( self, inputs: torch.Tensor, + tokenizer: Union[PreTrainedTokenizerFast, PreTrainedTokenizer] = None, + prompts: List[str] = None, streamer: Optional[TextStreamer] = None, device_ids: List[int] = None, runtime_ai100: bool = True, @@ -732,6 +745,14 @@ def generate( """ if not runtime_ai100: raise NotImplementedError("PyTorch execution is not supported yet for this model!") + if tokenizer and prompts: + return QEfficient.cloud_ai_100_exec_kv( + tokenizer, + self.lang_model.qpc_path, + prompt=prompts, + device_id=device_ids, + generation_len=generation_len, + ) return self.kv_offload_generate( inputs=inputs, device_ids=device_ids, streamer=streamer, generation_len=generation_len @@ -1259,15 +1280,21 @@ class QEFFAutoModelForImageTextToText: _hf_auto_class = AutoModelForImageTextToText - def __new__(self, model: nn.Module, kv_offload: Optional[bool] = True, **kwargs): + def __new__(self, model: nn.Module, kv_offload: Optional[bool] = True, continuous_batching: bool = False, **kwargs): if kv_offload: - return _QEffAutoModelForImageTextToTextDualQPC(model, **kwargs) + return _QEffAutoModelForImageTextToTextDualQPC(model, continuous_batching, **kwargs) else: return _QEFFAutoModelForImageTextToTextSingleQPC(model, **kwargs) @classmethod @with_replaced_quantizers - def from_pretrained(cls, pretrained_model_name_or_path: str, kv_offload: Optional[bool] = None, **kwargs): + def from_pretrained( + cls, + pretrained_model_name_or_path: str, + kv_offload: Optional[bool] = None, + continuous_batching: bool = False, + **kwargs, + ): """Used to load models supported by transformers.AutoModelForImageTextToText for Cloud AI 100. Args: @@ -1284,12 +1311,18 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, kv_offload: Optiona if kwargs.get("low_cpu_mem_usage", None): logger.warning("Updating low_cpu_mem_usage=False") - if kwargs.pop("continuous_batching", None): - NotImplementedError("Continuous batching is not supported for image-text-to-text models yet.") + if continuous_batching and not kv_offload: + NotImplementedError("Continuous batching is not supported for kv_offload = False") kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, **kwargs) - return cls(model, kv_offload=kv_offload, pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs) + return cls( + model, + kv_offload=kv_offload, + continuous_batching=continuous_batching, + pretrained_model_name_or_path=pretrained_model_name_or_path, + **kwargs, + ) MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP = {"InternVLChatModel": QEFFAutoModelForImageTextToText} From c279bff6af1c639c7c273e3c3072921f4358733c Mon Sep 17 00:00:00 2001 From: Mohit Soni Date: Wed, 9 Jul 2025 09:46:37 +0000 Subject: [PATCH 2/5] Updating Hybrid_Chunked_Cache Cache for CB Signed-off-by: Mohit Soni --- QEfficient/transformers/cache_utils.py | 29 ++++++++++++++++++++------ 1 file changed, 23 insertions(+), 6 deletions(-) diff --git a/QEfficient/transformers/cache_utils.py b/QEfficient/transformers/cache_utils.py index 16767fbe2..f87e008ee 100644 --- a/QEfficient/transformers/cache_utils.py +++ b/QEfficient/transformers/cache_utils.py @@ -443,6 +443,7 @@ def update( else: position_ids = cache_kwargs.get("position_ids") + batch_index = cache_kwargs.get("batch_index", None) # Check and fetch batch index value form the kwargs is_sliding_layer = torch.tensor(bool(self.is_sliding[layer_idx])) # Update the position_ids to handle the sliding window @@ -460,10 +461,22 @@ def update( valid_mask = (kv_position_ids != -1).unsqueeze(1).unsqueeze(-1) key_states = torch.where(valid_mask == 1, key_states, torch.zeros_like(key_states)) value_states = torch.where(valid_mask == 1, value_states, torch.zeros_like(value_states)) - self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], kv_position_ids, key_states) - self.value_cache[layer_idx] = CtxScatterFunc.apply( - self.value_cache[layer_idx], kv_position_ids, value_states - ) + if batch_index is not None: + invalid_scatter_index = torch.iinfo(torch.int32).max + scatter_position_ids = torch.where(position_ids < 0, invalid_scatter_index, position_ids) + + self.key_cache[layer_idx] = CtxScatterFuncCB.apply( + self.key_cache[layer_idx], batch_index, scatter_position_ids, key_states + ) + + self.value_cache[layer_idx] = CtxScatterFuncCB.apply( + self.value_cache[layer_idx], batch_index, scatter_position_ids, value_states + ) + else: + self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], position_ids, key_states) + self.value_cache[layer_idx] = CtxScatterFunc.apply( + self.value_cache[layer_idx], position_ids, value_states + ) k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx] # Original Gather @@ -483,8 +496,12 @@ def update( final_indices = torch.where( (is_sliding_layer & (position_ids.max() >= (layer_ctx_len - 1))), rolling_indices, ctx_indices ) - k_out = CtxGatherFunc.apply(k_out, final_indices) - v_out = CtxGatherFunc.apply(v_out, final_indices) + if batch_index is not None: + k_out = CtxGatherFuncCB.apply(k_out, batch_index, final_indices) + v_out = CtxGatherFuncCB.apply(v_out, batch_index, final_indices) + else: + k_out = CtxGatherFunc.apply(k_out, final_indices) + v_out = CtxGatherFunc.apply(v_out, final_indices) ctx_v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out) v_out = torch.where((is_sliding_layer & (position_ids.max() >= (layer_ctx_len - 1))), v_out, ctx_v_out) return k_out, v_out From f4a447d5b19c7f46a3448727c99cba146124503d Mon Sep 17 00:00:00 2001 From: Asmita Goswami Date: Wed, 17 Sep 2025 16:51:51 +0000 Subject: [PATCH 3/5] Updated CB for llama4 for 1 image with multiple prompt Signed-off-by: Asmita Goswami --- .../generation/text_generation_inference.py | 157 ++++++++++++++++++ .../models/llama4/modeling_llama4.py | 53 +++++- .../transformers/models/modeling_auto.py | 118 ++++++++++++- examples/llama4_CB_example.py | 90 ++++++++++ 4 files changed, 410 insertions(+), 8 deletions(-) create mode 100644 examples/llama4_CB_example.py diff --git a/QEfficient/generation/text_generation_inference.py b/QEfficient/generation/text_generation_inference.py index 5ac70bf58..a861d98ff 100755 --- a/QEfficient/generation/text_generation_inference.py +++ b/QEfficient/generation/text_generation_inference.py @@ -826,6 +826,163 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len): return decode_pause_time + def run_vision_language_continuous_batching_decode(self, prompt_queue, generation_len, shared_vision_embeddings=None): + """ + Runs continuous batching decode for vision language models with shared vision embeddings. + + Method sets up the initial conditions for decoding and preparing the decode inputs. Then enters a loop that continues as long as there are prompts in the queue or any decoding is ongoing. In each iteration of the loop, it runs the session with the current decode inputs, prepares the inputs for the next iteration and updates the decode inputs. If a prompt has been fully decoded, it runs prefill for the next prompt in the queue if available. + + Args: + prompt_queue (deque): The queue of prompts to be decoded. + generation_len (int): The generation length. + shared_vision_embeddings (np.array, optional): Shared vision embeddings for vision-language models. Defaults to None. + + """ + # Set logits placeholder for decode + logits_out_placeholder = np.zeros( + (self.full_batch_size, self._decode_seq_len, self._vocab_size), dtype=np.float32 + ) + self._session.set_buffers({"logits": logits_out_placeholder}) + + # Set shared vision embeddings if provided + if shared_vision_embeddings is not None: + self._session.set_buffers(shared_vision_embeddings) + + # Generate flag for tracking progress for each batch ID + current_decode_ongoing = np.full((self.full_batch_size, 1), True) + + # Generate an array for maintaining the tokens generated in each batch ID + generated_id_current_index = np.ones((self.full_batch_size, 1), np.int64) + + # Generate a batch ID map for mapping the batch ID if input > full_batch_size. + # This ID map will be used for storing all generated tokens + batch_id_map = {i: i for i in range(self.full_batch_size)} + decode_pause_time = 0 + + # Prepare decode inputs. + decode_inputs = self.prepare_decode_inputs() + + while prompt_queue or current_decode_ongoing.any(): + outputs = self._session.run(decode_inputs) + + # Prepare inputs for next iteration + logits = outputs["logits"] + if len(logits.shape) == 2: + logits = np.expand_dims(logits, 1) + next_token_id = logits.argmax(2) + + for decode_batch_id in range(self.full_batch_size): + if ( + next_token_id[decode_batch_id, -1] == self.tokenizer.eos_token_id + or generated_id_current_index[decode_batch_id] >= self.generation_len[decode_batch_id] + ): + if prompt_queue: + start = perf_counter() + # run prefill for next prompt input. + outputs, position_ids, generation_len = self.run_vision_language_prefill( + prompt_queue.popleft(), + generation_len, + decode_batch_id=np.array(decode_batch_id, dtype=np.int64).reshape(1, 1), + shared_vision_embeddings=shared_vision_embeddings, + ) + + new_token_id = self.update_decode_input(outputs, position_ids, generation_len, decode_batch_id) + + batch_id_map[decode_batch_id] = max(batch_id_map.values()) + 1 + self.generated_ids[batch_id_map[decode_batch_id], 0] = new_token_id.squeeze(1) + generated_id_current_index[decode_batch_id] = 1 + + self._session.set_buffers({"logits": logits_out_placeholder}) + + # Re-set shared vision embeddings for consistency + if shared_vision_embeddings: + self._session.set_buffers(shared_vision_embeddings) + + decode_pause_time += perf_counter() - start + + if self._prompt_to_lora_id_mapping_decode: + decode_inputs["lora_ids"][decode_batch_id] = self._prompt_to_lora_id_mapping_decode[ + batch_id_map[decode_batch_id] + ] + + else: + current_decode_ongoing[decode_batch_id] = False + else: + # If the generated sequence is valid and within generation len prepare for next decode + decode_inputs["input_ids"][decode_batch_id, -1] = next_token_id[decode_batch_id, -1] + decode_inputs["position_ids"][decode_batch_id, -1] += 1 + self.generated_ids[batch_id_map[decode_batch_id], generated_id_current_index[decode_batch_id]] = ( + next_token_id[decode_batch_id, -1] + ) + + generated_id_current_index[decode_batch_id] += 1 + + return decode_pause_time + + def run_vision_language_prefill(self, prompt, generation_len, decode_batch_id=None, shared_vision_embeddings=None): + """ + Default method for running decode. Executes the decoding process for a given set of inputs and a specified generation length. + Args: + prompt (str): The prompt for which to run prefill. + generation_len (int): Max allowed length for generating tokens. The decoding process will be terminated when generation length is reached. + decode_batch_id (np.ndarray, optional): The decode batch ID for continuous batching. Defaults to None. + """ + # Run prefill + inputs = self.tokenizer(prompt, return_tensors="np", padding=True) + position_ids = inputs["attention_mask"].sum(1, keepdims=True) + padded_len = inputs["input_ids"].shape[1] + num_chunks = -(padded_len // -self._prefill_seq_len) # ceil divide without float + padded_len = num_chunks * self._prefill_seq_len # Convert to a multiple of prompt_len + + # Initialize variables specific to request + # Calculate the max generation length. + max_gen_len = self._ctx_len - position_ids.max() + generation_len = self._fetch_generation_len(generation_len, max_gen_len) + + # Set the prefill logic buffer + logits_out_placeholder = np.zeros((prefill_logit_bs, 1, self._vocab_size), dtype=np.float32) + self._session.set_buffers({"logits": logits_out_placeholder}) + + # Set shared vision embeddings if provided + if shared_vision_embeddings is not None: + self._session.set_buffers(shared_vision_embeddings) + + inputs = self.tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len) + inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1) + inputs.pop("token_type_ids", None) + + if decode_batch_id is not None: + inputs["batch_index"] = decode_batch_id + if self.is_tlm: + inputs["num_logits_to_keep"] = np.zeros((1, 1)) + + if self._prompt_to_lora_id_mapping_prefill: + if self.full_batch_size: + inputs["lora_ids"] = np.array( + self._prompt_to_lora_id_mapping_prefill.popleft(), dtype=np.int64 + ).reshape(1, 1) + else: + batch_lora_ids = [self._prompt_to_lora_id_mapping_prefill.popleft() for i in range(self.batch_size)] + inputs["lora_ids"] = np.array(batch_lora_ids, dtype=np.int64).reshape(self.batch_size, 1) + + for i in range(num_chunks): + chunk_inputs = inputs.copy() + chunk_inputs["input_ids"] = inputs["input_ids"][ + :, i * self._prefill_seq_len : (i + 1) * self._prefill_seq_len + ] + chunk_inputs["position_ids"] = inputs["position_ids"][ + :, i * self._prefill_seq_len : (i + 1) * self._prefill_seq_len + ] + outputs = self._session.run(chunk_inputs) + if self._write_io_dir is not None: + write_io_files(inputs, outputs, self._write_io_dir, "prefill", "aic_batch_io", True, False) + return ( + outputs, + position_ids, + generation_len, + ) + + def run_decode(self, decode_inputs, generation_len, streamer: Optional[transformers.TextStreamer] = None): """ Default method for running decode. Executes the decoding process for a given set of inputs and a specified generation length. diff --git a/QEfficient/transformers/models/llama4/modeling_llama4.py b/QEfficient/transformers/models/llama4/modeling_llama4.py index 13ef38e6b..3d872f5b7 100644 --- a/QEfficient/transformers/models/llama4/modeling_llama4.py +++ b/QEfficient/transformers/models/llama4/modeling_llama4.py @@ -841,11 +841,41 @@ def forward( batch_index: Optional[torch.LongTensor] = None, ): inputs_embeds = self.model.language_model.get_input_embeddings()(input_ids) + batch_size = None + + # Handle CB case with multiple prompts sharing same image + if batch_index is not None and batch_index.numel() > 1: + # For CB with multiple prompts sharing same image, reuse vision embeds accross batches + batch_size = input_ids.shape[0] + + #Expanfd vision_embeds to match batch size if needed + if vision_embeds.shape[0] == 1 and batch_size > 1: + vision_embeds = vision_embeds.expand(batch_size, -1, -1) + selected = input_ids == self.model.config.image_token_index indices1 = selected.to(torch.int64).cumsum(1) - 1 - indices1 = torch.where(indices1 != -1, indices1 + image_idx, indices1) + # indices1 = torch.where(indices1 != -1, indices1 + image_idx, indices1) + # indices0 = torch.arange(selected.unsqueeze(0).shape[0]).view(-1, 1) + # image_features_expanded = vision_embeds.unsqueeze(0)[indices0, indices1] + + #Handle batch aware image indexing for CB + if batch_size is not None: + # For CB, use per-batch image indices + batch_image_idx = image_idx.expand_as(selected[:, :1]) + indices1 = torch.where(indices1 != -1, indices1 + batch_image_idx, indices1) + else: + # For non-CB, use global image indices + indices1 = torch.where(indices1 != -1, indices1 + image_idx, indices1) indices0 = torch.arange(selected.unsqueeze(0).shape[0]).view(-1, 1) - image_features_expanded = vision_embeds.unsqueeze(0)[indices0, indices1] + + # Handle vision embeddings indexing for batch processing + if vision_embeds.dim() == 3 and vision_embeds.shape[0] == input_ids.shape[0]: + # Batch wise vision embeddings + image_features_expanded = vision_embeds[indices0, indices1] + else: + # Single vision embeddings for all batches/ single image shared accross all batches + image_features_expanded = vision_embeds.unsqueeze(0)[indices0, indices1] + image_embeds = torch.where(selected.unsqueeze(-1), image_features_expanded, inputs_embeds) inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), inputs_embeds, image_embeds) outputs = self.model.language_model( @@ -855,8 +885,15 @@ def forward( batch_index=batch_index, use_cache=True, ) - next_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0) - image_idx = torch.where(image_idx < next_idx, next_idx, image_idx) + # Update image_idx to point to the next available vision_embeds index - handle batch case + if batch_index is not None and indices1.numel() > 0: + # For CB, update image_idx per batch + next_idx = (indices1.max(dim=1, keepdim=True)[0] + 1).unsqueeze(1) + image_idx = torch.where(image_idx < next_idx, next_idx, image_idx) + else: + next_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0) + image_idx = torch.where(image_idx < next_idx, next_idx, image_idx) + return outputs.logits, vision_embeds, image_idx, outputs.past_key_values @@ -946,7 +983,7 @@ def get_specializations( vision = [ { - "batch_size": batch_size, + "batch_size": 1, # To process image only once for all batch_sizes(prompts) in continuous batching "max_num_tiles": max_num_tiles, "img_size": img_size, } @@ -963,7 +1000,9 @@ def get_specializations( "chunk_ctx_len": chunk_ctx_len, } if continuous_batching: - lang_prefill["full_batch_size"] = kv_cache_batch_size + lang_prefill["full_batch_size"] = full_batch_size or kv_cache_batch_size + # Enable multi-prompt support with shared vision embeddings + lang_prefill["shared_vision"] = 1 else: lang_prefill["batch_size"] = kv_cache_batch_size if full_batch_size: @@ -980,7 +1019,7 @@ def get_specializations( "chunk_ctx_len": chunk_ctx_len, } if continuous_batching: - lang_decode["full_batch_size"] = kv_cache_batch_size + lang_decode["full_batch_size"] = full_batch_size or kv_cache_batch_size else: lang_decode["batch_size"] = kv_cache_batch_size diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 0b4f069cf..7357047a7 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -57,6 +57,8 @@ ) from QEfficient.utils.logging_utils import logger +from QEfficient.generation.text_generation_inference import TextGeneration + class QEFFTransformersBase(QEFFBaseModel): """ @@ -745,6 +747,17 @@ def generate( """ if not runtime_ai100: raise NotImplementedError("PyTorch execution is not supported yet for this model!") + + # Handle CB for multiple prompts with same image + if self.continuous_batching and tokenizer and prompts and len(prompts) > 1: + return self.continuous_batching_multi_prompt_generate( + inputs=inputs, + tokenizer=tokenizer, + prompts=prompts, + device_ids=device_ids, + generation_len=generation_len, + streamer=streamer, + ) if tokenizer and prompts: return QEfficient.cloud_ai_100_exec_kv( tokenizer, @@ -758,6 +771,109 @@ def generate( inputs=inputs, device_ids=device_ids, streamer=streamer, generation_len=generation_len ) + def continuous_batching_multi_prompt_generate( + self, + inputs: torch.Tensor, + tokenizer: Union[PreTrainedTokenizerFast, PreTrainedTokenizer], + prompts: List[str], + device_ids: List[int] = None, + generation_len: Optional[int] = None, + streamer: Optional[TextStreamer] = None, + ): + """ + Optimized continuous batching generate function for multiple prompts with same image. + This method processes a single image with multiple text prompts in a continuous batching manner, by: + 1. Running the vision encoder once for the shared image. + 2. Using continuous batching for multiplt prompts in the language decoder. + 3. Sharing vision embeddings across all prompts to save memory and computation. + """ + if not self.lang_model.qpc_path: + raise TypeError("Please run compile API for language model first!") + + try: + vision_session = None + lang_session = None + if self.vision_model.qpc_path: + vision_session = QAICInferenceSession(self.vision_model.qpc_path, device_ids) + + lang_session = QAICInferenceSession(self.lang_model.qpc_path, device_ids, activate=False) + + # Get compilation dimensions + batch_size, ctx_len, fbs = get_compilation_dims(self.lang_model.qpc_path) + + # Skip inputs/outputs + lang_session.skip_buffers( + [ + x + for x in lang_session.input_names + lang_session.output_names + if x.startswith("past_") or x.endswith("_RetainedState") + ] + ) + + # Process vision inputs once for all prompts + vision_inputs = { + k: v for k, v in inputs.items() if k in {"pixel_values", "aspect_ratio_ids", "aspect_ratio_mask"} + } + if vision_inputs: + vision_inputs["pixel_values"] = vision_inputs["pixel_values"].to(torch.float16).cpu().numpy() + vision_start = perf_counter() + vision_outputs = {} + if vision_inputs: + vision_outputs = vision_session.run(vision_inputs) + vision_end = perf_counter() + + # Deactivate vision session after use + if self.vision_model.qpc_path: + vision_session.deactivate() + + # Text generation instance for continuous batching + text_generator = TextGeneration( + tokenizer=tokenizer, + qpc_path=self.lang_model.qpc_path, + device_id=device_ids, + ctx_len=ctx_len, + enable_debug_logs=False, + full_batch_size=fbs, + ) + + # Prepare prompts for CB + # Each prompt processed with same vision embeddings + tokenized_prompts = [] + for prompt in prompts: + tokenized_prompt = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True) + tokenized_prompts.append(tokenized_prompt) + + # Run CB with shared vision embeddings + lang_session.activate() + + if vision_outputs: + lang_session.set_buffers(vision_outputs) + + # Execute continuous batching generate + exec_info = text_generator.generate( + prompt=prompts, + generation_len=generation_len, + streamer=streamer is not None, + ) + + print("Vision encoding time (s): ", vision_end - vision_start) + return exec_info + except Exception as e: + print(f"Error in continuous batching: {str(e)}") + raise + finally: + # Clean up + if vision_session: + try: + vision_session.deactivate() + except: + pass + if lang_session: + try: + lang_session.deactivate() + except: + pass + def kv_offload_generate( self, inputs: List[str] = None, @@ -829,7 +945,7 @@ def kv_offload_generate( } if vision_inputs: - vision_inputs["pixel_values"] = vision_inputs["pixel_values"].astype("float16") + vision_inputs["pixel_values"] = vision_inputs["pixel_values"].to(torch.float16).cpu().numpy() vision_start = perf_counter() vision_outputs = {} diff --git a/examples/llama4_CB_example.py b/examples/llama4_CB_example.py new file mode 100644 index 000000000..88fa21adb --- /dev/null +++ b/examples/llama4_CB_example.py @@ -0,0 +1,90 @@ +import torch +import transformers +from transformers import AutoConfig, AutoProcessor, TextStreamer + +from QEfficient import QEFFAutoModelForImageTextToText + +model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct" +config = AutoConfig.from_pretrained(model_id) +# For Testing Purpose Only +config.text_config.num_hidden_layers = 4 +config.vision_config.num_hidden_layers = 2 + +qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( + model_id, + attn_implementation="eager", + kv_offload=True, + config=config, + continuous_batching=True, +) +tokenizer = transformers.AutoTokenizer.from_pretrained(model_id) +processor = AutoProcessor.from_pretrained(model_id) + +qeff_model.compile( + prefill_seq_len=128, + ctx_len=3072, + img_size=336, + num_cores=16, + num_devices=4, + max_num_tiles=17, + batch_size=1, + full_batch_size=4, + mxfp6_matmul=True, + mxint8_kv_cache=True, + aic_enable_depth_first=True, + mos=1, +) + +image_url = ( + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png" + ) + +prompts = [ + "Can you describe the image in detail?", + # "What are the objects in the image?", + # "What is the main subject of the image?", + # "What colors are predominant in the image?", +] + +all_inputs = [] +for prompt in prompts: + messages = [ + { + "role": "user", + "content": [ + {"type": "image", "url": image_url}, + {"type": "text", "text": prompt}, + ], + }, + ] + + inputs = processor.apply_chat_template( + messages, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors="pt", + ) + inputs["pixel_values"] = inputs["pixel_values"].to(torch.float32) + all_inputs.append(inputs) + + +output = qeff_model.generate(inputs=all_inputs[0], tokenizer=tokenizer, device_ids = [0,1,2,3], prompts=prompts, generation_len=100) + +if hasattr(output, 'generated_texts'): + for i, (prompt, response) in enumerate(zip(prompts, output.generated_texts)): + print(f"Prompt {i+1}: {prompt}") + print(f"Response {i+1}: {response}") + print("-" * 30) +else: + print("Generated IDs:", output.generated_ids) + decoded_responses = tokenizer.batch_decode(output.generated_ids, skip_special_tokens=True) + for i, (prompt, response) in enumerate(zip(prompts, decoded_responses)): + print(f"Prompt {i+1}: {prompt}") + print(f"Response {i+1}: {response}") + print("-" * 30) + +# print(output.generated_ids) +# print(tokenizer.batch_decode(output.generated_ids)) +print(output) +print() From f0be4b84765f9640694ca3cb3ec74b3c2b1df040 Mon Sep 17 00:00:00 2001 From: Asmita Goswami Date: Wed, 17 Sep 2025 17:00:11 +0000 Subject: [PATCH 4/5] Updated CB for llama4 for 1 image with multiple prompt Signed-off-by: Asmita Goswami --- .../generation/text_generation_inference.py | 9 +- .../models/llama4/modeling_llama4.py | 6 +- .../transformers/models/modeling_auto.py | 141 ++++++++---------- examples/llama4_CB_example.py | 20 +-- 4 files changed, 82 insertions(+), 94 deletions(-) diff --git a/QEfficient/generation/text_generation_inference.py b/QEfficient/generation/text_generation_inference.py index a861d98ff..16bcf9e79 100755 --- a/QEfficient/generation/text_generation_inference.py +++ b/QEfficient/generation/text_generation_inference.py @@ -826,7 +826,9 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len): return decode_pause_time - def run_vision_language_continuous_batching_decode(self, prompt_queue, generation_len, shared_vision_embeddings=None): + def run_vision_language_continuous_batching_decode( + self, prompt_queue, generation_len, shared_vision_embeddings=None + ): """ Runs continuous batching decode for vision language models with shared vision embeddings. @@ -919,7 +921,9 @@ def run_vision_language_continuous_batching_decode(self, prompt_queue, generatio return decode_pause_time - def run_vision_language_prefill(self, prompt, generation_len, decode_batch_id=None, shared_vision_embeddings=None): + def run_vision_language_prefill( + self, prompt, generation_len, prefill_logit_bs=1, decode_batch_id=None, shared_vision_embeddings=None + ): """ Default method for running decode. Executes the decoding process for a given set of inputs and a specified generation length. Args: @@ -982,7 +986,6 @@ def run_vision_language_prefill(self, prompt, generation_len, decode_batch_id=No generation_len, ) - def run_decode(self, decode_inputs, generation_len, streamer: Optional[transformers.TextStreamer] = None): """ Default method for running decode. Executes the decoding process for a given set of inputs and a specified generation length. diff --git a/QEfficient/transformers/models/llama4/modeling_llama4.py b/QEfficient/transformers/models/llama4/modeling_llama4.py index 3d872f5b7..ae9cd3a01 100644 --- a/QEfficient/transformers/models/llama4/modeling_llama4.py +++ b/QEfficient/transformers/models/llama4/modeling_llama4.py @@ -848,7 +848,7 @@ def forward( # For CB with multiple prompts sharing same image, reuse vision embeds accross batches batch_size = input_ids.shape[0] - #Expanfd vision_embeds to match batch size if needed + # Expanfd vision_embeds to match batch size if needed if vision_embeds.shape[0] == 1 and batch_size > 1: vision_embeds = vision_embeds.expand(batch_size, -1, -1) @@ -858,7 +858,7 @@ def forward( # indices0 = torch.arange(selected.unsqueeze(0).shape[0]).view(-1, 1) # image_features_expanded = vision_embeds.unsqueeze(0)[indices0, indices1] - #Handle batch aware image indexing for CB + # Handle batch aware image indexing for CB if batch_size is not None: # For CB, use per-batch image indices batch_image_idx = image_idx.expand_as(selected[:, :1]) @@ -983,7 +983,7 @@ def get_specializations( vision = [ { - "batch_size": 1, # To process image only once for all batch_sizes(prompts) in continuous batching + "batch_size": 1, # To process image only once for all batch_sizes(prompts) in continuous batching "max_num_tiles": max_num_tiles, "img_size": img_size, } diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 7357047a7..99e7e07cf 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -31,6 +31,7 @@ from QEfficient.generation.text_generation_inference import ( CloudAI100ExecInfoNew, PerfMetrics, + TextGeneration, calculate_latency, get_compilation_dims, ) @@ -57,8 +58,6 @@ ) from QEfficient.utils.logging_utils import logger -from QEfficient.generation.text_generation_inference import TextGeneration - class QEFFTransformersBase(QEFFBaseModel): """ @@ -790,89 +789,73 @@ def continuous_batching_multi_prompt_generate( if not self.lang_model.qpc_path: raise TypeError("Please run compile API for language model first!") - try: - vision_session = None - lang_session = None - if self.vision_model.qpc_path: - vision_session = QAICInferenceSession(self.vision_model.qpc_path, device_ids) - - lang_session = QAICInferenceSession(self.lang_model.qpc_path, device_ids, activate=False) - - # Get compilation dimensions - batch_size, ctx_len, fbs = get_compilation_dims(self.lang_model.qpc_path) - - # Skip inputs/outputs - lang_session.skip_buffers( - [ - x - for x in lang_session.input_names + lang_session.output_names - if x.startswith("past_") or x.endswith("_RetainedState") - ] - ) + vision_session = None + lang_session = None + if self.vision_model.qpc_path: + vision_session = QAICInferenceSession(self.vision_model.qpc_path, device_ids) - # Process vision inputs once for all prompts - vision_inputs = { - k: v for k, v in inputs.items() if k in {"pixel_values", "aspect_ratio_ids", "aspect_ratio_mask"} - } - if vision_inputs: - vision_inputs["pixel_values"] = vision_inputs["pixel_values"].to(torch.float16).cpu().numpy() - vision_start = perf_counter() - vision_outputs = {} - if vision_inputs: - vision_outputs = vision_session.run(vision_inputs) - vision_end = perf_counter() - - # Deactivate vision session after use - if self.vision_model.qpc_path: - vision_session.deactivate() - - # Text generation instance for continuous batching - text_generator = TextGeneration( - tokenizer=tokenizer, - qpc_path=self.lang_model.qpc_path, - device_id=device_ids, - ctx_len=ctx_len, - enable_debug_logs=False, - full_batch_size=fbs, - ) + lang_session = QAICInferenceSession(self.lang_model.qpc_path, device_ids, activate=False) + + # Get compilation dimensions + batch_size, ctx_len, fbs = get_compilation_dims(self.lang_model.qpc_path) + + # Skip inputs/outputs + lang_session.skip_buffers( + [ + x + for x in lang_session.input_names + lang_session.output_names + if x.startswith("past_") or x.endswith("_RetainedState") + ] + ) - # Prepare prompts for CB - # Each prompt processed with same vision embeddings - tokenized_prompts = [] - for prompt in prompts: - tokenized_prompt = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True) - tokenized_prompts.append(tokenized_prompt) + # Process vision inputs once for all prompts + vision_inputs = { + k: v for k, v in inputs.items() if k in {"pixel_values", "aspect_ratio_ids", "aspect_ratio_mask"} + } + if vision_inputs: + vision_inputs["pixel_values"] = vision_inputs["pixel_values"].to(torch.float16).cpu().numpy() + vision_start = perf_counter() + vision_outputs = {} + if vision_inputs: + vision_outputs = vision_session.run(vision_inputs) + vision_end = perf_counter() + + # Deactivate vision session after use + if self.vision_model.qpc_path: + vision_session.deactivate() - # Run CB with shared vision embeddings - lang_session.activate() + # Text generation instance for continuous batching + text_generator = TextGeneration( + tokenizer=tokenizer, + qpc_path=self.lang_model.qpc_path, + device_id=device_ids, + ctx_len=ctx_len, + enable_debug_logs=False, + full_batch_size=fbs, + ) - if vision_outputs: - lang_session.set_buffers(vision_outputs) + # Prepare prompts for CB + # Each prompt processed with same vision embeddings + tokenized_prompts = [] + for prompt in prompts: + tokenized_prompt = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True) + tokenized_prompts.append(tokenized_prompt) - # Execute continuous batching generate - exec_info = text_generator.generate( - prompt=prompts, - generation_len=generation_len, - streamer=streamer is not None, - ) + # Run CB with shared vision embeddings + lang_session.activate() + + if vision_outputs: + lang_session.set_buffers(vision_outputs) + + # Execute continuous batching generate + exec_info = text_generator.generate( + prompt=prompts, + generation_len=generation_len, + streamer=streamer is not None, + ) - print("Vision encoding time (s): ", vision_end - vision_start) - return exec_info - except Exception as e: - print(f"Error in continuous batching: {str(e)}") - raise - finally: - # Clean up - if vision_session: - try: - vision_session.deactivate() - except: - pass - if lang_session: - try: - lang_session.deactivate() - except: - pass + print("Vision encoding time (s): ", vision_end - vision_start) + return exec_info def kv_offload_generate( self, diff --git a/examples/llama4_CB_example.py b/examples/llama4_CB_example.py index 88fa21adb..3fec53fdc 100644 --- a/examples/llama4_CB_example.py +++ b/examples/llama4_CB_example.py @@ -1,6 +1,6 @@ import torch import transformers -from transformers import AutoConfig, AutoProcessor, TextStreamer +from transformers import AutoConfig, AutoProcessor from QEfficient import QEFFAutoModelForImageTextToText @@ -36,8 +36,8 @@ ) image_url = ( - "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png" - ) + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png" +) prompts = [ "Can you describe the image in detail?", @@ -69,19 +69,21 @@ all_inputs.append(inputs) -output = qeff_model.generate(inputs=all_inputs[0], tokenizer=tokenizer, device_ids = [0,1,2,3], prompts=prompts, generation_len=100) +output = qeff_model.generate( + inputs=all_inputs[0], tokenizer=tokenizer, device_ids=[0, 1, 2, 3], prompts=prompts, generation_len=100 +) -if hasattr(output, 'generated_texts'): +if hasattr(output, "generated_texts"): for i, (prompt, response) in enumerate(zip(prompts, output.generated_texts)): - print(f"Prompt {i+1}: {prompt}") - print(f"Response {i+1}: {response}") + print(f"Prompt {i + 1}: {prompt}") + print(f"Response {i + 1}: {response}") print("-" * 30) else: print("Generated IDs:", output.generated_ids) decoded_responses = tokenizer.batch_decode(output.generated_ids, skip_special_tokens=True) for i, (prompt, response) in enumerate(zip(prompts, decoded_responses)): - print(f"Prompt {i+1}: {prompt}") - print(f"Response {i+1}: {response}") + print(f"Prompt {i + 1}: {prompt}") + print(f"Response {i + 1}: {response}") print("-" * 30) # print(output.generated_ids) From 0bff092ad1deb242451f01f1f82e6bf3aa9eea74 Mon Sep 17 00:00:00 2001 From: Asmita Goswami Date: Wed, 17 Sep 2025 17:06:24 +0000 Subject: [PATCH 5/5] Updated CB for llama4 for 1 image with multiple prompt Signed-off-by: Asmita Goswami --- examples/llama4_CB_example.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/examples/llama4_CB_example.py b/examples/llama4_CB_example.py index 3fec53fdc..578581a3b 100644 --- a/examples/llama4_CB_example.py +++ b/examples/llama4_CB_example.py @@ -1,3 +1,10 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- + import torch import transformers from transformers import AutoConfig, AutoProcessor @@ -41,9 +48,9 @@ prompts = [ "Can you describe the image in detail?", - # "What are the objects in the image?", - # "What is the main subject of the image?", - # "What colors are predominant in the image?", + "What are the objects in the image?", + "What is the main subject of the image?", + "What colors are predominant in the image?", ] all_inputs = []