Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
e79f9d4
[OpenVINO] Add support for Granite-4.0 family (incl. models with MoE …
rkazants Nov 7, 2025
28f2cfa
Update tests/openvino/test_decoder.py
rkazants Nov 7, 2025
695c633
Apply formatting
rkazants Nov 7, 2025
a734beb
Merge remote-tracking branch 'upstream/main' into support_granitemoeh…
rkazants Nov 7, 2025
46238a5
Fix conversion of fully attention-based Granite 4.0
rkazants Nov 9, 2025
e8967e1
Update documentation
rkazants Nov 9, 2025
d0f907e
Add tests
rkazants Nov 9, 2025
66cc1af
Update tests/openvino/utils_tests.py
rkazants Nov 9, 2025
8f3604a
Fix patching for MoE
rkazants Nov 9, 2025
284e6a2
Update tests/openvino/test_decoder.py
rkazants Nov 9, 2025
9f0004b
Support use_cache=True and stateful=False
rkazants Nov 11, 2025
1a57834
Apply formatting
rkazants Nov 11, 2025
17b6b8a
Merge remote-tracking branch 'upstream/main' into support_granitemoeh…
rkazants Dec 7, 2025
adf88a6
Apply code-formatting
rkazants Dec 7, 2025
9ae29e5
Fix syntax error
rkazants Dec 7, 2025
c0a6946
Patch update_causal_mask for GraniteMoeHybrid
rkazants Dec 8, 2025
b766fe5
Add a comment - why patch for update_causal_mask is needed
rkazants Dec 8, 2025
a0de668
Exclude beam search test for GraniteMoeHybrid
rkazants Dec 8, 2025
368f985
Update minimal supported version for GraniteMoeHybrid
rkazants Dec 9, 2025
625db96
Added a comment about lack support for beam search for hyrbid models
rkazants Dec 9, 2025
a4faee5
Simplify constructor for Zamba2HybridCacheDummyGenerator
rkazants Dec 9, 2025
9080055
Simplify Inference class for hybrid models
rkazants Dec 9, 2025
9db39e8
Update optimum/intel/openvino/modeling_decoder.py
rkazants Dec 9, 2025
53d8509
Clarify attributes in config
rkazants Dec 9, 2025
bcdbabb
Clarify attributes in modeling_decoder
rkazants Dec 9, 2025
5d7c2ca
Update optimum/exporters/openvino/__main__.py
rkazants Dec 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 71 additions & 16 deletions optimum/exporters/openvino/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@
SanaTextEncoderModelPatcher,
XverseModelPatcher,
Zamba2ModelPatcher,
GraniteMoeHybridModelPatcher,
)


Expand Down Expand Up @@ -4308,20 +4309,32 @@ def __init__(
)

config = normalized_config.config
self.intermediate_size = int(config.mamba_expand * config.hidden_size)
self.ssm_state_size = config.mamba_d_state
self.conv_kernel_size = config.mamba_d_conv
self.n_mamba_heads = config.n_mamba_heads
self.mamba_ngroups = config.mamba_ngroups
self.mamba_d_state = config.mamba_d_state
self.mamba_headdim = config.mamba_headdim
self.head_dim = config.attention_head_dim
self.hybrid_layer_ids = config.hybrid_layer_ids
logger.warning(
"The current support for the 'Zamba2' model type is experimental. "
"Performance is not optimal with high memory consumption. "
"Optimizations and improved support will be available in a future OpenVINO release."
)
if config.model_type == "zamba2":
self.intermediate_size = int(config.mamba_expand * config.hidden_size)
self.conv_kernel_size = config.mamba_d_conv
self.n_mamba_heads = config.n_mamba_heads
self.mamba_ngroups = config.mamba_ngroups
self.mamba_d_state = config.mamba_d_state
self.mamba_headdim = config.mamba_headdim
self.head_dim = config.attention_head_dim
self.num_hybrid_layers = len(config.hybrid_layer_ids)
logger.warning(
"The current support for the 'Zamba2' model type is experimental. "
"Performance is not optimal with high memory consumption. "
"Optimizations and improved support will be available in a future OpenVINO release."
)
else:
self.intermediate_size = int(config.mamba_expand * config.hidden_size)
self.conv_kernel_size = config.mamba_d_conv
self.n_mamba_heads = config.mamba_n_heads
self.mamba_ngroups = config.mamba_n_groups
self.mamba_d_state = config.mamba_d_state
self.mamba_headdim = config.mamba_d_head
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
self.num_hybrid_layers = config.layer_types.count("attention")
self.num_layers = config.layer_types.count("mamba")
self.num_attention_heads = config.num_key_value_heads
self.sequence_length = 0

