From b434b499f9e214a713a68d85de092b117ebb6722 Mon Sep 17 00:00:00 2001 From: Vinayak Baddi Date: Tue, 1 Jul 2025 10:23:29 +0000 Subject: [PATCH] [Llama4]: Add support for padding pixel_value num_patches to MAX_NUM_PATCHES Signed-off-by: vbaddi --- .../models/llama4/modeling_llama4.py | 3 +++ .../transformers/models/modeling_auto.py | 18 ++++++++++++++++++ QEfficient/utils/constants.py | 1 + 3 files changed, 22 insertions(+) diff --git a/QEfficient/transformers/models/llama4/modeling_llama4.py b/QEfficient/transformers/models/llama4/modeling_llama4.py index 6b30c7804..26cb31201 100644 --- a/QEfficient/transformers/models/llama4/modeling_llama4.py +++ b/QEfficient/transformers/models/llama4/modeling_llama4.py @@ -1112,3 +1112,6 @@ def get_inputs_info(self): shape=("max_num_tiles", 3, "img_size", "img_size"), ), ] + + def get_expected_patch_count(self) -> int: + return constants.LLAMA4_NUM_PATCHES # 17 diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 2f3ee3dc0..8511d1304 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -840,6 +840,24 @@ def kv_offload_generate( if vision_inputs: vision_inputs["pixel_values"] = vision_inputs["pixel_values"].astype("float16") + + if hasattr(self.model, "get_expected_patch_count"): + try: + expected_patches = self.model.get_expected_patch_count() + if vision_inputs["pixel_values"].shape[0] != expected_patches: + logger.info( + f"Padding pixel_values from {vision_inputs['pixel_values'].shape[0]} to {expected_patches} patches" + ) + single_patch = np.expand_dims(vision_inputs["pixel_values"][0], axis=0) + while vision_inputs["pixel_values"].shape[0] < expected_patches: + vision_inputs["pixel_values"] = np.concatenate( + (vision_inputs["pixel_values"], single_patch), axis=0 + ) + except Exception as e: + logger.warning(f"Failed to get expected patch count: {e}. Proceeding with original pixel_values shape.") + else: + logger.debug("Model does not have get_expected_patch_count method. Using original pixel_values shape.") + vision_start = perf_counter() vision_outputs = {} diff --git a/QEfficient/utils/constants.py b/QEfficient/utils/constants.py index 526b01683..9acb47efa 100644 --- a/QEfficient/utils/constants.py +++ b/QEfficient/utils/constants.py @@ -92,6 +92,7 @@ def get_models_dir(): # Llama4 Constants LLAMA4_ATTENTION_CHUNK_SIZE = 8192 LLAMA4_MAX_POSITION_EMBEDDINGS = 65536 +LLAMA4_NUM_PATCHES = 17 # Gemma3 Constant GEMMA3_MAX_POSITION_EMBEDDINGS = 32768