Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
56 commits
Select commit Hold shift + click to select a range
f0d3696
Added falcon model converter
mehtamansi29 Jan 9, 2025
21df61e
Added falcon model converter -1
mehtamansi29 Jan 10, 2025
bc4b4f7
Falcon converter changes
mehtamansi29 Apr 28, 2025
060e95c
Falcon converter changes_1
mehtamansi29 Apr 28, 2025
496f3e7
transformer config changes of falcon converter
mehtamansi29 Apr 29, 2025
9dd0e61
transformer config changes of falcon converter
mehtamansi29 Apr 29, 2025
b990401
transformer config changes of falcon converter_3
mehtamansi29 Jul 21, 2025
8f2284c
transformer config changes of falcon converter_4
mehtamansi29 Jul 21, 2025
3642f1e
transformer config changes of falcon converter_6
mehtamansi29 Jul 21, 2025
6da4ced
transformer config changes of falcon converter_7
mehtamansi29 Jul 21, 2025
a8ea36f
transformer config changes of falcon converter_8
mehtamansi29 Jul 21, 2025
cea948d
transformer config changes of falcon converter_9
mehtamansi29 Jul 21, 2025
60078c5
transformer config changes of falcon converter_11
mehtamansi29 Jul 24, 2025
d7a5c31
Merge remote-tracking branch 'upstream/master' into patch-1
mehtamansi29 Jul 25, 2025
c7d4a9c
intermediate_dim change
mehtamansi29 Jul 31, 2025
152c19e
intermediate_dim change_1
mehtamansi29 Jul 31, 2025
3bc83bd
backbone_config change
mehtamansi29 Jul 31, 2025
d3cbdec
transformer config intermediate_dim
mehtamansi29 Jul 31, 2025
50e6d06
attention layer weights changes
mehtamansi29 Jul 31, 2025
89bac89
attention layer indention change
mehtamansi29 Jul 31, 2025
164e6cc
transformers_config changes
mehtamansi29 Jul 31, 2025
7873b3c
transformer config changes
mehtamansi29 Jul 31, 2025
5f174d4
transformer config
mehtamansi29 Jul 31, 2025
559ee01
num_key_value_heads change
mehtamansi29 Jul 31, 2025
5047254
remove keyvalue head from transformer config
mehtamansi29 Jul 31, 2025
af2c647
intermediate_dim in transformer c
mehtamansi29 Jul 31, 2025
3aaa529
change head dim
mehtamansi29 Jul 31, 2025
13c04d7
hidden dim changes
mehtamansi29 Jul 31, 2025
8cc06a6
convert_falcon_changes
mehtamansi29 Aug 4, 2025
21e4473
attention layer change
mehtamansi29 Aug 4, 2025
6aa4244
attention layer changes_1
mehtamansi29 Aug 4, 2025
1ce3837
falcon converter changesa
mehtamansi29 Aug 4, 2025
496eeeb
preset_loader precommit run
mehtamansi29 Aug 4, 2025
9ccc46a
Merge branch 'keras-team:master' into patch-1
mehtamansi29 Aug 6, 2025
b64cd4c
backbone and casual_lm test
mehtamansi29 Aug 6, 2025
fba4aba
loading issue for falcon1b
mehtamansi29 Aug 19, 2025
f3c5041
loading issue for falcon1b_1
mehtamansi29 Aug 19, 2025
9b860c1
loading issue for falcon1b_1
mehtamansi29 Aug 19, 2025
41289d8
resolving conflict
mehtamansi29 Aug 19, 2025
2284520
convert_falcon file changes
mehtamansi29 Aug 19, 2025
2933774
convert_falcon chanes_1
mehtamansi29 Aug 19, 2025
b3ba59a
update for 7b parameters
mehtamansi29 Sep 1, 2025
f0fb361
7b parameters mismatch update
mehtamansi29 Sep 2, 2025
88a91a1
Merge remote-tracking branch 'upstream/master' into patch-1
mehtamansi29 Sep 2, 2025
6b26899
Resolve 7b paramter disperancies
mehtamansi29 Sep 2, 2025
e7b39bb
resolve 7b parameter disperancies_1
mehtamansi29 Sep 3, 2025
8e7520b
resolve 7b parameter disperancies_2
mehtamansi29 Sep 3, 2025
6d5ae8c
Revert "Resolve 7b paramter disperancies"
mehtamansi29 Sep 3, 2025
5948b6e
Revert "Resolve 7b paramter disperancies"
mehtamansi29 Sep 3, 2025
79951dd
falcon_transformer_decoder changes
mehtamansi29 Sep 3, 2025
a9eed7c
falcon_transformer_decoder changes_1
mehtamansi29 Sep 3, 2025
0243caf
changes for 7b mismatch parameters
mehtamansi29 Sep 9, 2025
06bd348
Chnages based on comments
mehtamansi29 Sep 11, 2025
a588b76
convert_falcon file changes
mehtamansi29 Sep 23, 2025
92aa32b
layernorm bias term changes
mehtamansi29 Sep 23, 2025
9345630
converter file change
mehtamansi29 Sep 23, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions keras_hub/src/models/falcon/falcon_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,15 @@ def __init__(
self,
num_heads,
attention_dropout_rate,
num_kv_heads,
use_bias=True,
**kwargs,
):
super().__init__(**kwargs)
self.num_heads = num_heads
self.attention_dropout_rate = attention_dropout_rate
self.num_kv_heads = num_kv_heads
self.use_bias = use_bias