def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
past_key_values = []
Expand All @@ -4333,11 +4346,11 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int
)
conv_state = self.random_float_tensor(conv_state_shape, framework=framework, dtype=float_dtype)
past_key_values.append(conv_state)
ssm_state_shape = (self.batch_size, self.n_mamba_heads, self.mamba_headdim, self.ssm_state_size)
ssm_state_shape = (self.batch_size, self.n_mamba_heads, self.mamba_headdim, self.mamba_d_state)
ssm_state = self.random_float_tensor(ssm_state_shape, framework=framework, dtype=float_dtype)
past_key_values.append(ssm_state)

for i in range(len(self.hybrid_layer_ids)):
for i in range(self.num_hybrid_layers):
kv_shape = (self.batch_size, self.num_attention_heads, self.sequence_length, self.head_dim)
k = self.random_float_tensor(kv_shape, framework=framework, dtype=float_dtype)
v = self.random_float_tensor(kv_shape, framework=framework, dtype=float_dtype)
Expand Down Expand Up @@ -4386,3 +4399,45 @@ def inputs(self) -> Dict[str, Dict[int, str]]:
if self.use_past_in_inputs:
self.add_past_key_values(common_inputs, direction="inputs")
return common_inputs


@register_in_tasks_manager("granitemoehybrid", *["text-generation", "text-generation-with-past"], library_name="transformers")
class GraniteMoeHybridOpenVINOConfig(MambaOpenVINOConfig):
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, Zamba2DummyPastKeyValuesGenerator)
DUMMY_PKV_GENERATOR_CLASS = Zamba2DummyPastKeyValuesGenerator
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
MIN_TRANSFORMERS_VERSION = "4.51.3"
_MODEL_PATCHER = GraniteMoeHybridModelPatcher

def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], direction: str):
if direction not in ["inputs", "outputs"]:
raise ValueError(f'direction must either be "inputs" or "outputs", but {direction} was given')

if direction == "inputs":
decoder_sequence_name = "past_sequence_length"
cache_name_prefix = "cache_params.past"
else:
decoder_sequence_name = "past_sequence_length + sequence_length"
cache_name_prefix = "cache_params.present"

self.num_mamba_layers = self._normalized_config.layer_types.count("mamba")
self.num_attention_layers = self._normalized_config.layer_types.count("attention")
for i in range(self.num_mamba_layers):
# [batch_size, conv_kernel_size - 1, d_model]
inputs_or_outputs[f"{cache_name_prefix}.conv.{i}"] = {0: "batch_size"}
# [batch_size, d_state, d_model]
inputs_or_outputs[f"{cache_name_prefix}.ssm.{i}"] = {0: "batch_size"}

for i in range(self.num_attention_layers):
inputs_or_outputs[f"{cache_name_prefix}.key.{i}"] = {0: "batch_size", 2: decoder_sequence_name}
inputs_or_outputs[f"{cache_name_prefix}.value.{i}"] = {0: "batch_size", 2: decoder_sequence_name}

@property
def inputs(self) -> Dict[str, Dict[int, str]]:
common_inputs = {
"input_ids": {0: "batch_size", 1: "sequence_length"},
"attention_mask": {0: "batch_size", 1: "sequence_length"},
}
if self.use_past_in_inputs:
self.add_past_key_values(common_inputs, direction="inputs")
return common_inputs
194 changes: 189 additions & 5 deletions optimum/exporters/openvino/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -6537,6 +6537,9 @@ def zamba2_mamba_mixer(
self,
hidden_states,
cache_params=None,
#attention_mask: Optional[torch.Tensor] = None,
#cache_position: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
):
def pad_tensor_by_size(input_tensor: torch.Tensor, pad_size: int):
Expand Down Expand Up @@ -6573,6 +6576,9 @@ def segment_sum(input_tensor):
return tensor_segsum

input_states = hidden_states
layer_idx = self.layer_idx
if cache_params is not None and hasattr(cache_params, "mamba_layer_idx_mapping"):
layer_idx = cache_params.mamba_layer_idx_mapping[layer_idx]

batch_size, seq_len, _ = input_states.shape
dtype = input_states.dtype
Expand All @@ -6599,7 +6605,7 @@ def segment_sum(input_tensor):
# 1. Convolution sequence transformation
# 1.1 Convolution sequence transformation for decoding step
if cache_params is not None:
conv_state_dec = cache_params.conv_states[self.layer_idx]
conv_state_dec = cache_params.conv_states[layer_idx]
conv_state_dec = torch.roll(conv_state_dec, shifts=-1, dims=-1)
conv_state_dec[:, :, -1] = hidden_states[:, 0, :] if hidden_states.ndim == 3 else hidden_states

