diff --git a/keras_hub/api/models/__init__.py b/keras_hub/api/models/__init__.py index fe220e2d43..99afd527a8 100644 --- a/keras_hub/api/models/__init__.py +++ b/keras_hub/api/models/__init__.py @@ -512,6 +512,15 @@ from keras_hub.src.models.qwen3.qwen3_tokenizer import ( Qwen3Tokenizer as Qwen3Tokenizer, ) +from keras_hub.src.models.qwen3_moe.qwen3_moe_backbone import ( + Qwen3MoeBackbone as Qwen3MoeBackbone, +) +from keras_hub.src.models.qwen3_moe.qwen3_moe_causal_lm import ( + Qwen3MoeCausalLM as Qwen3MoeCausalLM, +) +from keras_hub.src.models.qwen3_moe.qwen3_moe_causal_lm_preprocessor import ( + Qwen3MoeCausalLMPreprocessor as Qwen3MoeCausalLMPreprocessor, +) from keras_hub.src.models.qwen_moe.qwen_moe_backbone import ( QwenMoeBackbone as QwenMoeBackbone, ) diff --git a/keras_hub/api/tokenizers/__init__.py b/keras_hub/api/tokenizers/__init__.py index 5bf0186287..b155d0e6e1 100644 --- a/keras_hub/api/tokenizers/__init__.py +++ b/keras_hub/api/tokenizers/__init__.py @@ -78,6 +78,9 @@ from keras_hub.src.models.qwen.qwen_tokenizer import ( QwenTokenizer as QwenTokenizer, ) +from keras_hub.src.models.qwen3_moe.qwen3_moe_tokenizer import ( + Qwen3MoeTokenizer as Qwen3MoeTokenizer, +) from keras_hub.src.models.qwen_moe.qwen_moe_tokenizer import ( QwenMoeTokenizer as QwenMoeTokenizer, ) diff --git a/keras_hub/src/models/qwen3_moe/qwen3_moe_attention.py b/keras_hub/src/models/qwen3_moe/qwen3_moe_attention.py new file mode 100644 index 0000000000..a5442e8da0 --- /dev/null +++ b/keras_hub/src/models/qwen3_moe/qwen3_moe_attention.py @@ -0,0 +1,371 @@ +import math + +import keras +from keras import ops + +from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding +from keras_hub.src.models.qwen3_moe.qwen3_moe_layernorm import Qwen3MoeLayerNorm +from keras_hub.src.utils.keras_utils import clone_initializer +from keras_hub.src.utils.keras_utils import fused_attention_op_available + + +class Qwen3MoeAttention(keras.layers.Layer): + """A multi-head attention layer for Qwen3Moe models + This attention implementation supports grouped-query attention (GQA) where + the number of key-value heads can be less than the number of query heads. + + Args: + num_query_heads: int. Number of query heads. + num_key_value_heads: int. Number of key/value heads (for GQA). + head_dim: int. The dimension of each attention head. + rope_max_wavelength: int. Maximum wavelength for RoPE (Rotary Position + Embedding). + rope_scaling_factor: float. Scaling factor for RoPE, used for extending + context length. + kernel_initializer: Initializer for the kernel weights. + dropout: float. Dropout rate for attention weights. + layer_norm_epsilon: float. The epsilon value for layer normalization. + sliding_window_size: int. Size of the sliding window for attention. + **kwargs: Additional keyword arguments to pass to the Layer. + """ + + def __init__( + self, + num_query_heads, + num_key_value_heads, + head_dim=None, + rope_max_wavelength=10000, + rope_scaling_factor=1, + kernel_initializer="glorot_uniform", + dropout=0.0, + layer_norm_epsilon=1e-6, + sliding_window_size=None, + **kwargs, + ): + super().__init__( + **kwargs, + ) + self.num_query_heads = num_query_heads + self.num_key_value_heads = num_key_value_heads + self.head_dim = head_dim + self.dropout = dropout + + self.layer_norm_epsilon = layer_norm_epsilon + + self.num_key_value_groups = num_query_heads // num_key_value_heads + self.rope_max_wavelength = rope_max_wavelength + + self.kernel_initializer = keras.initializers.get( + clone_initializer(kernel_initializer) + ) + + self.rope_scaling_factor = rope_scaling_factor + self.sliding_window_size = sliding_window_size + + def build(self, inputs_shape): + # Einsum variables: + # b = batch size + # q = query length + # k = key/value length + # m = model dim + # u = num query heads + # v = num key/value heads + # h = head dim + hidden_dim = inputs_shape[-1] + if not self.head_dim: + self.head_dim = hidden_dim // self.num_query_heads + + self._inv_norm_factor = 1.0 / math.sqrt(self.head_dim) + self._query_dense = keras.layers.EinsumDense( + equation="bqm,muh->bquh", + output_shape=(None, self.num_query_heads, self.head_dim), + kernel_initializer=self.kernel_initializer, + dtype=self.dtype_policy, + name="query", + ) + self._query_dense.build(inputs_shape) + + self._query_dense_layer_norm = Qwen3MoeLayerNorm( + epsilon=self.layer_norm_epsilon, + dtype=self.dtype_policy, + head_dim=self.head_dim, + name="query_dense_layernorm", + ) + self._query_dense_layer_norm.build(inputs_shape) + + self._key_dense = keras.layers.EinsumDense( + equation="bkm,mvh->bkvh", + output_shape=( + None, + self.num_key_value_heads, + self.head_dim, + ), + kernel_initializer=self.kernel_initializer, + dtype=self.dtype_policy, + name="key", + ) + self._key_dense.build(inputs_shape) + + self._key_dense_layer_norm = Qwen3MoeLayerNorm( + epsilon=self.layer_norm_epsilon, + dtype=self.dtype_policy, + head_dim=self.head_dim, + name="key_dense_layernorm", + ) + self._key_dense_layer_norm.build(inputs_shape) + + self._value_dense = keras.layers.EinsumDense( + equation="bkm,mvh->bkvh", + output_shape=( + None, + self.num_key_value_heads, + self.head_dim, + ), + kernel_initializer=self.kernel_initializer, + dtype=self.dtype_policy, + name="value", + ) + self._value_dense.build(inputs_shape) + + self._softmax = keras.layers.Softmax( + axis=-1, + dtype="float32", + name="attention_softmax", + ) + + self._dropout_layer = keras.layers.Dropout( + rate=self.dropout, + dtype=self.dtype_policy, + ) + + self._output_dense = keras.layers.EinsumDense( + equation="bquh,uhm->bqm", + output_shape=(None, hidden_dim), + kernel_initializer=self.kernel_initializer, + dtype=self.dtype_policy, + name="attention_output", + ) + self._output_dense.build( + (None, None, self.num_query_heads, self.head_dim) + ) + + self.rotary_embedding_layer = RotaryEmbedding( + max_wavelength=self.rope_max_wavelength, + scaling_factor=self.rope_scaling_factor, + dtype=self.dtype_policy, + ) + + self._dot_product_equation = "bquh,bkuh->buqk" + self._combine_equation = "buqk,bkuh->bquh" + + self.built = True + + def call( + self, + hidden_states, + attention_mask=None, + cache=None, + cache_update_index=None, + training=None, + ): + """Applies attention mechanism to the input hidden states. + + Args: + hidden_states: Input tensor of shape [batch_size, seq_length, + hidden_size]. + attention_mask: Mask tensor of shape [batch_size, seq_length, + seq_length]. + cache: Optional cached key and value tensors. + cache_update_index: Index at which to update the cache. + training: Boolean indicating whether in training mode. + + Returns: + attention_output: Output tensor after applying attention. + cache: Updated cache tensors (if cache is provided). + """ + start_index = ( + cache_update_index if cache_update_index is not None else 0 + ) + + query = self._query_dense(hidden_states) + query = self._query_dense_layer_norm(query) + + # Compute RoPE for queries + query = self.rotary_embedding_layer(query, start_index=start_index) + + def _compute_key_value(x): + key = self._key_dense(x) + key = self._key_dense_layer_norm(key) + key = self.rotary_embedding_layer(key, start_index=start_index) + + value = self._value_dense(x) + + return key, value + + if cache is not None: + key_cache = cache[:, 0, ...] + value_cache = cache[:, 1, ...] + if cache_update_index is None: + key = key_cache + value = value_cache + else: + key_update, value_update = _compute_key_value(hidden_states) + start = [0, cache_update_index, 0, 0] + key = ops.slice_update(key_cache, start, key_update) + value = ops.slice_update(value_cache, start, value_update) + cache = ops.stack((key, value), axis=1) + else: + if cache_update_index is not None: + raise ValueError( + "`cache_update_index` should not be set if `cache` is " + f"`None`. Received: cache={cache}, " + f"cache_update_index={cache_update_index}" + ) + key, value = _compute_key_value(hidden_states) + + # [batch_shape, seq_len, num_key_value_heads, head_dim] + # -> [batch_shape, seq_len, num_heads, head_dim] + key = ops.repeat(key, repeats=self.num_key_value_groups, axis=2) + value = ops.repeat(value, repeats=self.num_key_value_groups, axis=2) + + attention_output = self._compute_attention( + query, + key, + value, + attention_mask, + cache_update_index=cache_update_index, + ) + + attention_output = self._dropout_layer( + attention_output, training=training + ) + + attention_output = self._output_dense(attention_output) + + if cache is not None: + return attention_output, cache + return attention_output + + def _masked_softmax(self, attention_scores, attention_mask=None): + """Applies softmax with optional masking. + + Args: + attention_scores: Attention score tensor. + attention_mask: Optional mask tensor. + + Returns: + Masked softmax attention weights. + """ + if attention_mask is not None: + return self._softmax( + attention_scores, attention_mask[:, None, :, :] + ) + return self._softmax(attention_scores) + + def _compute_attention( + self, query, key, value, attention_mask=None, cache_update_index=None + ): + """Computes attention using query, key, and value tensors. + Uses Flash Attention when available for better performance. + + Args: + query: Query tensor. + key: Key tensor. + value: Value tensor. + attention_mask: Optional mask tensor. + cache_update_index: Index for sliding window computation. + + Returns: + attention_output: Output tensor after applying attention. + """ + if fused_attention_op_available(): + # Use `dot_product_attention` with Flash Attention support if + # available. + if attention_mask is not None: + attention_mask = ops.expand_dims(attention_mask, axis=1) + attention_mask = ops.cast(attention_mask, dtype="bool") + attention_output = ops.dot_product_attention( + query, + key, + value, + mask=attention_mask, + scale=self._inv_norm_factor, + ) + return attention_output + + attention_scores = ops.einsum(self._dot_product_equation, query, key) + + attention_scores = ops.multiply( + attention_scores, + ops.cast(self._inv_norm_factor, self.compute_dtype), + ) + if self.sliding_window_size: + attention_mask = self._mask_sliding_window( + attention_mask, + cache_update_index=cache_update_index + if cache_update_index is not None + else 0, + ) + attention_scores = self._masked_softmax( + attention_scores, attention_mask + ) + attention_scores = ops.cast(attention_scores, self.compute_dtype) + attention_output = ops.einsum( + self._combine_equation, attention_scores, value + ) + + return attention_output + + def _mask_sliding_window( + self, + attention_mask, + cache_update_index=0, + ): + """Creates and combines a sliding window mask with the attention mask. + + Args: + attention_mask: Original attention mask. + cache_update_index: Starting index for the sliding window. + + Returns: + Combined attention mask with sliding window constraints. + """ + _, query_len, key_len = ops.shape(attention_mask) + # Compute the sliding window for square attention. + all_ones = ops.ones((key_len, key_len), "bool") + if keras.config.backend() == "tensorflow": + # TODO: trui/tril has issues with dynamic shape on the tensorflow + # backend. We should fix, but use `band_part` for now. + import tensorflow as tf + + band_size = ops.minimum(key_len, self.sliding_window_size - 1) + band_size = ops.cast(band_size, "int32") + sliding_mask = tf.linalg.band_part(all_ones, band_size, band_size) + else: + sliding_mask = ops.triu( + all_ones, -1 * self.sliding_window_size + 1 + ) * ops.tril(all_ones, self.sliding_window_size - 1) + # Slice the window for short queries during generation. + start = (cache_update_index, 0) + sliding_mask = ops.slice(sliding_mask, start, (query_len, key_len)) + sliding_mask = ops.expand_dims(sliding_mask, 0) + return ops.logical_and(attention_mask, ops.cast(sliding_mask, "bool")) + + def get_config(self): + config = super().get_config() + config.update( + { + "num_query_heads": self.num_query_heads, + "num_key_value_heads": self.num_key_value_heads, + "rope_max_wavelength": self.rope_max_wavelength, + "rope_scaling_factor": self.rope_scaling_factor, + "kernel_initializer": keras.initializers.serialize( + self.kernel_initializer + ), + "dropout": self.dropout, + "sliding_window_size": self.sliding_window_size, + "head_dim": self.head_dim, + "layer_norm_epsilon": self.layer_norm_epsilon, + } + ) + return config diff --git a/keras_hub/src/models/qwen3_moe/qwen3_moe_backbone.py b/keras_hub/src/models/qwen3_moe/qwen3_moe_backbone.py new file mode 100644 index 0000000000..baeddd7673 --- /dev/null +++ b/keras_hub/src/models/qwen3_moe/qwen3_moe_backbone.py @@ -0,0 +1,365 @@ +import keras +from keras import ops + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.layers.modeling.reversible_embedding import ( + ReversibleEmbedding, +) +from keras_hub.src.models.backbone import Backbone +from keras_hub.src.models.qwen3_moe.qwen3_moe_decoder import ( + Qwen3MoeTransformerDecoder, +) +from keras_hub.src.models.qwen3_moe.qwen3_moe_layernorm import Qwen3MoeLayerNorm + + +def _qwen3_moe_kernel_initializer(stddev=0.02): + return keras.initializers.RandomNormal(stddev=stddev) + + +@keras_hub_export( + "keras_hub.models.Qwen3MoeBackbone", +) +class Qwen3MoeBackbone(Backbone): + """Qwen3 MoE core network with hyperparameters. + + This backbone implements the base Transformer network for the Qwen MoE + model. It includes embedding lookups and transformer layers with a Mixture + of Experts (MoE) architecture, where each layer uses a sparse set of experts + for efficient computation. This backbone outputs the final hidden states for + each token, not generative predictions over the vocabulary space. For higher + -level object for text generation, see `keras_hub.models.Qwen3MoeCausalLM`. + + The default constructor gives a fully customizable, randomly initialized + Qwen MoE model with any number of layers, heads, and embedding dimensions. + To load preset architectures and weights, use the `from_preset` constructor. + + Args: + vocabulary_size: int. The size of the token vocabulary. + num_layers: int. The number of transformer layers. + num_query_heads: int. The number of heads for the query projections in + the attention layer. + num_key_value_heads: int. The number of heads for the key and value + projections in the attention layer. + hidden_dim: int. The size of the transformer hidden state at the end of + each transformer layer. + intermediate_dim: int. The output dimension of the first Dense layer in + the feedforward network for each transformer. + moe_intermediate_dim: int. The intermediate dimension for each expert + in the MoE feedforward network. + num_experts: int. The number of experts in each MoE layer. + top_k: int. The number of top experts to select for each token in the + MoE layer. + head_dim: int. The size of each attention head. + layer_norm_epsilon: float. The epsilon value used for every layer norm + in the transformer model. + dropout: float. Dropout probability for the transformer encoder. + sliding_window_size: int. Size of the sliding local window. Defaults to + 4096. + max_sequence_length: int. The maximum sequence length supported by the + model. Defaults to 4096. + dtype: str or `keras.mixed_precision.DTypePolicy`. The dtype to use for + the model's computations and weights. Note that some computations, + such as softmax and layer normalization, will always be done at + float32 precision regardless of dtype. + + Example: + ```python + input_data = { + "token_ids": np.ones(shape=(1, 12), dtype="int32"), + "padding_mask": np.array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]), + } + + # Pretrained Qwen MoE decoder. + model = keras_hub.models.Qwen3MoeBackbone.from_preset("qwen3_moe_a2_7b") + model(input_data) + + # Randomly initialized Qwen MoE decoder with custom config. + model = keras_hub.models.Qwen3MoeBackbone( + vocabulary_size=151936, + num_layers=28, + num_query_heads=16, + num_key_value_heads=8, + hidden_dim=2048, + intermediate_dim=4096, + moe_intermediate_dim=128, + num_experts=60, + top_k=4, + head_dim=128, + max_sequence_length=4096, + ) + model(input_data) + """ + + def __init__( + self, + vocabulary_size, + num_layers, + num_query_heads, + num_key_value_heads, + hidden_dim, + intermediate_dim, + moe_intermediate_dim, + num_experts, + head_dim=None, + top_k=4, + norm_top_k_prob=False, + decoder_sparse_step=1, + rope_max_wavelength=10000, + rope_scaling_factor=1.0, + layer_norm_epsilon=1e-6, + dropout=0, + dtype=None, + tie_word_embeddings=False, + sliding_window_size=32768, + router_aux_loss_coefficient=0.001, + mlp_only_layers=None, + training=None, + **kwargs, + ): + # === Layers === + self.token_embedding = ReversibleEmbedding( + input_dim=vocabulary_size, + output_dim=hidden_dim, + tie_weights=tie_word_embeddings, + embeddings_initializer=_qwen3_moe_kernel_initializer(stddev=0.01), + dtype=dtype, + name="token_embedding", + ) + + if not mlp_only_layers: + mlp_only_layers = [] + + self.transformer_layers = [] + for i in range(num_layers): + is_sparse_mlp = ( + (i not in mlp_only_layers) + and num_experts > 0 + and (i + 1) % decoder_sparse_step == 0 + ) + layer = Qwen3MoeTransformerDecoder( + intermediate_dim=intermediate_dim, + num_query_heads=num_query_heads, + num_key_value_heads=num_key_value_heads, + moe_intermediate_dim=moe_intermediate_dim, + head_dim=head_dim, + num_experts=num_experts, + top_k=top_k, + norm_top_k_prob=norm_top_k_prob, + rope_max_wavelength=rope_max_wavelength, + rope_scaling_factor=rope_scaling_factor, + layer_norm_epsilon=layer_norm_epsilon, + activation=ops.silu, + kernel_initializer=_qwen3_moe_kernel_initializer(stddev=0.02), + dropout=dropout, + dtype=dtype, + sliding_window_size=sliding_window_size, + router_aux_loss_coefficient=router_aux_loss_coefficient, + is_sparse_mlp=is_sparse_mlp, + name=f"transformer_layer_{i}", + ) + self.transformer_layers.append(layer) + self.layer_norm = Qwen3MoeLayerNorm( + epsilon=layer_norm_epsilon, + dtype=dtype, + name="sequence_output_layernorm", + ) + + # === Functional Model === + token_id_input = keras.Input( + shape=(None,), dtype="int32", name="token_ids" + ) + padding_mask_input = keras.Input( + shape=(None,), dtype="int32", name="padding_mask" + ) + x = self.token_embedding(token_id_input) + for transformer_layer in self.transformer_layers: + x = transformer_layer( + x, decoder_padding_mask=padding_mask_input, training=training + ) + sequence_output = self.layer_norm(x) + super().__init__( + inputs={ + "token_ids": token_id_input, + "padding_mask": padding_mask_input, + }, + outputs=sequence_output, + dtype=dtype, + **kwargs, + ) + + # === Config === + self.vocabulary_size = vocabulary_size + self.num_layers = num_layers + self.num_query_heads = num_query_heads + self.hidden_dim = hidden_dim + self.intermediate_dim = intermediate_dim + self.moe_intermediate_dim = moe_intermediate_dim + self.head_dim = head_dim + self.rope_max_wavelength = rope_max_wavelength + self.num_key_value_heads = num_key_value_heads + self.rope_scaling_factor = rope_scaling_factor + self.layer_norm_epsilon = layer_norm_epsilon + self.dropout = dropout + self.tie_word_embeddings = tie_word_embeddings + self.sliding_window_size = sliding_window_size + self.num_experts = num_experts + self.top_k = top_k + self.norm_top_k_prob = norm_top_k_prob + self.decoder_sparse_step = decoder_sparse_step + self.mlp_only_layers = mlp_only_layers + self.router_aux_loss_coefficient = router_aux_loss_coefficient + + def get_config(self): + config = super().get_config() + config.update( + { + "vocabulary_size": self.vocabulary_size, + "num_layers": self.num_layers, + "num_query_heads": self.num_query_heads, + "head_dim": self.head_dim, + "hidden_dim": self.hidden_dim, + "intermediate_dim": self.intermediate_dim, + "moe_intermediate_dim": self.moe_intermediate_dim, + "rope_max_wavelength": self.rope_max_wavelength, + "num_key_value_heads": self.num_key_value_heads, + "rope_scaling_factor": self.rope_scaling_factor, + "layer_norm_epsilon": self.layer_norm_epsilon, + "dropout": self.dropout, + "tie_word_embeddings": self.tie_word_embeddings, + "sliding_window_size": self.sliding_window_size, + "num_experts": self.num_experts, + "top_k": self.top_k, + "norm_top_k_prob": self.norm_top_k_prob, + "decoder_sparse_step": self.decoder_sparse_step, + "mlp_only_layers": self.mlp_only_layers, + "router_aux_loss_coefficient": self.router_aux_loss_coefficient, + } + ) + return config + + @staticmethod + def get_layout_map( + device_mesh, + model_parallel_dim_name="model", + data_parallel_dim_name="batch", + ): + """Get a `keras.distribution.LayoutMap` for model parallel distribution. + + The returned `LayoutMap` contains the sharding spec for the Qwen3Moe + backbone weights, so that you can use it to distribute weights across + the accelerators. + + Example: + ``` + # Feel free to change the mesh shape to balance data and model + # parallelism + mesh = keras.distribution.DeviceMesh( + shape=(1, 8), + axis_names=('batch', 'model'), + devices=keras.distribution.list_devices(), + ) + layout_map = Qwen3MoeBackbone.get_layout_map( + mesh, + model_parallel_dim_name="model", + ) + + distribution = keras.distribution.ModelParallel( + layout_map=layout_map, + batch_dim_name='batch', + ) + + with distribution.scope(): + qwen3_moe_model = keras_hub.models.Qwen3MoeBackbone.from_preset() + ``` + + To see how the layout map was applied, load the model then run + (for one decoder block): + ``` + embedding_layer = qwen3_moe_model.backbone.get_layer("token_embedding") + decoder_block_1 = qwen3_moe_model.backbone.get_layer( + 'transformer_layer_0' + ) + for variable in embedding_layer.weights + decoder_block_1.weights: + print( + f'{variable.path:<58} {str(variable.shape):<16} ' + f'{str(variable.value.sharding.spec)}' + ) + ``` + + Args: + device_mesh: The `keras.distribution.DeviceMesh` instance for + distribution. + model_parallel_dim_name: The axis name of the device mesh, where + the weights should be partition on. + data_parallel_dim_name: The axis name of the device mesh, where + the data should be partition on. + Return: + `keras.distribution.LayoutMap` that contains the sharding spec + for all the model weights. + """ + # The weight path and shape of the Llama backbone is like below + # token_embedding/embeddings (128256, 2048) + # repeat block for decoder + # transformer_layer_0/self_attention/query/kernel (2048, 32, 64) + # transformer_layer_0/self_attention/key/kernel (2048, 8, 64) + # transformer_layer_0/self_attention/value/kernel (2048, 8, 64) + # transformer_layer_0/self_attention/attention_output/kernel + # (32, 64, 2048) + # transformer_layer_0/self_attention_layernorm/scale (2048,) + # transformer_layer_0/feedforward_intermediate_dense/kernel + # (2048, 8192) + # transformer_layer_0/feedforward_gate_dense/kernel (2048, 8192) + # transformer_layer_0/feedforward_output_dense/kerne (8192, 2048) + # transformer_layer_0/feedforward_layernorm/scale (2048,) + + if not isinstance(device_mesh, keras.distribution.DeviceMesh): + raise ValueError( + "Invalid device_mesh type. Expected " + f"`keras.distribution.Device`, got {type(device_mesh)}" + ) + if model_parallel_dim_name not in device_mesh.axis_names: + raise ValueError( + f"{model_parallel_dim_name} is not found in the " + f"device_mesh.axis_names. {device_mesh.axis_name=}" + ) + if data_parallel_dim_name not in device_mesh.axis_names: + raise ValueError( + f"{data_parallel_dim_name} is not found in the " + f"device_mesh.axis_names. {device_mesh.axis_name=}" + ) + # Note that it is possible to further config the mesh to be 3D, eg + # (data, seq, model). We leave it as 2D for now for simplicity. + data_dim = data_parallel_dim_name + model_dim = model_parallel_dim_name + # The sharding config is based on the Gemma team training config. + # See https://arxiv.org/abs/2403.08295 + layout_map = keras.distribution.LayoutMap(device_mesh) + layout_map["token_embedding/embeddings"] = (model_dim, data_dim) + layout_map[ + "transformer_layer.*self_attention.*(query|key|value).kernel" + ] = ( + model_dim, + data_dim, + None, + ) + layout_map["transformer_layer.*attention_output.kernel"] = ( + model_dim, + None, + data_dim, + ) + layout_map[ + "transformer_layer.*feedforward_intermediate_dense.kernel" + ] = ( + data_dim, + model_dim, + ) + layout_map["transformer_layer.*feedforward_gate_dense.kernel"] = ( + data_dim, + model_dim, + ) + layout_map["transformer_layer.*feedforward_output_dense.kernel"] = ( + model_dim, + data_dim, + ) + + return layout_map diff --git a/keras_hub/src/models/qwen3_moe/qwen3_moe_backbone_test.py b/keras_hub/src/models/qwen3_moe/qwen3_moe_backbone_test.py new file mode 100644 index 0000000000..cdfd5440e9 --- /dev/null +++ b/keras_hub/src/models/qwen3_moe/qwen3_moe_backbone_test.py @@ -0,0 +1,69 @@ +import pytest +from keras import ops + +from keras_hub.src.models.qwen3_moe.qwen3_moe_backbone import Qwen3MoeBackbone +from keras_hub.src.tests.test_case import TestCase + + +class Qwen3MoeBackboneTest(TestCase): + def setUp(self): + self.init_kwargs = { + "vocabulary_size": 20, + "num_layers": 2, + "num_query_heads": 4, + "num_key_value_heads": 2, + "hidden_dim": 16, + "intermediate_dim": 32, + "head_dim": 2, + "moe_intermediate_dim": 16, + "num_experts": 4, + "top_k": 2, + "norm_top_k_prob": True, + "decoder_sparse_step": 1, + "layer_norm_epsilon": 1e-6, + "rope_max_wavelength": 10000, + "rope_scaling_factor": 1.0, + "dropout": 0.0, + "sliding_window_size": 4096, + "router_aux_loss_coefficient": 0.01, + "tie_word_embeddings": False, + "mlp_only_layers": [], + "dtype": "float32", # Explicitly set dtype to avoid mixed precision + } + self.input_data = { + "token_ids": ops.ones((2, 7), dtype="int32"), + "padding_mask": ops.ones((2, 7), dtype="int32"), + } + + def test_backbone_basics(self): + self.run_backbone_test( + cls=Qwen3MoeBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output_shape=(2, 7, 16), + run_quantization_check=False, + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=Qwen3MoeBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) + + def test_architecture_characteristics(self): + model = Qwen3MoeBackbone(**self.init_kwargs) + expected_params = 7768 + self.assertEqual(model.count_params(), expected_params) + expected_layers = 6 + self.assertEqual(len(model.layers), expected_layers) + + def test_auxiliary_loss(self): + model = Qwen3MoeBackbone(**self.init_kwargs) + _ = model(self.input_data, training=True) + self.assertTrue( + len(model.losses) > 0, "Auxiliary losses should be present" + ) + for loss in model.losses: + self.assertGreater(loss, 0.0, "Auxiliary loss should be positive") diff --git a/keras_hub/src/models/qwen3_moe/qwen3_moe_causal_lm.py b/keras_hub/src/models/qwen3_moe/qwen3_moe_causal_lm.py new file mode 100644 index 0000000000..198e3af697 --- /dev/null +++ b/keras_hub/src/models/qwen3_moe/qwen3_moe_causal_lm.py @@ -0,0 +1,357 @@ +import keras +from keras import ops + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.causal_lm import CausalLM +from keras_hub.src.models.qwen3_moe.qwen3_moe_backbone import Qwen3MoeBackbone +from keras_hub.src.models.qwen3_moe.qwen3_moe_causal_lm_preprocessor import ( + Qwen3MoeCausalLMPreprocessor, +) +from keras_hub.src.utils.tensor_utils import any_equal + + +@keras_hub_export( + "keras_hub.models.Qwen3MoeCausalLM", +) +class Qwen3MoeCausalLM(CausalLM): + """An end-to-end Qwen3 MoE model for causal language modeling. + + A causal language model (LM) predicts the next token based on previous + tokens. This task setup can be used to train the model unsupervised on plain + text input, or to autoregressively generate plain text similar to the data + used for training. This task can be used for pre-training or fine-tuning a + Qwen3 MoE model, simply by calling `fit()`. + + This model has a `generate()` method, which generates text based on a + prompt. The generation strategy used is controlled by an additional + `sampler` argument on `compile()`. You can recompile the model with + different `keras_hub.samplers` objects to control the generation. + By default, `"greedy"` sampling will be used. + + This model can optionally be configured with a `preprocessor` layer, in + which case it will automatically apply preprocessing to string inputs during + `fit()`, `predict()`, `evaluate()`, and `generate()`. This is done by + default when creating the model with `from_preset()`. + + The Qwen3 MoE architecture leverages a Mixture of Experts (MoE) design, + where each transformer layer uses a sparse set of experts to process tokens + efficiently, making it suitable for large-scale language tasks with + optimized computational resources. + + Args: + backbone: A `keras_hub.models.Qwen3MoeBackbone` instance. + preprocessor: A `keras_hub.models.Qwen3MoeCausalLMPreprocessor` or + `None`. If `None`, this model will not apply preprocessing, and + inputs should be preprocessed before calling the model. + + Examples: + + Use `generate()` to do text generation. + ```python + qwen3_moe_lm = keras_hub.models.Qwen3MoeCausalLM.from_preset( + "qwen3_moe_3b_en" + ) + qwen3_moe_lm.generate("I want to say", max_length=30) + + # Generate with batched prompts. + qwen3_moe_lm.generate(["This is a", "Where are you"], max_length=30) + ``` + + Compile the `generate()` function with a custom sampler. + ```python + qwen3_moe_lm = keras_hub.models.Qwen3MoeCausalLM.from_preset( + "qwen3_moe_3b_en" + ) + qwen3_moe_lm.compile(sampler="top_k") + qwen3_moe_lm.generate("I want to say", max_length=30) + + qwen3_moe_lm.compile(sampler=keras_hub.samplers.BeamSampler(num_beams=2)) + qwen3_moe_lm.generate("I want to say", max_length=30) + ``` + + Use `generate()` without preprocessing. + ```python + prompt = { + # Token ids for " Qwen3 is". + "token_ids": np.array([[2, 12345, 678, 0, 0, 0, 0]] * 2), + # Use `"padding_mask"` to indicate values that should not be overridden. + "padding_mask": np.array([[1, 1, 1, 0, 0, 0, 0]] * 2), + } + + qwen3_moe_lm = keras_hub.models.Qwen3MoeCausalLM.from_preset( + "qwen3_moe_a2_7b", + preprocessor=None, + ) + qwen3_moe_lm.generate(prompt) + ``` + + Call `fit()` on a single batch. + ```python + features = ["The quick brown fox jumped.", "I forgot my homework."] + qwen3_moe_lm = keras_hub.models.Qwen3MoeCausalLM.from_preset( + "qwen3_moe_3b_en" + ) + qwen3_moe_lm.fit(x=features, batch_size=2) + ``` + + Call `fit()` with LoRA fine-tuning enabled. + ```python + features = ["The quick brown fox jumped.", "I forgot my homework."] + qwen3_moe_lm = keras_hub.models.Qwen3MoeCausalLM.from_preset( + "qwen3_moe_3b_en" + ) + qwen3_moe_lm.backbone.enable_lora(rank=4) + qwen3_moe_lm.fit(x=features, batch_size=2) + ``` + + Call `fit()` without preprocessing. + ```python + x = { + # Token ids for " Qwen3 is a language model" + "token_ids": np.array([[2, 12345, 678, 543, 9876, 1, 0, 0]] * 2), + "padding_mask": np.array([[1, 1, 1, 1, 1, 1, 0, 0]] * 2), + } + y = np.array([[12345, 678, 543, 9876, 1, 0, 0, 0]] * 2) + sw = np.array([[1, 1, 1, 1, 1, 0, 0, 0]] * 2) + + qwen3_moe_lm = keras_hub.models.Qwen3MoeCausalLM.from_preset( + "qwen3_moe_a2_7b", + preprocessor=None, + ) + qwen3_moe_lm.fit(x=x, y=y, sample_weight=sw, batch_size=2) + ``` + + Custom backbone and vocabulary. + ```python + tokenizer = keras_hub.models.Qwen3MoeTokenizer( + proto="qwen3_moe_vocab.spm", + ) + preprocessor = keras_hub.models.Qwen3MoeCausalLMPreprocessor( + tokenizer=tokenizer, + sequence_length=128, + ) + backbone = keras_hub.models.Qwen3MoeBackbone( + vocabulary_size=151936, + num_layers=28, + num_query_heads=16, + num_key_value_heads=8, + hidden_dim=2048, + intermediate_dim=4096, + moe_intermediate_dim=128, + num_experts=60, + top_k=4, + max_sequence_length=4096, + ) + qwen3_moe_lm = keras_hub.models.Qwen3MoeCausalLM( + backbone=backbone, + preprocessor=preprocessor, + ) + qwen3_moe_lm.fit(x=features, batch_size=2) + ``` + """ + + backbone_cls = Qwen3MoeBackbone + preprocessor_cls = Qwen3MoeCausalLMPreprocessor + + def __init__(self, backbone, preprocessor=None, **kwargs): + # === Layers === + self.backbone = backbone + self.preprocessor = preprocessor + + # === Functional Model === + # This must be "backbone.input" i.e. the full input structure, + # rather than "backbone.inputs" which is the flattened list of inputs. + inputs = backbone.input + hidden_states = backbone(inputs) + outputs = backbone.token_embedding(hidden_states, reverse=True) + super().__init__( + inputs=inputs, + outputs=outputs, + **kwargs, + ) + + def call_with_cache( + self, + token_ids, + cache, + cache_update_index, + ): + """Forward pass of `Qwen3MoeCausalLM` with cache. + + `call_with_cache` adds an additional forward pass for the model for + autoregressive inference. Unlike calling the model directly, this method + allows caching previous key/value Tensors in multi-head attention layer, + and avoids recomputing the outputs of seen tokens. + + Args: + token_ids: a dense int Tensor with shape `(batch_size, max_length)`. + cache: a dense float Tensor, the cache of key and value. + cache_update_index: int, or int Tensor. The index of current inputs + in the whole sequence. + + Returns: + A (logits, hidden_states, cache) tuple. Where `logits` is the + language model logits for the input token_ids, `hidden_states` is + the final hidden representation of the input tokens, and `cache` is + the decoding cache. + """ + x = self.backbone.token_embedding(token_ids) + # Each decoder layer has a cache; we update them separately. + updated_cache = [] + for i in range(self.backbone.num_layers): + current_cache = cache[:, i, ...] + x, next_cache = self.backbone.transformer_layers[i]( + x, + self_attention_cache=current_cache, + self_attention_cache_update_index=cache_update_index, + ) + updated_cache.append(next_cache) + cache = ops.stack(updated_cache, axis=1) + hidden_states = x = self.backbone.layer_norm(x) + logits = self.backbone.token_embedding(x, reverse=True) + return logits, hidden_states, cache + + def _build_cache(self, token_ids): + """Build an empty cache for use with `call_with_cache()`.""" + batch_size = ops.shape(token_ids)[0] + max_length = ops.shape(token_ids)[1] + num_layers = self.backbone.num_layers + num_key_value_heads = self.backbone.num_key_value_heads + head_dim = self.backbone.head_dim + shape = [ + batch_size, + num_layers, + 2, + max_length, + num_key_value_heads, + head_dim, + ] + cache = ops.zeros(shape, dtype=self.compute_dtype) + # Seed the cache. + _, hidden_states, cache = self.call_with_cache(token_ids, cache, 0) + return hidden_states, cache + + def generate_step( + self, + inputs, + stop_token_ids=None, + ): + """A compilable generation function for a single batch of inputs. + + This function represents the inner, XLA-compilable, generation function + for a single batch of inputs. Inputs should have the same structure as + model inputs, a dictionary with keys `"token_ids"` and `"padding_mask"`. + + Args: + inputs: A dictionary with two keys `"token_ids"` and + `"padding_mask"` and batched tensor values. + stop_token_ids: Tuple of id's of the end token to stop on. If all + sequences have produced a new stop token, generation + will stop. + """ + token_ids, padding_mask = inputs["token_ids"], inputs["padding_mask"] + # Create and seed cache with a single forward pass. + hidden_states, cache = self._build_cache(token_ids) + # Compute the lengths of all user inputted tokens ids. + row_lengths = ops.sum(ops.cast(padding_mask, "int32"), axis=-1) + # Start at the first index that has no user inputted id. + index = ops.min(row_lengths) + + def next(prompt, cache, index): + # The cache index is the index of our previous token. + cache_update_index = index - 1 + batch_size = ops.shape(prompt)[0] + prompt = ops.slice(prompt, [0, cache_update_index], [batch_size, 1]) + logits, hidden_states, cache = self.call_with_cache( + prompt, + cache, + cache_update_index, + ) + return ( + ops.squeeze(logits, axis=1), + ops.squeeze(hidden_states, axis=1), + cache, + ) + + token_ids = self.sampler( + next=next, + prompt=token_ids, + cache=cache, + index=index, + mask=padding_mask, + stop_token_ids=stop_token_ids, + hidden_states=hidden_states, + model=self, + ) + + # Compute an output padding mask with the token ids we updated. + if stop_token_ids is not None: + # Build a mask of stop token locations not in the original + # prompt (not in locations where `padding_mask` is True). + end_locations = any_equal( + token_ids, stop_token_ids, ops.logical_not(padding_mask) + ) + end_locations = ops.cast(end_locations, "int32") + # Use cumsum to get ones in all locations after end_locations. + cumsum = ops.cast(ops.cumsum(end_locations, axis=-1), "int32") + overflow = cumsum - end_locations + # Our padding mask is the inverse of these overflow locations. + padding_mask = ops.logical_not(ops.cast(overflow, "bool")) + else: + # Without early stopping, all locations will have been updated. + padding_mask = ops.ones_like(token_ids, dtype="bool") + return { + "token_ids": token_ids, + "padding_mask": padding_mask, + } + + def score( + self, + token_ids, + padding_mask=None, + scoring_mode="logits", + layer_intercept_fn=None, + target_ids=None, + ): + if scoring_mode not in ("logits", "loss"): + raise ValueError( + "Unsupported scoring_mode. Must be one of 'logits' or 'loss'." + ) + + if scoring_mode == "loss" and target_ids is None: + raise ValueError( + "Cannot compute loss without targets. Please provide target " + "token ids via the target_ids parameter." + ) + + batch_shape = ops.shape(token_ids)[:2] + assert len(batch_shape) == 2 + + if padding_mask is None: + padding_mask = ops.ones(shape=batch_shape) + + if layer_intercept_fn is None: + + def default_layer_intercept_fn(x, unused_i): + return x + + layer_intercept_fn = default_layer_intercept_fn + + token_embeddings = self.backbone.token_embedding(token_ids) + x = layer_intercept_fn(token_embeddings, -1) + + for i, transformer_layer in enumerate(self.backbone.transformer_layers): + x = transformer_layer(x, decoder_padding_mask=padding_mask) + x = layer_intercept_fn(x, i) + + x = self.backbone.layer_norm(x) + logits = self.backbone.token_embedding(x, reverse=True) + + if scoring_mode == "logits": + return logits + + per_token_loss_fn = keras.losses.SparseCategoricalCrossentropy( + from_logits=True, reduction="none" + ) + per_token_loss = per_token_loss_fn(target_ids, logits) + return per_token_loss diff --git a/keras_hub/src/models/qwen3_moe/qwen3_moe_causal_lm_preprocessor.py b/keras_hub/src/models/qwen3_moe/qwen3_moe_causal_lm_preprocessor.py new file mode 100644 index 0000000000..f6f52f1e79 --- /dev/null +++ b/keras_hub/src/models/qwen3_moe/qwen3_moe_causal_lm_preprocessor.py @@ -0,0 +1,12 @@ +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.causal_lm_preprocessor import CausalLMPreprocessor +from keras_hub.src.models.qwen3_moe.qwen3_moe_backbone import Qwen3MoeBackbone +from keras_hub.src.models.qwen3_moe.qwen3_moe_tokenizer import Qwen3MoeTokenizer + + +@keras_hub_export( + "keras_hub.models.Qwen3MoeCausalLMPreprocessor", +) +class Qwen3MoeCausalLMPreprocessor(CausalLMPreprocessor): + backbone_cls = Qwen3MoeBackbone + tokenizer_cls = Qwen3MoeTokenizer diff --git a/keras_hub/src/models/qwen3_moe/qwen3_moe_causal_lm_preprocessor_test.py b/keras_hub/src/models/qwen3_moe/qwen3_moe_causal_lm_preprocessor_test.py new file mode 100644 index 0000000000..180c5f64ed --- /dev/null +++ b/keras_hub/src/models/qwen3_moe/qwen3_moe_causal_lm_preprocessor_test.py @@ -0,0 +1,68 @@ +from keras_hub.src.models.qwen3_moe.qwen3_moe_causal_lm_preprocessor import ( + Qwen3MoeCausalLMPreprocessor, +) +from keras_hub.src.models.qwen3_moe.qwen3_moe_tokenizer import Qwen3MoeTokenizer +from keras_hub.src.tests.test_case import TestCase + + +class Qwen3MoeCausalLMPreprocessorTest(TestCase): + def setUp(self): + self.vocab = ["!", "air", "Ġair", "plane", "Ġat", "port"] + self.vocab += ["<|im_end|>", "<|endoftext|>"] + self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)]) + self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"] + self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"] + self.merges += ["Ġai r", "Ġa i", "pla ne"] + self.tokenizer = Qwen3MoeTokenizer( + vocabulary=self.vocab, + merges=self.merges, + ) + self.init_kwargs = { + "tokenizer": self.tokenizer, + "sequence_length": 8, + } + self.input_data = ["airplane at airport"] + + def test_preprocessor_basics(self): + self.run_preprocessor_test( + cls=Qwen3MoeCausalLMPreprocessor, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output=( + { + "token_ids": [[1, 3, 4, 2, 5, 6, 7, 7]], + "padding_mask": [[1, 1, 1, 1, 1, 1, 0, 0]], + }, + [[3, 4, 2, 5, 6, 7, 7, 7]], + [[1, 1, 1, 1, 1, 0, 0, 0]], + ), + ) + + def test_with_start_end_token(self): + input_data = ["airplane at airport"] * 4 + preprocessor = Qwen3MoeCausalLMPreprocessor( + **self.init_kwargs, + add_start_token=True, + add_end_token=True, + ) + x, y, sw = preprocessor(input_data) + self.assertAllEqual(x["token_ids"], [[1, 3, 4, 2, 5, 6, 7, 7]] * 4) + self.assertAllEqual(x["padding_mask"], [[1, 1, 1, 1, 1, 1, 0, 0]] * 4) + self.assertAllEqual(y, [[3, 4, 2, 5, 6, 7, 7, 7]] * 4) + self.assertAllEqual(sw, [[1, 1, 1, 1, 1, 0, 0, 0]] * 4) + + def test_generate_preprocess(self): + input_data = "airplane at airport" + preprocessor = Qwen3MoeCausalLMPreprocessor(**self.init_kwargs) + x = preprocessor.generate_preprocess(input_data) + self.assertAllEqual(x["token_ids"], [1, 3, 4, 2, 5, 7, 7, 7]) + self.assertAllEqual(x["padding_mask"], [1, 1, 1, 1, 1, 0, 0, 0]) + + def test_generate_postprocess(self): + input_data = { + "token_ids": [1, 3, 4, 2, 5, 7, 7, 7], + "padding_mask": [1, 1, 1, 1, 1, 0, 0, 0], + } + preprocessor = Qwen3MoeCausalLMPreprocessor(**self.init_kwargs) + x = preprocessor.generate_postprocess(input_data) + self.assertAllEqual(x, "airplane at airport") diff --git a/keras_hub/src/models/qwen3_moe/qwen3_moe_causal_lm_test.py b/keras_hub/src/models/qwen3_moe/qwen3_moe_causal_lm_test.py new file mode 100644 index 0000000000..d342c1e165 --- /dev/null +++ b/keras_hub/src/models/qwen3_moe/qwen3_moe_causal_lm_test.py @@ -0,0 +1,130 @@ +import os +from unittest.mock import patch + +os.environ["KERAS_BACKEND"] = "jax" + +import pytest +from keras import ops + +from keras_hub.src.models.qwen3_moe.qwen3_moe_backbone import Qwen3MoeBackbone +from keras_hub.src.models.qwen3_moe.qwen3_moe_causal_lm import Qwen3MoeCausalLM +from keras_hub.src.models.qwen3_moe.qwen3_moe_causal_lm_preprocessor import ( + Qwen3MoeCausalLMPreprocessor, +) +from keras_hub.src.models.qwen3_moe.qwen3_moe_tokenizer import Qwen3MoeTokenizer +from keras_hub.src.tests.test_case import TestCase + + +class Qwen3MoeCausalLMTest(TestCase): + def setUp(self): + self.vocab = ["!", "air", "Ġair", "plane", "Ġat", "port"] + self.vocab += ["<|endoftext|>"] + self.vocab += ["<|im_end|>"] + self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)]) + self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"] + self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"] + self.merges += ["Ġai r", "Ġa i", "pla ne"] + self.preprocessor = Qwen3MoeCausalLMPreprocessor( + Qwen3MoeTokenizer(vocabulary=self.vocab, merges=self.merges), + sequence_length=7, + ) + self.backbone = Qwen3MoeBackbone( + vocabulary_size=self.preprocessor.tokenizer.vocabulary_size(), + num_layers=2, + num_query_heads=4, + num_key_value_heads=2, + hidden_dim=8, + intermediate_dim=16, + moe_intermediate_dim=4, + head_dim=2, + # shared_expert_intermediate_dim=16, + num_experts=4, + ) + self.init_kwargs = { + "preprocessor": self.preprocessor, + "backbone": self.backbone, + } + self.train_data = ([" airplane at airport", " airplane at airport"],) + self.input_data = self.preprocessor(*self.train_data)[0] + + def test_causal_lm_basics(self): + self.run_task_test( + cls=Qwen3MoeCausalLM, + init_kwargs=self.init_kwargs, + train_data=self.train_data, + expected_output_shape=(2, 7, 8), + ) + + def test_generate(self): + causal_lm = Qwen3MoeCausalLM(**self.init_kwargs) + # String input. + prompt = " airplane at airport" + output = causal_lm.generate(" airplane at airport") + self.assertTrue(prompt in output) + # Int tensor input. + prompt_ids = self.preprocessor.generate_preprocess([prompt]) + causal_lm.preprocessor = None + outputs = causal_lm.generate(prompt_ids, stop_token_ids=None) + # Assert prompt is in output in token id space. + self.assertAllEqual( + outputs["token_ids"][:, :5], + prompt_ids["token_ids"][:, :5], + ) + self.assertAllEqual( + outputs["padding_mask"][:, :5], + prompt_ids["padding_mask"][:, :5], + ) + + def test_generate_strip_prompt(self): + causal_lm = Qwen3MoeCausalLM(**self.init_kwargs) + prompt = " airplane at airport" + output = causal_lm.generate(prompt, strip_prompt=True) + self.assertFalse(output.startswith(prompt)) + + def test_early_stopping(self): + causal_lm = Qwen3MoeCausalLM(**self.init_kwargs) + call_with_cache = causal_lm.call_with_cache + + def wrapper(*args, **kwargs): + """Modify output logits to always favor end_token_id""" + logits, hidden_states, cache = call_with_cache(*args, **kwargs) + index = self.preprocessor.tokenizer.end_token_id + update = ops.ones_like(logits)[:, :, index] * 1.0e9 + update = ops.expand_dims(update, axis=-1) + logits = ops.slice_update(logits, (0, 0, index), update) + return logits, hidden_states, cache + + with patch.object(causal_lm, "call_with_cache", wraps=wrapper): + prompt = [" airplane at airport", " airplane"] + output = causal_lm.generate(prompt) + # We should immediately abort and output the prompt. + self.assertEqual(prompt, output) + + def test_generate_compilation(self): + causal_lm = Qwen3MoeCausalLM(**self.init_kwargs) + # Assert we do not recompile with successive calls. + causal_lm.generate(" airplane at airport") + first_fn = causal_lm.generate_function + causal_lm.generate(" airplane at airport") + second_fn = causal_lm.generate_function + self.assertEqual(first_fn, second_fn) + # Assert we do recompile after compile is called. + causal_lm.compile(sampler="greedy") + self.assertIsNone(causal_lm.generate_function) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=Qwen3MoeCausalLM, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in Qwen3MoeCausalLM.presets: + self.run_preset_test( + cls=Qwen3MoeCausalLM, + preset=preset, + input_data=self.input_data, + ) diff --git a/keras_hub/src/models/qwen3_moe/qwen3_moe_decoder.py b/keras_hub/src/models/qwen3_moe/qwen3_moe_decoder.py new file mode 100644 index 0000000000..fd4e59193a --- /dev/null +++ b/keras_hub/src/models/qwen3_moe/qwen3_moe_decoder.py @@ -0,0 +1,672 @@ +import keras +from keras import ops + +from keras_hub.src.layers.modeling.transformer_layer_utils import ( + compute_causal_mask, +) +from keras_hub.src.layers.modeling.transformer_layer_utils import ( + merge_padding_and_attention_mask, +) +from keras_hub.src.models.qwen3_moe.qwen3_moe_attention import Qwen3MoeAttention +from keras_hub.src.models.qwen3_moe.qwen3_moe_layernorm import Qwen3MoeLayerNorm +from keras_hub.src.utils.keras_utils import clone_initializer + + +def compute_load_balancing_loss( + router_logits, num_experts, top_k, attention_mask=None +): + """ + Compute the load balancing auxiliary loss for a single MoE layer. + + Args: + router_logits: Tensor of shape (batch_size * seq_len, num_experts). + num_experts: Integer, total number of experts. + top_k: Integer, number of experts to select per token. + attention_mask: Tensor of shape (batch_size, seq_len, seq_len), + optional mask for padding. + + Returns: + Scalar tensor representing the auxiliary loss. + """ + # Compute routing probabilities + routing_weights = ops.softmax( + router_logits, axis=-1 + ) # Shape: (batch_size * seq_len, num_experts) + + # Get top-k experts + _, selected_experts = ops.top_k( + routing_weights, k=top_k + ) # Shape: (batch_size * seq_len, top_k) + + # Create one-hot encoding for selected experts + expert_mask = ops.one_hot( + selected_experts, num_experts + ) # Shape: (batch_size * seq_len, top_k, num_experts) + + if attention_mask is not None: + # Convert attention mask to (batch_size, seq_len) + batch_size, seq_len, _ = ops.shape(attention_mask) + flat_mask = ops.any(attention_mask, axis=-1) + flat_mask = ops.reshape( + flat_mask, (-1,) + ) # Shape: (batch_size * seq_len,) + # Expand mask for broadcasting + expert_attention_mask = ops.expand_dims( + flat_mask, axis=-1 + ) # Shape: (batch_size * seq_len, 1) + expert_attention_mask = ops.cast(expert_attention_mask, dtype="float32") + + # Compute masked means + tokens_per_expert = ops.sum( + expert_mask * expert_attention_mask[:, None, :], axis=0 + ) / ops.maximum( + ops.sum(expert_attention_mask[:, None, :], axis=0), 1e-9 + ) # Shape: (top_k, num_experts) + router_prob_per_expert = ops.sum( + routing_weights * expert_attention_mask, axis=0 + ) / ops.maximum( + ops.sum(expert_attention_mask, axis=0), 1e-9 + ) # Shape: (num_experts,) + else: + # Unmasked means + tokens_per_expert = ops.mean( + expert_mask, axis=0 + ) # Shape: (top_k, num_experts) + router_prob_per_expert = ops.mean( + routing_weights, axis=0 + ) # Shape: (num_experts,) + + # Average over top_k dimension if necessary + tokens_per_expert = ops.mean( + tokens_per_expert, axis=0 + ) # Shape: (num_experts,) + + # Compute the loss + overall_loss = ops.sum(tokens_per_expert * router_prob_per_expert) + return overall_loss * num_experts + + +class Qwen3MoeMLP(keras.layers.Layer): + """A feedforward network layer for a Transformer model. + + This layer implements the gated linear unit (GLU) variant of a + feedforward network, which is a common setup in modern Transformers. + It consists of three dense layers: a gate layer, an intermediate layer, + and an output layer. The output is computed as + `output_dense(activation(gate_dense(x)) * intermediate_dense(x))`. + + Args: + intermediate_dim (int): The size of the intermediate (hidden) layer. + hidden_dim (int): The size of the input and output layers. + activation_fn (str, optional): The activation function to use. + Defaults to "silu". + layer_norm_epsilon (float, optional): Epsilon for layer normalization. + Defaults to 1e-6. + kernel_initializer (str, optional): The initializer for the kernel + weights. Defaults to "glorot_uniform". + """ + + def __init__( + self, + intermediate_dim, + hidden_dim, + activation_fn="silu", + layer_norm_epsilon=1e-6, + kernel_initializer="glorot_uniform", + **kwargs, + ): + super().__init__(**kwargs) + self.intermediate_dim = intermediate_dim + self.hidden_dim = hidden_dim + self.activation_fn = activation_fn + self.kernel_initializer = kernel_initializer + self.layer_norm_epsilon = layer_norm_epsilon + + def build(self, decoder_sequence_shape): + # Feedforward layers. + self._feedforward_intermediate_dense = keras.layers.Dense( + self.intermediate_dim, + kernel_initializer=clone_initializer(self.kernel_initializer), + use_bias=False, + dtype=self.dtype_policy, + name="feedforward_intermediate_dense", + ) + self._feedforward_intermediate_dense.build(decoder_sequence_shape) + + self._feedforward_gate_dense = keras.layers.Dense( + self.intermediate_dim, + kernel_initializer=clone_initializer(self.kernel_initializer), + use_bias=False, + dtype=self.dtype_policy, + name="feedforward_gate_dense", + ) + self._feedforward_gate_dense.build(decoder_sequence_shape) + + self._feedforward_output_dense = keras.layers.Dense( + self.hidden_dim, + kernel_initializer=clone_initializer(self.kernel_initializer), + use_bias=False, + dtype=self.dtype_policy, + name="feedforward_output_dense", + ) + + self._feedforward_output_dense.build( + self._feedforward_gate_dense.compute_output_shape( + decoder_sequence_shape + ) + ) + + self.activation = keras.activations.get(self.activation_fn) + self.built = True + + def call(self, x): + gate_output = self._feedforward_gate_dense(x) + + # Note that we run the activation function in full 32-bit + # precision since this is what `torch.nn.functional.silu` + # does. Internally, `torch.nn.functional.silu` converts the + # inputs to float32, computes SiLU, and converts the outputs + # back to compute dtype. + # CPU Kernel: https://github.com/pytorch/pytorch/blob/35c493f2cf9b623bfdc7e6b34dc1cb39690a7919/aten/src/ATen/native/cpu/Activation.cpp#L1221-L1235 # noqa: E501 + # CUDA Kernel: https://github.com/pytorch/pytorch/blob/35c493f2cf9b623bfdc7e6b34dc1cb39690a7919/aten/src/ATen/native/cuda/ActivationSiluKernel.cu # noqa: E501 + gate_output = ops.cast(gate_output, "float32") + gate_output = self.activation(gate_output) + gate_output = ops.cast(gate_output, self.compute_dtype) + + x = self._feedforward_intermediate_dense(x) + + return self._feedforward_output_dense(ops.multiply(x, gate_output)) + + +class Qwen3MoeExperts(keras.layers.Layer): + """A layer that contains a bank of feedforward experts for MoE. + + This layer implements the expert part of a Mixture-of-Experts (MoE) model. + It creates a set of 'expert' feedforward networks that are computed in a + batched manner for efficiency. The weights for all experts are stored in + a single tensor, and computations are performed using `einsum` to process + all experts simultaneously. + + Args: + num_experts (int): The total number of experts in the layer. + hidden_dim (int): The dimension of the input and output of each expert. + intermediate_dim (int): The intermediate dimension of each expert's + feedforward network. + activation_fn (str, optional): The activation function to use within + each expert. Defaults to "silu". + kernel_initializer (str, optional): The initializer for the kernel + weights. Defaults to "glorot_uniform". + """ + + def __init__( + self, + num_experts, + hidden_dim, + intermediate_dim, + activation_fn="silu", + kernel_initializer="glorot_uniform", + **kwargs, + ): + super().__init__(**kwargs) + self.num_experts = num_experts + self.hidden_dim = hidden_dim + self.intermediate_dim = intermediate_dim + self.activation = keras.activations.get(activation_fn) + self.kernel_initializer = kernel_initializer + + def build(self, _): + self._expert_feedforward_gate_dense = self.add_weight( + shape=( + self.num_experts, + self.hidden_dim, + 2 * self.intermediate_dim, + ), + initializer=self.kernel_initializer, + trainable=True, + dtype=self.variable_dtype, + name="expert_feedforward_gate_dense", + ) + + self._expert_feedforward_output_dense = self.add_weight( + shape=(self.num_experts, self.intermediate_dim, self.hidden_dim), + initializer=self.kernel_initializer, + trainable=True, + dtype=self.variable_dtype, + name="expert_feedforward_output_dense", + ) + + self.built = True + + def call(self, hidden_states): + gate_up = ops.einsum( + "th,ehm->etm", hidden_states, self._expert_feedforward_gate_dense + ) + gate, up = ops.split(gate_up, 2, axis=-1) + hidden = up * self.activation(gate) + out = ops.einsum( + "eti,eih->eth", hidden, self._expert_feedforward_output_dense + ) + return out + + +class Qwen3SparseMoeBlock(keras.layers.Layer): + """A sparse Mixture-of-Experts (MoE) block. + + This block implements the full MoE logic. It contains a 'router' that + learns to send each input token to a subset of 'experts'. The final output + is a weighted combination of the outputs from the selected experts. + It also computes a load-balancing auxiliary loss during training to + encourage the router to distribute tokens evenly across all experts. + + Args: + hidden_dim (int): The dimension of the input and output tensors. + moe_intermediate_dim (int): The intermediate dimension of each expert. + num_experts (int): The total number of experts available. + top_k (int): The number of experts to route each token to. + norm_top_k_prob (bool): If True, normalize the probabilities of the + top-k experts. + kernel_initializer (str, optional): The initializer for kernel weights. + Defaults to "glorot_uniform". + layer_norm_epsilon (float, optional): Epsilon for layer normalization. + Defaults to 1e-6. + router_aux_loss_coefficient (float, optional): The coefficient for the + load-balancing auxiliary loss. Defaults to 0.01. + """ + + def __init__( + self, + hidden_dim, + moe_intermediate_dim, + num_experts, + top_k, + norm_top_k_prob, + kernel_initializer="glorot_uniform", + layer_norm_epsilon=1e-6, + router_aux_loss_coefficient=0.01, + **kwargs, + ): + super().__init__(**kwargs) + self.hidden_dim = hidden_dim + self.intermediate_dim = moe_intermediate_dim + self.num_experts = num_experts + self.top_k = top_k + self.norm_top_k_prob = norm_top_k_prob + self.kernel_initializer = kernel_initializer + self.layer_norm_epsilon = layer_norm_epsilon + self.router_aux_loss_coefficient = router_aux_loss_coefficient + + def build(self, decoder_sequence_shape): + self._sparse_feedforward_gate_dense = keras.layers.Dense( + self.num_experts, + use_bias=False, + kernel_initializer=self.kernel_initializer, + name="sparse_feedforward_gate_dense", + dtype=self.dtype_policy, + ) + self._sparse_feedforward_gate_dense.build(decoder_sequence_shape) + + # NOTE: Experts are implemented as a single layer to enable efficient + # batched computation. Implementing each expert individually is + # currently avoided due to the lack of `ragged_dot` support in the + # Keras ops API, which would make individual implementations unstable + # and prone to bugs. + self.expert_bank = Qwen3MoeExperts( + num_experts=self.num_experts, + hidden_dim=self.hidden_dim, + intermediate_dim=self.intermediate_dim, + kernel_initializer=self.kernel_initializer, + name="experts", + dtype=self.dtype_policy, + ) + self.expert_bank.build(decoder_sequence_shape) + + self.built = True + + def call(self, hidden_states, attention_mask=None, training=None): + batch_size, seq_len, _ = ops.shape(hidden_states) + hidden_states_flattened = ops.reshape( + hidden_states, (-1, self.hidden_dim) + ) + + router_logits = self._sparse_feedforward_gate_dense( + hidden_states_flattened + ) + router_probs = ops.softmax(router_logits, axis=-1) + + top_p, top_i = ops.top_k(router_probs, k=self.top_k) + if self.norm_top_k_prob: + top_p = top_p / ops.sum(top_p, axis=-1, keepdims=True) + + one_hot = ops.one_hot(top_i, self.num_experts) + one_hot = ops.cast(one_hot, top_p.dtype) + routing_full = ops.sum(one_hot * top_p[..., None], axis=1) + routing_full = ops.transpose(routing_full, (1, 0)) + routing_full = ops.cast(routing_full, hidden_states_flattened.dtype) + + expert_out = self.expert_bank(hidden_states_flattened) + + weighted_out = expert_out * routing_full[:, :, None] + expert_contribution = ops.sum(weighted_out, axis=0) + + out = ops.reshape( + expert_contribution, (batch_size, seq_len, self.hidden_dim) + ) + + # Compute and add auxiliary loss during training + if training: + aux_loss = compute_load_balancing_loss( + router_logits=router_logits, + num_experts=self.num_experts, + top_k=self.top_k, + attention_mask=attention_mask, + ) + self.add_loss(self.router_aux_loss_coefficient * aux_loss) + + return out, router_logits + + +class Qwen3MoeTransformerDecoder(keras.layers.Layer): + """A Transformer decoder layer for the Qwen3 Moe backbone. + + This layer implements a Transformer decoder block that includes + self-attention with optional sliding window attention and a + Mixture-of-Experts (MoE) feed-forward network. + + Args: + intermediate_dim: Output dimension of the first dense layer in the + feed-forward network (for non-MoE layers). + num_query_heads: Number of query attention heads. + num_key_value_heads: Number of key/value attention heads (for GQA). + moe_intermediate_dim: The intermediate dimension for each expert in the + MoE layer. + num_experts: The total number of experts in the MoE layer. + top_k: The number of experts to which each token is routed. + norm_top_k_prob: If True, normalize the top-k probabilities. + head_dim: The dimension of each attention head. If None, it is + inferred from other dimensions. + is_sparse_mlp: If True, uses a sparse MLP. + rope_max_wavelength: Maximum wavelength for RoPE (Rotary Position + Embedding). + rope_scaling_factor: Scaling factor for RoPE, used for extending + context length. + activation: Activation function to use in the feed-forward network. + layer_norm_epsilon: Small float added to variance to avoid dividing + by zero in layer norm. + kernel_initializer: Initializer for the kernel weights. + dropout: Dropout rate for attention and hidden layers. + sliding_window_size: Size of the sliding window for attention when + enabled. + router_aux_loss_coefficient: The coefficient for the router's auxiliary + loss, used for load balancing. + **kwargs: Additional keyword arguments to pass to the Layer. + """ + + def __init__( + self, + intermediate_dim, + num_query_heads, + num_key_value_heads, + moe_intermediate_dim, + num_experts, + top_k, + norm_top_k_prob, + head_dim=None, + is_sparse_mlp=False, + rope_max_wavelength=10000, + rope_scaling_factor=1.0, + activation="silu", + layer_norm_epsilon=1e-6, + kernel_initializer="glorot_uniform", + dropout=0, + sliding_window_size=4096, + router_aux_loss_coefficient=0.001, + **kwargs, + ): + super().__init__(**kwargs) + self.intermediate_dim = intermediate_dim + self.num_query_heads = num_query_heads + self.num_key_value_heads = num_key_value_heads + self.rope_max_wavelength = rope_max_wavelength + self.rope_scaling_factor = rope_scaling_factor + self.dropout = dropout + self.sliding_window_size = sliding_window_size + self.activation = keras.activations.get(activation) + self.layer_norm_epsilon = layer_norm_epsilon + self.kernel_initializer = keras.initializers.get(kernel_initializer) + self.moe_intermediate_dim = moe_intermediate_dim + self.head_dim = head_dim + self.num_experts = num_experts + self.top_k = top_k + self.norm_top_k_prob = norm_top_k_prob + self.is_sparse_mlp = is_sparse_mlp + self.router_aux_loss_coefficient = router_aux_loss_coefficient + self.supports_masking = True + + def build(self, decoder_sequence_shape): + self._decoder_sequence_shape = decoder_sequence_shape + self.hidden_dim = decoder_sequence_shape[-1] + + # Self attention layer. + self._self_attention_layer = Qwen3MoeAttention( + num_query_heads=self.num_query_heads, + num_key_value_heads=self.num_key_value_heads, + rope_max_wavelength=self.rope_max_wavelength, + head_dim=self.head_dim, + rope_scaling_factor=self.rope_scaling_factor, + kernel_initializer=clone_initializer(self.kernel_initializer), + dropout=self.dropout, + sliding_window_size=self.sliding_window_size, + dtype=self.dtype_policy, + name="self_attention", + ) + self._self_attention_layer.build(decoder_sequence_shape) + + self._self_attention_layernorm = Qwen3MoeLayerNorm( + epsilon=self.layer_norm_epsilon, + dtype=self.dtype_policy, + name="self_attention_layernorm", + ) + + self._self_attention_layernorm.build(decoder_sequence_shape) + self._self_attention_dropout = keras.layers.Dropout( + rate=self.dropout, + dtype=self.dtype_policy, + name="self_attention_dropout", + ) + + # Feedforward layers. + if self.is_sparse_mlp: + self.mlp = Qwen3SparseMoeBlock( + hidden_dim=self.hidden_dim, + moe_intermediate_dim=self.moe_intermediate_dim, + num_experts=self.num_experts, + top_k=self.top_k, + norm_top_k_prob=self.norm_top_k_prob, + router_aux_loss_coefficient=self.router_aux_loss_coefficient, + kernel_initializer=self.kernel_initializer, + dtype=self.dtype_policy, + ) + self.mlp.build(decoder_sequence_shape) + else: + self.mlp = Qwen3MoeMLP( + intermediate_dim=self.intermediate_dim, + hidden_dim=self.hidden_dim, + dtype=self.dtype_policy, + ) + self.mlp.build(decoder_sequence_shape) + + self._feedforward_layernorm = Qwen3MoeLayerNorm( + epsilon=self.layer_norm_epsilon, + dtype=self.dtype_policy, + name="feedforward_layernorm", + ) + self._feedforward_layernorm.build(decoder_sequence_shape) + + self.built = True + + def call( + self, + decoder_sequence, + decoder_padding_mask=None, + decoder_attention_mask=None, + self_attention_cache=None, + self_attention_cache_update_index=None, + training=None, + ): + """Forward pass for the decoder layer. + + Args: + decoder_sequence: Input tensor of shape [batch_size, seq_length, + hidden_size]. + decoder_padding_mask: Mask tensor for padding tokens. + decoder_attention_mask: Additional attention mask. + self_attention_cache: Optional cached key and value tensors for + self-attention. + self_attention_cache_update_index: Index at which to update the + cache. + training: Boolean indicating whether in training mode. + + Returns: + decoder_output: Output tensor after applying transformer decoder + block. + self_attention_cache: Updated cache tensors (if cache is provided). + """ + self_attention_mask = self._compute_self_attention_mask( + decoder_sequence=decoder_sequence, + decoder_padding_mask=decoder_padding_mask, + decoder_attention_mask=decoder_attention_mask, + self_attention_cache=self_attention_cache, + self_attention_cache_update_index=self_attention_cache_update_index, + ) + residual = decoder_sequence + + x = self._self_attention_layernorm(decoder_sequence) + + # Self attention block. + x = self._self_attention_layer( + hidden_states=x, + attention_mask=self_attention_mask, + cache=self_attention_cache, + cache_update_index=self_attention_cache_update_index, + ) + + if self_attention_cache is not None: + x, self_attention_cache = x + + x = self._self_attention_dropout(x, training=training) + + x = x + residual + residual = x + + x = self._feedforward_layernorm(x) + if isinstance(self.mlp, Qwen3SparseMoeBlock): + x = self.mlp( + x, training=training, attention_mask=self_attention_mask + ) + else: + x = self.mlp(x) + + if isinstance(x, tuple): + x, _ = x + + x = ops.cast(x, ops.dtype(residual)) + decoder_output = x + residual + + output = (decoder_output,) + + if self_attention_cache is not None: + output += (self_attention_cache,) + + return output[0] if len(output) == 1 else output + + def _compute_self_attention_mask( + self, + decoder_sequence, + decoder_padding_mask, + decoder_attention_mask, + self_attention_cache, + self_attention_cache_update_index, + ): + """Computes the self-attention mask combining causal, padding and + attention masks. + + Args: + decoder_sequence: Input tensor. + decoder_padding_mask: Mask tensor for padding tokens. + decoder_attention_mask: Additional attention mask. + self_attention_cache: Optional cached key and value tensors. + self_attention_cache_update_index: Index at which to update the + cache. + + Returns: + Combined attention mask tensor. + """ + decoder_mask = merge_padding_and_attention_mask( + decoder_sequence, decoder_padding_mask, decoder_attention_mask + ) + batch_size = ops.shape(decoder_sequence)[0] + input_length = output_length = ops.shape(decoder_sequence)[1] + # We need to handle a rectangular causal mask when doing cached + # decoding. For generative inference, `decoder_sequence` will + # generally be length 1, and `cache` will be the full generation length. + if self_attention_cache is not None: + input_length = ops.shape(self_attention_cache)[2] + + cache_update_index = ( + 0 + if self_attention_cache_update_index is None + else self_attention_cache_update_index + ) + + causal_mask = compute_causal_mask( + batch_size, input_length, output_length, cache_update_index + ) + + return ( + ops.minimum(decoder_mask, causal_mask) + if decoder_mask is not None + else causal_mask + ) + + def compute_output_shape(self, decoder_sequence_shape): + """Computes the output shape of the layer. + + Args: + decoder_sequence_shape: Shape of the decoder sequence input. + + Returns: + Output shape, which is the same as the input shape. + """ + return decoder_sequence_shape + + def get_config(self): + """Returns the config of the layer. + + Returns: + Dictionary containing the parameters used to initialize this layer. + """ + config = super().get_config() + config.update( + { + "num_query_heads": self.num_query_heads, + "intermediate_dim": self.intermediate_dim, + "moe_intermediate_dim": self.moe_intermediate_dim, + "rope_max_wavelength": self.rope_max_wavelength, + "num_key_value_heads": self.num_key_value_heads, + "rope_scaling_factor": self.rope_scaling_factor, + "layer_norm_epsilon": self.layer_norm_epsilon, + "dropout": self.dropout, + "sliding_window_size": self.sliding_window_size, + "num_experts": self.num_experts, + "top_k": self.top_k, + "norm_top_k_prob": self.norm_top_k_prob, + "router_aux_loss_coefficient": self.router_aux_loss_coefficient, + "head_dim": self.head_dim, + "is_sparse_mlp": self.is_sparse_mlp, + "activation": keras.activations.serialize(self.activation), + "kernel_initializer": keras.initializers.serialize( + self.kernel_initializer + ), + } + ) + return config diff --git a/keras_hub/src/models/qwen3_moe/qwen3_moe_layernorm.py b/keras_hub/src/models/qwen3_moe/qwen3_moe_layernorm.py new file mode 100644 index 0000000000..0a5fc23b6e --- /dev/null +++ b/keras_hub/src/models/qwen3_moe/qwen3_moe_layernorm.py @@ -0,0 +1,45 @@ +import keras +from keras import ops + + +class Qwen3MoeLayerNorm(keras.layers.Layer): + """A normalization layer for Qwen that implements RMS normalization. + + Args: + head_dim: int. The dimension of each attention head, used for per-head + normalization. Defaults to `None`. + epsilon: float. A small float added to variance to avoid dividing by + zero. Defaults to `1e-6`. + """ + + def __init__(self, head_dim=None, epsilon=1e-6, **kwargs): + super().__init__(**kwargs) + self.head_dim = head_dim + self.epsilon = epsilon + + def build(self, input_shape): + if self.head_dim: + dim = self.head_dim + else: + dim = input_shape[-1] + + self.scale = self.add_weight( + name="scale", + trainable=True, + shape=(dim,), + initializer="ones", + dtype=self.variable_dtype, + ) + self.built = True + + def call(self, x): + input_dtype = x.dtype + x = ops.cast(x, "float32") + var = ops.mean(ops.power(x, 2), axis=-1, keepdims=True) + x = x * ops.rsqrt(var + self.epsilon) + return ops.cast(x * self.scale, input_dtype) + + def get_config(self): + config = super().get_config() + config.update({"epsilon": self.epsilon, "head_dim": self.head_dim}) + return config diff --git a/keras_hub/src/models/qwen3_moe/qwen3_moe_tokenizer.py b/keras_hub/src/models/qwen3_moe/qwen3_moe_tokenizer.py new file mode 100644 index 0000000000..8c62f12db5 --- /dev/null +++ b/keras_hub/src/models/qwen3_moe/qwen3_moe_tokenizer.py @@ -0,0 +1,48 @@ +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.qwen3_moe.qwen3_moe_backbone import Qwen3MoeBackbone +from keras_hub.src.tokenizers.byte_pair_tokenizer import BytePairTokenizer + + +@keras_hub_export( + "keras_hub.tokenizers.Qwen3MoeTokenizer", +) +class Qwen3MoeTokenizer(BytePairTokenizer): + """Tokenizer for Qwen Moe model. + + This tokenizer implements byte-pair encoding (BPE) for Qwen models, + handling special tokens like BOS (beginning of sequence) and EOS (end of + sequence). + + Args: + vocabulary: Dictionary mapping tokens to token IDs, or path to + vocabulary file. + merges: List of BPE merges, or path to merges file. + bos_token: Beginning of sequence token. Defaults to None. + eos_token: End of sequence token. Defaults to "<|endoftext|>". + misc_special_tokens: Set of additional special tokens. Defaults to + empty set. + """ + + backbone_cls = Qwen3MoeBackbone + + def __init__( + self, + vocabulary=None, + merges=None, + **kwargs, + ): + # Add EOS token + eos_token = "<|im_end|>" + self._add_special_token(eos_token, "end_token") + + pad_token = "<|endoftext|>" + self._add_special_token(pad_token, "pad_token") + + self.start_token_id = None + self.start_token = None + + super().__init__( + vocabulary=vocabulary, + merges=merges, + **kwargs, + ) diff --git a/keras_hub/src/utils/transformers/convert_qwen3_moe.py b/keras_hub/src/utils/transformers/convert_qwen3_moe.py new file mode 100644 index 0000000000..4fa80635ec --- /dev/null +++ b/keras_hub/src/utils/transformers/convert_qwen3_moe.py @@ -0,0 +1,216 @@ +import numpy as np + +from keras_hub.src.models.qwen3_moe.qwen3_moe_backbone import Qwen3MoeBackbone +from keras_hub.src.utils.preset_utils import load_json + +backbone_cls = Qwen3MoeBackbone + + +def convert_backbone_config(transformers_config): + return { + "vocabulary_size": transformers_config["vocab_size"], + "hidden_dim": transformers_config["hidden_size"], + "head_dim": transformers_config["head_dim"], + "num_layers": transformers_config["num_hidden_layers"], + "num_query_heads": transformers_config["num_attention_heads"], + "num_key_value_heads": transformers_config["num_key_value_heads"], + "intermediate_dim": transformers_config["intermediate_size"], + "moe_intermediate_dim": transformers_config["moe_intermediate_size"], + "num_experts": transformers_config["num_experts"], + "top_k": transformers_config["num_experts_per_tok"], + "norm_top_k_prob": transformers_config["norm_topk_prob"], + "decoder_sparse_step": transformers_config["decoder_sparse_step"], + "layer_norm_epsilon": transformers_config["rms_norm_eps"], + "rope_max_wavelength": transformers_config["rope_theta"], + "sliding_window_size": transformers_config["sliding_window"], + "router_aux_loss_coefficient": transformers_config[ + "router_aux_loss_coef" + ], + "tie_word_embeddings": transformers_config.get( + "tie_word_embeddings", False + ), + } + + +def convert_weights(backbone, loader, transformers_config): + loader.port_weight( + keras_variable=backbone.get_layer("token_embedding").embeddings, + hf_weight_key="model.embed_tokens.weight", + ) + if not backbone.tie_word_embeddings: + loader.port_weight( + keras_variable=backbone.get_layer( + "token_embedding" + ).reverse_embeddings, + hf_weight_key="lm_head.weight", + # rearrange_pattern="b a -> a b", + hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)), + ) + + def transpose_and_reshape(x, shape): + return np.reshape(np.transpose(x), shape) + + for i in range(backbone.num_layers): + decoder_layer = backbone.get_layer(f"transformer_layer_{i}") + + # Input layernorm + loader.port_weight( + keras_variable=decoder_layer._self_attention_layernorm.scale, + hf_weight_key=f"model.layers.{i}.input_layernorm.weight", + ) + + # Attention layers + + ## Query + loader.port_weight( + keras_variable=decoder_layer._self_attention_layer._query_dense.kernel, + hf_weight_key=f"model.layers.{i}.self_attn.q_proj.weight", + hook_fn=transpose_and_reshape, + ) + loader.port_weight( + keras_variable=decoder_layer._self_attention_layer._query_dense_layer_norm.scale, + hf_weight_key=f"model.layers.{i}.self_attn.q_norm.weight", + ) + ## Key + loader.port_weight( + keras_variable=decoder_layer._self_attention_layer._key_dense.kernel, + hf_weight_key=f"model.layers.{i}.self_attn.k_proj.weight", + hook_fn=transpose_and_reshape, + ) + loader.port_weight( + keras_variable=decoder_layer._self_attention_layer._key_dense_layer_norm.scale, + hf_weight_key=f"model.layers.{i}.self_attn.k_norm.weight", + ) + ## Value + loader.port_weight( + keras_variable=decoder_layer._self_attention_layer._value_dense.kernel, + hf_weight_key=f"model.layers.{i}.self_attn.v_proj.weight", + hook_fn=transpose_and_reshape, + ) + ## Output + loader.port_weight( + keras_variable=decoder_layer._self_attention_layer._output_dense.kernel, + hf_weight_key=f"model.layers.{i}.self_attn.o_proj.weight", + # rearrange_patterns="c (a b) -> a b c", + # rearrange_dims={"a": backbone.num_query_heads}, + hook_fn=transpose_and_reshape, + ) + + # MLP layers + if ( + (i not in backbone.mlp_only_layers) + and backbone.num_experts > 0 + and ((i + 1) % backbone.decoder_sparse_step == 0) + ): + # MoE layers + loader.port_weight( + keras_variable=decoder_layer.mlp._sparse_feedforward_gate_dense.kernel, + hf_weight_key=f"model.layers.{i}.mlp.gate.weight", + # rearrange_patterns="b a -> a b", + hook_fn=lambda hf_tensor, _: np.transpose( + hf_tensor, axes=(1, 0) + ), + ) + # Batched experts: gate_up_proj and down_proj + gate_up_proj_list = [] + down_proj_list = [] + for expert_idx in range(backbone.num_experts): + # Load gate_proj and up_proj for each expert + gate_proj = loader.get_tensor( + f"model.layers.{i}.mlp.experts.{expert_idx}.gate_proj.weight" + ) + up_proj = loader.get_tensor( + f"model.layers.{i}.mlp.experts.{expert_idx}.up_proj.weight" + ) + # Transpose to (hidden_dim, intermediate_dim) + gate_proj = np.transpose(gate_proj, axes=(1, 0)) + up_proj = np.transpose(up_proj, axes=(1, 0)) + # Concatenate gate_proj and up_proj along the last dimension + gate_up_proj = np.concatenate([gate_proj, up_proj], axis=-1) + gate_up_proj_list.append(gate_up_proj) + + # Load down_proj for each expert + down_proj = loader.get_tensor( + f"model.layers.{i}.mlp.experts.{expert_idx}.down_proj.weight" + ) + down_proj = np.transpose( + down_proj, axes=(1, 0) + ) # (intermediate_dim, hidden_dim) + down_proj_list.append(down_proj) + + # Stack the lists to create batched weights + gate_up_proj_batched = np.stack( + gate_up_proj_list, axis=0 + ) # (num_experts, hidden_dim, 2 * intermediate_dim) + down_proj_batched = np.stack( + down_proj_list, axis=0 + ) # (num_experts, intermediate_dim, hidden_dim) + + # Assign batched weights to expert_bank + decoder_layer.mlp.expert_bank._expert_feedforward_gate_dense.assign( + gate_up_proj_batched + ) + decoder_layer.mlp.expert_bank._expert_feedforward_output_dense.assign( + down_proj_batched + ) + else: + loader.port_weight( + keras_variable=decoder_layer._feedforward_intermediate_dense.kernel, + hf_weight_key=f"model.layers.{i}.mlp.up_proj.weight", + # rearrange_patterns="b a -> a b", + hook_fn=lambda hf_tensor, _: np.transpose( + hf_tensor, axes=(1, 0) + ), + ) + loader.port_weight( + keras_variable=decoder_layer._feedforward_output_dense.kernel, + hf_weight_key=f"model.layers.{i}.mlp.down_proj.weight", + # rearrange_patterns="b a -> a b", + hook_fn=lambda hf_tensor, _: np.transpose( + hf_tensor, axes=(1, 0) + ), + ) + loader.port_weight( + keras_variable=decoder_layer._feedforward_gate_dense.kernel, + hf_weight_key=f"model.layers.{i}.mlp.gate_proj.weight", + # rearrange_patterns="b a -> a b", + hook_fn=lambda hf_tensor, _: np.transpose( + hf_tensor, axes=(1, 0) + ), + ) + + # Feedforward layernorm + loader.port_weight( + keras_variable=decoder_layer._feedforward_layernorm.scale, + hf_weight_key=f"model.layers.{i}.post_attention_layernorm.weight", + ) + + # Final normalization layer + loader.port_weight( + keras_variable=backbone.get_layer("sequence_output_layernorm").scale, + hf_weight_key="model.norm.weight", + ) + + return backbone + + +def convert_tokenizer(cls, preset, **kwargs): + tokenizer_config = load_json(preset, "tokenizer.json") + vocab = tokenizer_config["model"]["vocab"] + merges = tokenizer_config["model"]["merges"] + merges = [" ".join(item) for item in merges] + + # Load all special tokens with the exception of "reserved" ones. + special_tokens = set() + for token in tokenizer_config["added_tokens"]: + if not token["content"].startswith("<|reserved_special_token_"): + vocab[token["content"]] = token["id"] + special_tokens.add(token["content"]) + + kwargs.update( + { + "unsplittable_tokens": list(special_tokens), + } + ) + + return cls(vocabulary=vocab, merges=merges, **kwargs) diff --git a/keras_hub/src/utils/transformers/convert_qwen3_moe_test.py b/keras_hub/src/utils/transformers/convert_qwen3_moe_test.py new file mode 100644 index 0000000000..3af141ac67 --- /dev/null +++ b/keras_hub/src/utils/transformers/convert_qwen3_moe_test.py @@ -0,0 +1,37 @@ +import keras +import pytest + +from keras_hub.src.models.backbone import Backbone +from keras_hub.src.models.causal_lm import CausalLM +from keras_hub.src.models.qwen3_moe.qwen3_moe_backbone import Qwen3MoeBackbone +from keras_hub.src.models.qwen3_moe.qwen3_moe_causal_lm import Qwen3MoeCausalLM +from keras_hub.src.tests.test_case import TestCase + + +# NOTE: This test is valid and should pass locally. It is skipped only on +# TensorFlow GPU CI because of ResourceExhaustedError (OOM). Revisit once +# TensorFlow GPU CI runs without hitting OOM. +@pytest.mark.skipif( + keras.backend.backend() == "tensorflow", + reason="TensorFlow GPU CI OOM (ResourceExhaustedError)", +) +class TestTask(TestCase): + @pytest.mark.extra_large + def test_convert_tiny_preset(self): + model = Qwen3MoeCausalLM.from_preset("hf://Qwen/Qwen3-30B-A3B") + prompt = "What is the capital of France?" + model.generate([prompt], max_length=15) + + @pytest.mark.extra_large + def test_class_detection(self): + preset_name = "hf://Qwen/Qwen3-30B-A3B" + model = CausalLM.from_preset( + preset_name, + load_weights=False, + ) + self.assertIsInstance(model, Qwen3MoeCausalLM) + model = Backbone.from_preset( + preset_name, + load_weights=False, + ) + self.assertIsInstance(model, Qwen3MoeBackbone) diff --git a/keras_hub/src/utils/transformers/preset_loader.py b/keras_hub/src/utils/transformers/preset_loader.py index bfca6e7bc5..d808a943be 100644 --- a/keras_hub/src/utils/transformers/preset_loader.py +++ b/keras_hub/src/utils/transformers/preset_loader.py @@ -18,6 +18,7 @@ from keras_hub.src.utils.transformers import convert_pali_gemma from keras_hub.src.utils.transformers import convert_qwen from keras_hub.src.utils.transformers import convert_qwen3 +from keras_hub.src.utils.transformers import convert_qwen3_moe from keras_hub.src.utils.transformers import convert_qwen_moe from keras_hub.src.utils.transformers import convert_t5gemma from keras_hub.src.utils.transformers import convert_vit @@ -61,6 +62,8 @@ def __init__(self, preset, config): self.converter = convert_mixtral elif model_type == "qwen2_moe": self.converter = convert_qwen_moe + elif model_type == "qwen3_moe": + self.converter = convert_qwen3_moe elif model_type == "qwen3": self.converter = convert_qwen3 elif model_type == "t5gemma": diff --git a/tools/checkpoint_conversion/convert_qwen3_moe_checkpoints.py b/tools/checkpoint_conversion/convert_qwen3_moe_checkpoints.py new file mode 100644 index 0000000000..6e2c846103 --- /dev/null +++ b/tools/checkpoint_conversion/convert_qwen3_moe_checkpoints.py @@ -0,0 +1,162 @@ +import os +import traceback + +os.environ["KERAS_BACKEND"] = "torch" +os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # Hide any CUDA devices + +import numpy as np +import torch +from absl import app +from absl import flags + +device = torch.device("cpu") +# Force PyTorch to use CPU +torch.set_default_device(device) + +from keras import ops # noqa: E402 +from transformers import AutoModelForCausalLM # noqa: E402 +from transformers import AutoTokenizer # noqa: E402 + +import keras_hub # noqa: E402 + +PRESET_MAP = { + "qwen3_moe_30b_a3b_en": "Qwen/Qwen3-30B-A3B", +} + +FLAGS = flags.FLAGS +flags.DEFINE_string( + "preset", None, f"Must be one of {','.join(PRESET_MAP.keys())}" +) + + +def test_model( + keras_hub_model, keras_hub_tokenizer, hf_model, hf_model_tokenizer +): + # First, test that the number of parameters match + keras_hub_params = keras_hub_model.count_params() + hf_params = hf_model.num_parameters() + assert keras_hub_params == hf_params + + # Test the outputs of both the models + hf_inputs = hf_model_tokenizer(["What is Keras?"], return_tensors="pt").to( + device + ) + hf_outputs = hf_model(**hf_inputs) + hf_output_logits = hf_outputs.logits.detach().cpu().float().numpy() + + keras_hub_preprocessor = keras_hub.models.Qwen3MoeCausalLMPreprocessor( + keras_hub_tokenizer + ) + keras_hub_inputs = keras_hub_preprocessor( + ["What is Keras?"], sequence_length=5 + )[0] + keras_hub_inputs = {k: v.to(device) for k, v in keras_hub_inputs.items()} + + keras_hub_output = keras_hub_model(keras_hub_inputs) + keras_hub_logits = keras_hub_model.token_embedding( + keras_hub_output, reverse=True + ) + keras_hub_logits = ops.convert_to_numpy(keras_hub_logits) + + # High tolerence since bfloat16 is used as the default dtype for Qwen + + try: + np.testing.assert_allclose( + keras_hub_logits, hf_output_logits, atol=1e-3 + ) + except AssertionError as err: + print("\n") + print(traceback.format_exc()) + print(err.args[0]) + print("\n") + + +def test_tokenizer(keras_hub_tokenizer, hf_tokenizer): + hf_output = hf_tokenizer(["What is Keras?"], return_tensors="pt") + hf_output = hf_output["input_ids"].detach().cpu().numpy() + keras_hub_preprocessor = keras_hub.models.Qwen3MoeCausalLMPreprocessor( + keras_hub_tokenizer + ) + keras_hub_output = keras_hub_preprocessor( + ["What is Keras?"], sequence_length=5 + ) + keras_hub_output = ops.convert_to_numpy(keras_hub_output[0]["token_ids"]) + + np.testing.assert_equal(keras_hub_output, hf_output) + + +def validate_output( + keras_hub_model, keras_hub_tokenizer, hf_model, hf_tokenizer +): + input_str = "What is Keras?" + length = 32 + + # KerasHub + preprocessor = keras_hub.models.Qwen3MoeCausalLMPreprocessor( + keras_hub_tokenizer + ) + qwen_moe_lm = keras_hub.models.Qwen3MoeCausalLM( + backbone=keras_hub_model, preprocessor=preprocessor, sampler="greedy" + ) + + keras_output = qwen_moe_lm.generate([input_str], max_length=length) + keras_output = keras_output[0] + print("🔶 KerasHub output:", keras_output) + + # Transformers + hf_inputs = hf_tokenizer([input_str], return_tensors="pt").to(device) + outputs = hf_model.generate( + **hf_inputs, + max_length=length, # Match KerasHub's max_length + do_sample=True, # Enable sampling (default in KerasHub for generate) + pad_token_id=hf_tokenizer.pad_token_id, + ) + print("HF Token outputs = ", outputs) + hf_generated_text = hf_tokenizer.batch_decode( + outputs, skip_special_tokens=True + )[0] + print("🔶 Huggingface output:", hf_generated_text) + + +def main(_): + # === Get the preset name === + if FLAGS.preset not in PRESET_MAP.keys(): + raise ValueError( + f"Invalid preset {FLAGS.preset}. Must be one " + f"of {','.join(PRESET_MAP.keys())}" + ) + + preset = FLAGS.preset + hf_preset = PRESET_MAP[preset] + + # === Load the Huggingface model === + hf_model = AutoModelForCausalLM.from_pretrained( + hf_preset, + device_map=device, + ) + hf_tokenizer = AutoTokenizer.from_pretrained(hf_preset, return_tensors="pt") + hf_model.eval() + + keras_hub_model = keras_hub.models.Qwen3MoeBackbone.from_preset( + f"hf://{hf_preset}" + ) + keras_hub_tokenizer = keras_hub.tokenizers.Qwen3MoeTokenizer.from_preset( + f"hf://{hf_preset}" + ) + + print("\n-> Huggingface model and tokenizer loaded") + + # === Check that the models and tokenizers outputs match === + test_tokenizer(keras_hub_tokenizer, hf_tokenizer) + test_model(keras_hub_model, keras_hub_tokenizer, hf_model, hf_tokenizer) + + # == Validate model.generate output == + validate_output( + keras_hub_model, keras_hub_tokenizer, hf_model, hf_tokenizer + ) + print("\n-> Tests passed!") + + +if __name__ == "__main__": + flags.mark_flag_as_required("preset") + app.run(main)