def build(self, inputs_shape):
# Einsum variables:
Expand All @@ -28,31 +32,33 @@ def build(self, inputs_shape):

self.head_dim = hidden_dim // self.num_heads

bias_axes = "nh" if self.use_bias else None

# Layer-wise attention scaling
self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)

self.query_dense = keras.layers.EinsumDense(
equation="bqm,mnh->bqnh",
output_shape=(None, self.num_heads, self.head_dim),
bias_axes="nh",
bias_axes=bias_axes,
dtype=self.dtype_policy,
name="query_dense",
)
self.query_dense.build(inputs_shape)

self.key_dense = keras.layers.EinsumDense(
equation="bkm,mnh->bknh",
output_shape=(None, self.num_heads, self.head_dim),
bias_axes="nh",
output_shape=(None, self.num_kv_heads, self.head_dim),
bias_axes=bias_axes,
dtype=self.dtype_policy,
name="key_dense",
)
self.key_dense.build(inputs_shape)

self.value_dense = keras.layers.EinsumDense(
equation="bkm,mnh->bknh",
output_shape=(None, self.num_heads, self.head_dim),
bias_axes="nh",
output_shape=(None, self.num_kv_heads, self.head_dim),
bias_axes=bias_axes,
dtype=self.dtype_policy,
name="value_dense",
)
Expand All @@ -67,6 +73,7 @@ def build(self, inputs_shape):
self.output_dense = keras.layers.Dense(
hidden_dim,
dtype=self.dtype_policy,
use_bias=self.use_bias,
name="output_dense",
)
self.output_dense.build(inputs_shape)
Expand Down
13 changes: 11 additions & 2 deletions keras_hub/src/models/falcon/falcon_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,10 @@ class FalconBackbone(Backbone):
}

# Pretrained Falcon decoder.
# TODO: Update the preset.
model = keras_hub.models.FalconBackbone.from_preset("falcon_preset")
model = keras_hub.models.FalconBackbone.from_preset("falcon-7b-instruct")
model(input_data)

model = keras_hub.models.FalconBackbone.from_preset("falcon-rw-1b")
model(input_data)

# Randomly initialized Falcon decoder with a custom config.
Expand All @@ -70,13 +72,16 @@ def __init__(
num_layers,
num_attention_heads,
hidden_dim,
num_kv_heads,
intermediate_dim,
layer_norm_epsilon=1e-5,
attention_dropout_rate=0,
feedforward_dropout_rate=0,
dtype=None,
**kwargs,
):
use_bias = True if hidden_dim == 2048 else False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If use_bias is solely based on hidden_dim then there is no need for use_bias argument.