Expand All @@ -6621,11 +6627,11 @@ def segment_sum(input_tensor):
] # [batch, intermediate_size, seq_len]
if attention_mask is not None:
# tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66
hidden_states_prefill = (hidden_states_prefill * attention_mask[:, :, None]).to(dtype)
hidden_states_prefill = (hidden_states_prefill * attention_mask[:, :seq_len, None]).to(dtype)

# Compute final conv state and set into the cache
conv_state = conv_state_prefill * (1.0 - is_decoding) + conv_state_dec * is_decoding
cache_params.conv_states[self.layer_idx].copy_(conv_state)
cache_params.conv_states[layer_idx].copy_(conv_state)
else:
hidden_states_prefill = self.act(self.conv1d(hidden_states.transpose(1, 2))[..., :seq_len].transpose(1, 2))
hidden_states_dec = hidden_states_prefill[:, :1]
Expand Down Expand Up @@ -6677,7 +6683,7 @@ def segment_sum(input_tensor):
dBx = dB * hidden_states_dec[..., None]

# State calculation
new_ssm_state_dec = cache_params.ssm_states[self.layer_idx] * dA + dBx
new_ssm_state_dec = cache_params.ssm_states[layer_idx] * dA + dBx

# Subsequent output
# [bsz, n_groups * state_size] -> [bsz, num_heads, state_size]
Expand Down Expand Up @@ -6795,7 +6801,7 @@ def segment_sum(input_tensor):
ssm_state = new_ssm_state_prefill * (1.0 - is_decoding) + new_ssm_state_dec * is_decoding

# Set final ssm state into the cache
cache_params.ssm_states[self.layer_idx].copy_(ssm_state)
cache_params.ssm_states[layer_idx].copy_(ssm_state)
else:
y = y_prefill

Expand Down Expand Up @@ -6954,3 +6960,181 @@ def __exit__(self, exc_type, exc_value, traceback):
else:
continue
mamba_layer.forward = mamba_layer._orig_forward


def granite_moe_hybrid_parallel_experts(self, inputs, expert_size):
def dynamic_split_traceable(x, expert_size):
sizes = torch.tensor(expert_size)
indices = torch.cumsum(sizes, dim=0)
starts = torch.cat([torch.zeros(1, device=sizes.device, dtype=indices.dtype), indices[:-1]])

chunks = []
for i in range(sizes.size(0)):
start = starts[i]
end = indices[i]
chunks.append(x.narrow(0, start, end - start))
return chunks

input_list = dynamic_split_traceable(inputs, expert_size)
output_list = []
for i in range(self.num_experts):
output_list.append(F.linear(input_list[i], self.weight[i]))
results = torch.cat(output_list, dim=0)
return results


class GraniteMoeHybridModelPatcher(ModelPatcher):
def __init__(
self,
config: "OnnxConfig",
model: "PreTrainedModel",
model_kwargs: Optional[Dict[str, Any]] = None,
):
from transformers.models.granitemoehybrid.modeling_granitemoehybrid import HybridMambaAttentionDynamicCache

super().__init__(config, model, model_kwargs)

class GraniteMoeHybridDynamicCacheWrap(HybridMambaAttentionDynamicCache):
def __init__(self, config, batch_size: int, conv_states, ssm_states, key_cache, value_cache):
# Call parent constructor with all required arguments
super().__init__(config=config, batch_size=batch_size)
self.conv_states = conv_states
self.ssm_states = ssm_states
self.key_cache = key_cache
self.value_cache = value_cache
self.attention_layer_idx_mapping = {}
self.mamba_layer_idx_mapping = {}
attention_layer_idx = 0
mamba_layer_idx = 0
for i in range(config.num_hidden_layers):
if self.layers_block_type[i] == "attention":
self.attention_layer_idx_mapping[i] = attention_layer_idx
attention_layer_idx += 1
elif self.layers_block_type[i] == "mamba":
self.mamba_layer_idx_mapping[i] = mamba_layer_idx
mamba_layer_idx += 1

def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[dict[str, Any]] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
# map layer_idx to key_cache (value_cache) idx
layer_idx = self.attention_layer_idx_mapping[layer_idx]
# Update the cache
if self.key_cache[layer_idx].shape[-1] == 0:
self.key_cache[layer_idx] = key_states
self.value_cache[layer_idx] = value_states
else:
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=2)
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=2)

return self.key_cache[layer_idx], self.value_cache[layer_idx]

def __getitem__(self, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor]:
layer_idx = self.attention_layer_idx_mapping[layer_idx]
return self.key_cache[layer_idx], self.value_cache[layer_idx]

