Skip to content

Commit 783670a

Browse files
committed
Updates
Signed-off-by: Ann Kuruvilla <[email protected]>
1 parent 55f16f6 commit 783670a

File tree

4 files changed

+21
-17
lines changed

4 files changed

+21
-17
lines changed

QEfficient/transformers/models/gemma3/modeling_gemma3.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -648,6 +648,17 @@ def forward(self, input_ids, position_ids, pixel_values, image_idx, past_key_val
648648
image_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0)
649649
return outputs.logits, pixel_values, image_idx, outputs.past_key_values
650650

651+
def get_npi_file(self, model_name: str, **compiler_options):
652+
if model_name == "google/gemma-3-4b-it":
653+
compiler_options["node_precision_info"] = constants.DEFAULT_GEMMA3_4B_NODE_PRECISION_INFO
654+
elif model_name == "google/gemma-3-27b-it":
655+
compiler_options["node_precision_info"] = constants.DEFAULT_GEMMA3_27B_NODE_PRECISION_INFO
656+
else:
657+
raise ValueError(
658+
f"For Model {self.pretrained_model_name_or_path} default NPI file is not supported/added. Please use one of the following: google/gemma-3-4b-it, google/gemma-3-27b-it"
659+
)
660+
return compiler_options
661+
651662
def get_specializations(
652663
self,
653664
batch_size: int,
@@ -694,18 +705,6 @@ def get_specializations(
694705
]
695706
specializations = {}
696707

697-
# Default node precision file added for Gemma3:AI-100
698-
# if user provides a custom node precision file, it will override default one
699-
if "node_precision_info" not in compiler_options:
700-
if self.pretrained_model_name_or_path == "google/gemma-3-4b-it":
701-
compiler_options["node_precision_info"] = constants.DEFAULT_GEMMA3_4B_NODE_PRECISION_INFO
702-
elif self.pretrained_model_name_or_path == "google/gemma-3-27b-it":
703-
compiler_options["node_precision_info"] = constants.DEFAULT_GEMMA3_27B_NODE_PRECISION_INFO
704-
else:
705-
raise ValueError(
706-
f"For Model {self.pretrained_model_name_or_path} default NPI file is not supported/added. Please use one of the following: google/gemma-3-4b-it, google/gemma-3-27b-it"
707-
)
708-
709708
if kv_offload:
710709
specializations["vision"] = vision
711710
specializations["lang"] = lang

QEfficient/transformers/models/modeling_auto.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -681,6 +681,9 @@ def compile(
681681
**compiler_options,
682682
)
683683

684+
if hasattr(self.model, "get_npi_file"):
685+
compiler_options = self.model.get_npi_file(self.model.pretrained_model_name_or_path, **compiler_options)
686+
684687
custom_io_vision = {}
685688
kv_cache_dtype = "mxint8" if mxint8_kv_cache else "float16"
686689
custom_io_vision["pixel_values"] = "float16"
@@ -1030,6 +1033,9 @@ def compile(
10301033
**compiler_options,
10311034
)
10321035

1036+
if hasattr(self.model, "get_npi_file"):
1037+
self.model.get_npi_file(self.pretrained_model_name_or_path)
1038+
10331039
custom_io = {}
10341040
kv_cache_dtype = "mxint8" if mxint8_kv_cache else "float16"
10351041
# inputs

QEfficient/utils/constants.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
#
66
# -----------------------------------------------------------------------------
77

8-
import importlib.resources as pkg_resources
98
import os
109
from dataclasses import dataclass
1110

@@ -103,8 +102,8 @@ def get_models_dir():
103102
LLAMA4_MAX_POSITION_EMBEDDINGS = 65536
104103

105104
# Gemma3 Constant
106-
DEFAULT_GEMMA3_4B_NODE_PRECISION_INFO = str(pkg_resources.path(gemma3, "fp32_nodes_gemma3_4b_mm.yaml"))
107-
DEFAULT_GEMMA3_27B_NODE_PRECISION_INFO = str(pkg_resources.path(gemma3, "fp32_nodes_gemma3_27b_mm.yaml"))
105+
DEFAULT_GEMMA3_4B_NODE_PRECISION_INFO = "QEfficient/transformers/models/gemma3/fp32_nodes_gemma3_4b_mm.yaml"
106+
DEFAULT_GEMMA3_27B_NODE_PRECISION_INFO = "QEfficient/transformers/models/gemma3/fp32_nodes_gemma3_27b_mm.yaml"
108107

109108

110109
class Constants:

examples/gemma3_example/gemma3_mm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,11 @@
2222
# pass HF_TOKEN if gated model
2323
# For running the model in single QPC approach use kv_offload=False. For Dual QPC approach use kv_offload=True ###
2424
qeff_model = QEFFAutoModelForImageTextToText.from_pretrained(
25-
model_id, config=config, attn_implementation="eager", kv_offload=True
25+
model_id, config=config, attn_implementation="eager", kv_offload=False
2626
)
2727

2828
### use skip_vision=Ture, if want to run only text, or false ###
29-
skip_vision = True
29+
skip_vision = False
3030

3131
if skip_vision:
3232
## Only Text ##

0 commit comments

Comments
 (0)