# === Layers ===
self.token_embedding = ReversibleEmbedding(
input_dim=vocabulary_size,
Expand All @@ -92,7 +97,9 @@ def __init__(
intermediate_dim=intermediate_dim,
attention_dropout_rate=attention_dropout_rate,
feedforward_dropout_rate=feedforward_dropout_rate,
num_kv_heads=num_kv_heads,
dtype=dtype,
use_bias=use_bias,
name=f"transformer_layer_{i}",
)
self.transformer_layers.append(layer)
Expand Down Expand Up @@ -134,6 +141,7 @@ def __init__(
self.intermediate_dim = intermediate_dim
self.attention_dropout_rate = attention_dropout_rate
self.feedforward_dropout_rate = feedforward_dropout_rate
self.num_kv_heads = num_kv_heads
self.layer_norm_epsilon = layer_norm_epsilon

def get_config(self):
Expand All @@ -146,6 +154,7 @@ def get_config(self):
"hidden_dim": self.hidden_dim,
"intermediate_dim": self.intermediate_dim,
"attention_dropout_rate": self.attention_dropout_rate,
"num_kv_heads": self.num_kv_heads,
"feedforward_dropout_rate": self.feedforward_dropout_rate,
"layer_norm_epsilon": self.layer_norm_epsilon,
}
Expand Down
1 change: 1 addition & 0 deletions keras_hub/src/models/falcon/falcon_backbone_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ def setUp(self):
"num_layers": 2,
"num_attention_heads": 8,
"hidden_dim": 16,
"num_kv_heads": 1,
"intermediate_dim": 32,
}
self.input_data = {
Expand Down
1 change: 1 addition & 0 deletions keras_hub/src/models/falcon/falcon_causal_lm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def setUp(self):
vocabulary_size=self.preprocessor.tokenizer.vocabulary_size(),
num_layers=2,
num_attention_heads=2,
num_kv_heads=1,
hidden_dim=4,
intermediate_dim=16,
)
Expand Down
14 changes: 12 additions & 2 deletions keras_hub/src/models/falcon/falcon_transformer_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ def __init__(
self,
num_attention_heads,
intermediate_dim,
num_kv_heads,
use_bias=False,
layer_norm_epsilon=1e-5,
attention_dropout_rate=0,
feedforward_dropout_rate=0,
Expand All @@ -28,11 +30,15 @@ def __init__(
self.layer_norm_epsilon = layer_norm_epsilon
self.attention_dropout_rate = attention_dropout_rate
self.feedforward_dropout_rate = feedforward_dropout_rate
self.num_kv_heads = num_kv_heads
self.use_bias = use_bias

def build(self, decoder_sequence_shape):
self.hidden_dim = decoder_sequence_shape[-1]
self.input_layernorm = keras.layers.LayerNormalization(
epsilon=self.layer_norm_epsilon,
center=True if self.use_bias else False,
scale=True,
dtype=self.dtype_policy,
name="input_layernorm",
)
Expand All @@ -43,7 +49,9 @@ def build(self, decoder_sequence_shape):
self.attention_layer = FalconAttention(
num_heads=self.num_attention_heads,
attention_dropout_rate=self.attention_dropout_rate,
num_kv_heads=self.num_kv_heads,
dtype=self.dtype_policy,
use_bias=self.use_bias,
name="attention",
)
self.attention_layer.build(
Expand All @@ -58,6 +66,8 @@ def build(self, decoder_sequence_shape):

self.post_attention_layernorm = keras.layers.LayerNormalization(
epsilon=self.layer_norm_epsilon,
center=True if self.use_bias else False,
scale=True,
dtype=self.dtype_policy,
name="post_attention_layernorm",
)
Expand All @@ -69,15 +79,15 @@ def build(self, decoder_sequence_shape):
self.dense_h_to_4h = keras.layers.Dense(
self.intermediate_dim,
activation=keras.activations.gelu,
use_bias=True,
use_bias=self.use_bias,
dtype=self.dtype_policy,
name="dense_h_to_4h",
)
self.dense_h_to_4h.build(decoder_sequence_shape)

self.dense_4h_to_h = keras.layers.Dense(
self.hidden_dim,
use_bias=True,
use_bias=self.use_bias,
dtype=self.dtype_policy,
name="dense_4h_to_h",
)
Expand Down
1 change: 1 addition & 0 deletions keras_hub/src/utils/preset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
HF_TOKENIZER_CONFIG_FILE = "tokenizer_config.json"
SAFETENSOR_CONFIG_FILE = "model.safetensors.index.json"
SAFETENSOR_FILE = "model.safetensors"
PYTORCH_BIN_FILE = "pytorch_model.bin"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the file name holds good for all the different models with .bin files?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, some older versions of popular model(BERT, GPT-2, and T5) use this .bin files. So this filename PYTORCH_BIN_FILE good for all this models.


# Global state for preset registry.
BUILTIN_PRESETS = {}
Expand Down
124 changes: 124 additions & 0 deletions keras_hub/src/utils/transformers/convert_falcon.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import numpy as np

from keras_hub.src.models.falcon import FalconBackbone
from keras_hub.src.utils.preset_utils import load_json

backbone_cls = FalconBackbone


def convert_backbone_config(transformers_config):
if transformers_config.get("multi_query", False):
num_kv_heads = 1
else:
num_kv_heads = transformers_config.get(
"num_kv_heads", transformers_config["num_attention_heads"]
)
return {
"vocabulary_size": transformers_config["vocab_size"],
"num_layers": transformers_config["num_hidden_layers"],
"hidden_dim": transformers_config["hidden_size"],
"num_attention_heads": transformers_config["num_attention_heads"],
"head_dim": transformers_config["hidden_size"]
// transformers_config["num_attention_heads"],
"intermediate_dim": transformers_config.get(
"ffn_hidden_size", 4 * transformers_config["hidden_size"]
),
"num_kv_heads": num_kv_heads,
"use_bias": transformers_config.get("use_bias", True),
}


def convert_weights(backbone, loader, transformers_config):
hidden_dim = transformers_config["hidden_size"]
num_attention_heads = transformers_config["num_attention_heads"]
head_dim = hidden_dim // num_attention_heads
if transformers_config.get("multi_query", False):
num_kv_heads = 1
else:
num_kv_heads = transformers_config.get(
"num_kv_heads", num_attention_heads
)

# Embeddings
loader.port_weight(
keras_variable=backbone.get_layer("token_embedding").embeddings,
hf_weight_key="word_embeddings.weight",
)

for i in range(backbone.num_layers):
decoder_layer = backbone.get_layer(f"transformer_layer_{i}")

# Norm layer
loader.port_weight(
keras_variable=decoder_layer.input_layernorm.gamma,
hf_weight_key=f"h.{i}.input_layernorm.weight",
)

if decoder_layer.input_layernorm.beta is not None:
loader.port_weight(
keras_variable=decoder_layer.input_layernorm.beta,
hf_weight_key=f"h.{i}.input_layernorm.bias",
)
# Attention layers
loader.port_weight(
keras_variable=decoder_layer.attention_layer.output_dense.kernel,
hf_weight_key=f"h.{i}.self_attention.dense.weight",
)

# Load the combined QKV weight
hf_qkv_tensor = loader.get_tensor(
f"h.{i}.self_attention.query_key_value.weight"
)

if hf_qkv_tensor.shape[0] != hidden_dim:
hf_qkv_tensor = np.transpose(hf_qkv_tensor)

query_output_dim = num_attention_heads * head_dim
kv_output_dim = num_kv_heads * head_dim
query_kernel = hf_qkv_tensor[:, :query_output_dim]
key_kernel = hf_qkv_tensor[
:, query_output_dim : query_output_dim + kv_output_dim
]
value_kernel = hf_qkv_tensor[:, query_output_dim + kv_output_dim :]
query_kernel = query_kernel.reshape(
hidden_dim, num_attention_heads, head_dim
)
key_kernel = key_kernel.reshape(hidden_dim, num_kv_heads, head_dim)
value_kernel = value_kernel.reshape(hidden_dim, num_kv_heads, head_dim)
decoder_layer.attention_layer.query_dense.kernel.assign(query_kernel)
decoder_layer.attention_layer.key_dense.kernel.assign(key_kernel)
decoder_layer.attention_layer.value_dense.kernel.assign(value_kernel)

# MLP dense layers
loader.port_weight(
keras_variable=decoder_layer.dense_h_to_4h.kernel,
hf_weight_key=f"h.{i}.mlp.dense_h_to_4h.weight",
hook_fn=lambda x, y: np.transpose(x),
)

loader.port_weight(
keras_variable=decoder_layer.dense_4h_to_h.kernel,
hf_weight_key=f"h.{i}.mlp.dense_4h_to_h.weight",
hook_fn=lambda x, y: np.transpose(x),
)

if hasattr(backbone, "final_layernorm"):
loader.port_weight(
keras_variable=backbone.final_layernorm.gamma,
hf_weight_key="ln_f.weight",
)
loader.port_weight(
keras_variable=backbone.final_layernorm.beta,
hf_weight_key="ln_f.bias",
)


def convert_tokenizer(cls, preset, **kwargs):
tokenizer_data = load_json(preset, "tokenizer.json")
vocab = tokenizer_data["model"]["vocab"]
merges = tokenizer_data["model"].get("merges", None)
tokenizer_kwargs = {"vocabulary": vocab}
if merges is not None:
tokenizer_kwargs["merges"] = merges
tokenizer_kwargs.update(kwargs)
return cls(**tokenizer_kwargs)
23 changes: 23 additions & 0 deletions keras_hub/src/utils/transformers/convert_falcon_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import pytest

from keras_hub.src.models.falcon.falcon_backbone import FalconBackbone
from keras_hub.src.models.falcon.falcon_causal_lm import FalconCausalLM
from keras_hub.src.tests.test_case import TestCase


class TestTask(TestCase):
@pytest.mark.large
def test_convert_tiny_preset(self):
model = FalconCausalLM.from_preset("hf://tiiuae/falcon-rw-1b")
prompt = "What is your favorite condiment?"
model.generate([prompt], max_length=15)

@pytest.mark.large
def test_class_detection(self):
model = FalconCausalLM.from_preset("hf://tiiuae/falcon-rw-1b")
self.assertIsInstance(model, FalconCausalLM)
model = FalconBackbone.from_preset(
"hf://tiiuae/falcon-1b",
load_weights=False,
)
self.assertIsInstance(model, FalconBackbone)
5 changes: 5 additions & 0 deletions keras_hub/src/utils/transformers/preset_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from keras_hub.src.utils.transformers import convert_dinov2
from keras_hub.src.utils.transformers import convert_distilbert
from keras_hub.src.utils.transformers import convert_esm
from keras_hub.src.utils.transformers import convert_falcon
from keras_hub.src.utils.transformers import convert_gemma
from keras_hub.src.utils.transformers import convert_gpt2
from keras_hub.src.utils.transformers import convert_llama3
Expand Down Expand Up @@ -55,6 +56,8 @@ def __init__(self, preset, config):
self.converter = convert_pali_gemma
elif model_type == "vit":
self.converter = convert_vit
elif model_type == "falcon":
self.converter = convert_falcon
elif model_type == "qwen2":
self.converter = convert_qwen
elif model_type == "mixtral":
Expand All @@ -76,6 +79,8 @@ def check_backbone_class(self):

def load_backbone(self, cls, load_weights, **kwargs):
keras_config = self.converter.convert_backbone_config(self.config)
if "num_kv_heads" in keras_config:
kwargs["num_kv_heads"] = keras_config.pop("num_kv_heads")
backbone = cls(**{**keras_config, **kwargs})
if load_weights:
jax_memory_cleanup(backbone)
Expand Down
Loading
Loading