def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
# take any layer that contains cache and not empty tensor
layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx
layer_idx = self.attention_layer_idx_mapping[layer_idx]
#if len(self.key_cache) <= layer_idx or self.key_cache[layer_idx].numel() == 0:
# return 0
return self.key_cache[layer_idx].shape[-2]

# the patch is needed to include KV-cache, Conv, and SSM states in the inputs and outputs.
def patched_forward(
input_ids,
attention_mask=None,
cache_params=None,
):
num_mamba_layers = self.real_config._config.layer_types.count("mamba")
num_attention_layers = self.real_config._config.layer_types.count("attention")
use_cache = False
wrapped_cache_params = None
if cache_params is not None:
use_cache = True
conv_states = []
ssm_states = []
key_cache = []
value_cache = []

# decouple ssm_states, conv_states, keys and values from cache_params
batch_size = cache_params[0].size(0)
for idx in range(num_mamba_layers):
conv_states.append(cache_params[2 * idx])
ssm_states.append(cache_params[2 * idx + 1])

for idx in range(num_attention_layers):
key_cache.append(cache_params[2 * num_mamba_layers + 2 * idx])
value_cache.append(cache_params[2 * num_mamba_layers + 2 * idx + 1])

wrapped_cache_params = GraniteMoeHybridDynamicCacheWrap(
self.real_config._config, batch_size, conv_states, ssm_states, key_cache, value_cache
)

causal_lm_output = self.model_orig_forward(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=wrapped_cache_params,
use_cache=use_cache,
)
outputs = {
"logits": causal_lm_output.logits,
}

if use_cache:
past_key_values = causal_lm_output.past_key_values
# unwrap GraniteMoeHybridDynamicCacheWrap object
present_key_values = []
for idx in range(num_mamba_layers):
present_key_values.append(past_key_values.conv_states[idx])
present_key_values.append(past_key_values.ssm_states[idx])

for idx in range(num_attention_layers):
present_key_values.append(past_key_values.key_cache[idx])
present_key_values.append(past_key_values.value_cache[idx])

outputs["present_key_values"] = present_key_values

return outputs

self.patched_forward = patched_forward
self.model_orig_forward = self.orig_forward
self.orig_forward = patched_forward

def __enter__(self):
def patch_sparse_moe(sparse_moe_layer):
sparse_moe_layer._orig_forward = sparse_moe_layer.forward
sparse_moe_layer.forward = types.MethodType(granite_moe_hybrid_parallel_experts, sparse_moe_layer)

super().__enter__()
setattr(self._model, self.orig_forward_name, self.patched_forward)

for idx, layer in enumerate(self._model.model.layers):
patch_sparse_moe(layer.block_sparse_moe.input_linear)
patch_sparse_moe(layer.block_sparse_moe.output_linear)
if self.real_config._config.layers_block_type[idx] == "mamba":
mamba_layer = layer.mamba
else:
continue
mamba_layer._orig_forward = mamba_layer.forward
mamba_layer.forward = types.MethodType(zamba2_mamba_mixer, mamba_layer)

def __exit__(self, exc_type, exc_value, traceback):
def unpatch_sparse_moe(sparse_moe_layer):
sparse_moe_layer.forward = sparse_moe_layer._orig_forward

super().__exit__(exc_type, exc_value, traceback)
setattr(self._model, self.orig_forward_name, self.model_orig_forward)
for idx, layer in enumerate(self._model.model.layers):
unpatch_sparse_moe(layer.block_sparse_moe.input_linear)
unpatch_sparse_moe(layer.block_sparse_moe.output_linear)
if self.real_config._config.layers_block_type[idx] == "mamba":
mamba_layer = layer.mamba
else:
continue
mamba_layer.forward = mamba_layer._orig_forward
2 changes: 1 addition & 1 deletion optimum/exporters/openvino/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def get_submodels(model):
"minicpmo",
]

SSM_MODELS = ["mamba", "falcon_mamba", "zamba2"]
SSM_MODELS = ["mamba", "falcon_mamba", "zamba2", "granitemoehybrid"]


def save_config(config, save_dir):
Expand Down
4 changes: 2 additions & 2 deletions tests/openvino/test_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,8 +374,8 @@ def test_compare_to_transformers(self, model_arch):
gc.collect()

@parameterized.expand(SUPPORTED_ARCHITECTURES)
@pytest.mark.run_slow
@slow
#@pytest.mark.run_slow
#@slow
def test_pipeline(self, model_arch):
set_seed(SEED)
model_kwargs = {}
Expand Down
Loading