diff --git a/keras_hub/src/models/falcon/falcon_attention.py b/keras_hub/src/models/falcon/falcon_attention.py index 48db9664ea..048274466a 100644 --- a/keras_hub/src/models/falcon/falcon_attention.py +++ b/keras_hub/src/models/falcon/falcon_attention.py @@ -9,11 +9,15 @@ def __init__( self, num_heads, attention_dropout_rate, + num_kv_heads, + use_bias=True, **kwargs, ): super().__init__(**kwargs) self.num_heads = num_heads self.attention_dropout_rate = attention_dropout_rate + self.num_kv_heads = num_kv_heads + self.use_bias = use_bias def build(self, inputs_shape): # Einsum variables: @@ -28,13 +32,15 @@ def build(self, inputs_shape): self.head_dim = hidden_dim // self.num_heads + bias_axes = "nh" if self.use_bias else None + # Layer-wise attention scaling self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim) self.query_dense = keras.layers.EinsumDense( equation="bqm,mnh->bqnh", output_shape=(None, self.num_heads, self.head_dim), - bias_axes="nh", + bias_axes=bias_axes, dtype=self.dtype_policy, name="query_dense", ) @@ -42,8 +48,8 @@ def build(self, inputs_shape): self.key_dense = keras.layers.EinsumDense( equation="bkm,mnh->bknh", - output_shape=(None, self.num_heads, self.head_dim), - bias_axes="nh", + output_shape=(None, self.num_kv_heads, self.head_dim), + bias_axes=bias_axes, dtype=self.dtype_policy, name="key_dense", ) @@ -51,8 +57,8 @@ def build(self, inputs_shape): self.value_dense = keras.layers.EinsumDense( equation="bkm,mnh->bknh", - output_shape=(None, self.num_heads, self.head_dim), - bias_axes="nh", + output_shape=(None, self.num_kv_heads, self.head_dim), + bias_axes=bias_axes, dtype=self.dtype_policy, name="value_dense", ) @@ -67,6 +73,7 @@ def build(self, inputs_shape): self.output_dense = keras.layers.Dense( hidden_dim, dtype=self.dtype_policy, + use_bias=self.use_bias, name="output_dense", ) self.output_dense.build(inputs_shape) diff --git a/keras_hub/src/models/falcon/falcon_backbone.py b/keras_hub/src/models/falcon/falcon_backbone.py index 379e7403fa..11ff75eef4 100644 --- a/keras_hub/src/models/falcon/falcon_backbone.py +++ b/keras_hub/src/models/falcon/falcon_backbone.py @@ -44,8 +44,10 @@ class FalconBackbone(Backbone): } # Pretrained Falcon decoder. - # TODO: Update the preset. - model = keras_hub.models.FalconBackbone.from_preset("falcon_preset") + model = keras_hub.models.FalconBackbone.from_preset("falcon-7b-instruct") + model(input_data) + + model = keras_hub.models.FalconBackbone.from_preset("falcon-rw-1b") model(input_data) # Randomly initialized Falcon decoder with a custom config. @@ -70,6 +72,7 @@ def __init__( num_layers, num_attention_heads, hidden_dim, + num_kv_heads, intermediate_dim, layer_norm_epsilon=1e-5, attention_dropout_rate=0, @@ -77,6 +80,8 @@ def __init__( dtype=None, **kwargs, ): + use_bias = True if hidden_dim == 2048 else False + # === Layers === self.token_embedding = ReversibleEmbedding( input_dim=vocabulary_size, @@ -92,7 +97,9 @@ def __init__( intermediate_dim=intermediate_dim, attention_dropout_rate=attention_dropout_rate, feedforward_dropout_rate=feedforward_dropout_rate, + num_kv_heads=num_kv_heads, dtype=dtype, + use_bias=use_bias, name=f"transformer_layer_{i}", ) self.transformer_layers.append(layer) @@ -134,6 +141,7 @@ def __init__( self.intermediate_dim = intermediate_dim self.attention_dropout_rate = attention_dropout_rate self.feedforward_dropout_rate = feedforward_dropout_rate + self.num_kv_heads = num_kv_heads self.layer_norm_epsilon = layer_norm_epsilon def get_config(self): @@ -146,6 +154,7 @@ def get_config(self): "hidden_dim": self.hidden_dim, "intermediate_dim": self.intermediate_dim, "attention_dropout_rate": self.attention_dropout_rate, + "num_kv_heads": self.num_kv_heads, "feedforward_dropout_rate": self.feedforward_dropout_rate, "layer_norm_epsilon": self.layer_norm_epsilon, } diff --git a/keras_hub/src/models/falcon/falcon_backbone_test.py b/keras_hub/src/models/falcon/falcon_backbone_test.py index e46a3bd97b..924cf25f97 100644 --- a/keras_hub/src/models/falcon/falcon_backbone_test.py +++ b/keras_hub/src/models/falcon/falcon_backbone_test.py @@ -12,6 +12,7 @@ def setUp(self): "num_layers": 2, "num_attention_heads": 8, "hidden_dim": 16, + "num_kv_heads": 1, "intermediate_dim": 32, } self.input_data = { diff --git a/keras_hub/src/models/falcon/falcon_causal_lm_test.py b/keras_hub/src/models/falcon/falcon_causal_lm_test.py index 393f8a8e97..b41e342bb5 100644 --- a/keras_hub/src/models/falcon/falcon_causal_lm_test.py +++ b/keras_hub/src/models/falcon/falcon_causal_lm_test.py @@ -31,6 +31,7 @@ def setUp(self): vocabulary_size=self.preprocessor.tokenizer.vocabulary_size(), num_layers=2, num_attention_heads=2, + num_kv_heads=1, hidden_dim=4, intermediate_dim=16, ) diff --git a/keras_hub/src/models/falcon/falcon_transformer_decoder.py b/keras_hub/src/models/falcon/falcon_transformer_decoder.py index 7eaca5f22e..ceedff7562 100644 --- a/keras_hub/src/models/falcon/falcon_transformer_decoder.py +++ b/keras_hub/src/models/falcon/falcon_transformer_decoder.py @@ -17,6 +17,8 @@ def __init__( self, num_attention_heads, intermediate_dim, + num_kv_heads, + use_bias=False, layer_norm_epsilon=1e-5, attention_dropout_rate=0, feedforward_dropout_rate=0, @@ -28,11 +30,15 @@ def __init__( self.layer_norm_epsilon = layer_norm_epsilon self.attention_dropout_rate = attention_dropout_rate self.feedforward_dropout_rate = feedforward_dropout_rate + self.num_kv_heads = num_kv_heads + self.use_bias = use_bias def build(self, decoder_sequence_shape): self.hidden_dim = decoder_sequence_shape[-1] self.input_layernorm = keras.layers.LayerNormalization( epsilon=self.layer_norm_epsilon, + center=True if self.use_bias else False, + scale=True, dtype=self.dtype_policy, name="input_layernorm", ) @@ -43,7 +49,9 @@ def build(self, decoder_sequence_shape): self.attention_layer = FalconAttention( num_heads=self.num_attention_heads, attention_dropout_rate=self.attention_dropout_rate, + num_kv_heads=self.num_kv_heads, dtype=self.dtype_policy, + use_bias=self.use_bias, name="attention", ) self.attention_layer.build( @@ -58,6 +66,8 @@ def build(self, decoder_sequence_shape): self.post_attention_layernorm = keras.layers.LayerNormalization( epsilon=self.layer_norm_epsilon, + center=True if self.use_bias else False, + scale=True, dtype=self.dtype_policy, name="post_attention_layernorm", ) @@ -69,7 +79,7 @@ def build(self, decoder_sequence_shape): self.dense_h_to_4h = keras.layers.Dense( self.intermediate_dim, activation=keras.activations.gelu, - use_bias=True, + use_bias=self.use_bias, dtype=self.dtype_policy, name="dense_h_to_4h", ) @@ -77,7 +87,7 @@ def build(self, decoder_sequence_shape): self.dense_4h_to_h = keras.layers.Dense( self.hidden_dim, - use_bias=True, + use_bias=self.use_bias, dtype=self.dtype_policy, name="dense_4h_to_h", ) diff --git a/keras_hub/src/utils/preset_utils.py b/keras_hub/src/utils/preset_utils.py index affb2362c6..400365ed4f 100644 --- a/keras_hub/src/utils/preset_utils.py +++ b/keras_hub/src/utils/preset_utils.py @@ -60,6 +60,7 @@ HF_TOKENIZER_CONFIG_FILE = "tokenizer_config.json" SAFETENSOR_CONFIG_FILE = "model.safetensors.index.json" SAFETENSOR_FILE = "model.safetensors" +PYTORCH_BIN_FILE = "pytorch_model.bin" # Global state for preset registry. BUILTIN_PRESETS = {} diff --git a/keras_hub/src/utils/transformers/convert_falcon.py b/keras_hub/src/utils/transformers/convert_falcon.py new file mode 100644 index 0000000000..74a98a16f7 --- /dev/null +++ b/keras_hub/src/utils/transformers/convert_falcon.py @@ -0,0 +1,124 @@ +import numpy as np + +from keras_hub.src.models.falcon import FalconBackbone +from keras_hub.src.utils.preset_utils import load_json + +backbone_cls = FalconBackbone + + +def convert_backbone_config(transformers_config): + if transformers_config.get("multi_query", False): + num_kv_heads = 1 + else: + num_kv_heads = transformers_config.get( + "num_kv_heads", transformers_config["num_attention_heads"] + ) + return { + "vocabulary_size": transformers_config["vocab_size"], + "num_layers": transformers_config["num_hidden_layers"], + "hidden_dim": transformers_config["hidden_size"], + "num_attention_heads": transformers_config["num_attention_heads"], + "head_dim": transformers_config["hidden_size"] + // transformers_config["num_attention_heads"], + "intermediate_dim": transformers_config.get( + "ffn_hidden_size", 4 * transformers_config["hidden_size"] + ), + "num_kv_heads": num_kv_heads, + "use_bias": transformers_config.get("use_bias", True), + } + + +def convert_weights(backbone, loader, transformers_config): + hidden_dim = transformers_config["hidden_size"] + num_attention_heads = transformers_config["num_attention_heads"] + head_dim = hidden_dim // num_attention_heads + if transformers_config.get("multi_query", False): + num_kv_heads = 1 + else: + num_kv_heads = transformers_config.get( + "num_kv_heads", num_attention_heads + ) + + # Embeddings + loader.port_weight( + keras_variable=backbone.get_layer("token_embedding").embeddings, + hf_weight_key="word_embeddings.weight", + ) + + for i in range(backbone.num_layers): + decoder_layer = backbone.get_layer(f"transformer_layer_{i}") + + # Norm layer + loader.port_weight( + keras_variable=decoder_layer.input_layernorm.gamma, + hf_weight_key=f"h.{i}.input_layernorm.weight", + ) + + if decoder_layer.input_layernorm.beta is not None: + loader.port_weight( + keras_variable=decoder_layer.input_layernorm.beta, + hf_weight_key=f"h.{i}.input_layernorm.bias", + ) + # Attention layers + loader.port_weight( + keras_variable=decoder_layer.attention_layer.output_dense.kernel, + hf_weight_key=f"h.{i}.self_attention.dense.weight", + ) + + # Load the combined QKV weight + hf_qkv_tensor = loader.get_tensor( + f"h.{i}.self_attention.query_key_value.weight" + ) + + if hf_qkv_tensor.shape[0] != hidden_dim: + hf_qkv_tensor = np.transpose(hf_qkv_tensor) + + query_output_dim = num_attention_heads * head_dim + kv_output_dim = num_kv_heads * head_dim + query_kernel = hf_qkv_tensor[:, :query_output_dim] + key_kernel = hf_qkv_tensor[ + :, query_output_dim : query_output_dim + kv_output_dim + ] + value_kernel = hf_qkv_tensor[:, query_output_dim + kv_output_dim :] + query_kernel = query_kernel.reshape( + hidden_dim, num_attention_heads, head_dim + ) + key_kernel = key_kernel.reshape(hidden_dim, num_kv_heads, head_dim) + value_kernel = value_kernel.reshape(hidden_dim, num_kv_heads, head_dim) + decoder_layer.attention_layer.query_dense.kernel.assign(query_kernel) + decoder_layer.attention_layer.key_dense.kernel.assign(key_kernel) + decoder_layer.attention_layer.value_dense.kernel.assign(value_kernel) + + # MLP dense layers + loader.port_weight( + keras_variable=decoder_layer.dense_h_to_4h.kernel, + hf_weight_key=f"h.{i}.mlp.dense_h_to_4h.weight", + hook_fn=lambda x, y: np.transpose(x), + ) + + loader.port_weight( + keras_variable=decoder_layer.dense_4h_to_h.kernel, + hf_weight_key=f"h.{i}.mlp.dense_4h_to_h.weight", + hook_fn=lambda x, y: np.transpose(x), + ) + + if hasattr(backbone, "final_layernorm"): + loader.port_weight( + keras_variable=backbone.final_layernorm.gamma, + hf_weight_key="ln_f.weight", + ) + loader.port_weight( + keras_variable=backbone.final_layernorm.beta, + hf_weight_key="ln_f.bias", + ) + + +def convert_tokenizer(cls, preset, **kwargs): + tokenizer_data = load_json(preset, "tokenizer.json") + vocab = tokenizer_data["model"]["vocab"] + merges = tokenizer_data["model"].get("merges", None) + tokenizer_kwargs = {"vocabulary": vocab} + if merges is not None: + tokenizer_kwargs["merges"] = merges + tokenizer_kwargs.update(kwargs) + return cls(**tokenizer_kwargs) diff --git a/keras_hub/src/utils/transformers/convert_falcon_test.py b/keras_hub/src/utils/transformers/convert_falcon_test.py new file mode 100644 index 0000000000..12190c4426 --- /dev/null +++ b/keras_hub/src/utils/transformers/convert_falcon_test.py @@ -0,0 +1,23 @@ +import pytest + +from keras_hub.src.models.falcon.falcon_backbone import FalconBackbone +from keras_hub.src.models.falcon.falcon_causal_lm import FalconCausalLM +from keras_hub.src.tests.test_case import TestCase + + +class TestTask(TestCase): + @pytest.mark.large + def test_convert_tiny_preset(self): + model = FalconCausalLM.from_preset("hf://tiiuae/falcon-rw-1b") + prompt = "What is your favorite condiment?" + model.generate([prompt], max_length=15) + + @pytest.mark.large + def test_class_detection(self): + model = FalconCausalLM.from_preset("hf://tiiuae/falcon-rw-1b") + self.assertIsInstance(model, FalconCausalLM) + model = FalconBackbone.from_preset( + "hf://tiiuae/falcon-1b", + load_weights=False, + ) + self.assertIsInstance(model, FalconBackbone) diff --git a/keras_hub/src/utils/transformers/preset_loader.py b/keras_hub/src/utils/transformers/preset_loader.py index bfca6e7bc5..c639d46703 100644 --- a/keras_hub/src/utils/transformers/preset_loader.py +++ b/keras_hub/src/utils/transformers/preset_loader.py @@ -10,6 +10,7 @@ from keras_hub.src.utils.transformers import convert_dinov2 from keras_hub.src.utils.transformers import convert_distilbert from keras_hub.src.utils.transformers import convert_esm +from keras_hub.src.utils.transformers import convert_falcon from keras_hub.src.utils.transformers import convert_gemma from keras_hub.src.utils.transformers import convert_gpt2 from keras_hub.src.utils.transformers import convert_llama3 @@ -55,6 +56,8 @@ def __init__(self, preset, config): self.converter = convert_pali_gemma elif model_type == "vit": self.converter = convert_vit + elif model_type == "falcon": + self.converter = convert_falcon elif model_type == "qwen2": self.converter = convert_qwen elif model_type == "mixtral": @@ -76,6 +79,8 @@ def check_backbone_class(self): def load_backbone(self, cls, load_weights, **kwargs): keras_config = self.converter.convert_backbone_config(self.config) + if "num_kv_heads" in keras_config: + kwargs["num_kv_heads"] = keras_config.pop("num_kv_heads") backbone = cls(**{**keras_config, **kwargs}) if load_weights: jax_memory_cleanup(backbone) diff --git a/keras_hub/src/utils/transformers/safetensor_utils.py b/keras_hub/src/utils/transformers/safetensor_utils.py index 24c8dff338..03dbcebe07 100644 --- a/keras_hub/src/utils/transformers/safetensor_utils.py +++ b/keras_hub/src/utils/transformers/safetensor_utils.py @@ -1,11 +1,17 @@ import contextlib +from keras_hub.src.utils.preset_utils import PYTORCH_BIN_FILE from keras_hub.src.utils.preset_utils import SAFETENSOR_CONFIG_FILE from keras_hub.src.utils.preset_utils import SAFETENSOR_FILE from keras_hub.src.utils.preset_utils import check_file_exists from keras_hub.src.utils.preset_utils import get_file from keras_hub.src.utils.preset_utils import load_json +try: + import torch +except ImportError: + torch = None + try: import safetensors except ImportError: @@ -65,25 +71,53 @@ def get_prefixed_key(self, hf_weight_key, dict_like): return hf_weight_key def get_tensor(self, hf_weight_key): - if self.safetensor_config is None: - fname = self.fname if self.fname is not None else SAFETENSOR_FILE - else: + if self.safetensor_config is not None: full_key = self.get_prefixed_key( hf_weight_key, self.safetensor_config["weight_map"] ) fname = self.safetensor_config["weight_map"][full_key] - - if fname in self.safetensor_files: - file = self.safetensor_files[fname] + elif self.fname is not None: + fname = self.fname else: + if check_file_exists(self.preset, SAFETENSOR_FILE): + fname = SAFETENSOR_FILE + elif check_file_exists(self.preset, PYTORCH_BIN_FILE): + fname = PYTORCH_BIN_FILE + else: + raise FileNotFoundError( + f"No supported weight file found for preset {self.preset}" + ) + + if fname.endswith(".safetensors"): + if fname in self.safetensor_files: + file = self.safetensor_files[fname] + else: + path = get_file(self.preset, fname) + file = self.enter_context( + safetensors.safe_open(path, framework="np") + ) + self.safetensor_files[fname] = file + + full_key = self.get_prefixed_key(hf_weight_key, file) + return file.get_tensor(full_key) + elif fname.endswith(".bin"): + if torch is None: + raise ImportError( + "Loading a `.bin` file requires PyTorch. " + "Please install with `pip install torch`." + ) path = get_file(self.preset, fname) - file = self.enter_context( - safetensors.safe_open(path, framework="np") - ) - self.safetensor_files[fname] = file - - full_key = self.get_prefixed_key(hf_weight_key, file) - return file.get_tensor(full_key) + if fname in self.safetensor_files: + file = self.safetensor_files[fname] + else: + state_dict = torch.load(path, map_location="cpu") + file = { + k: v.to(torch.float32).numpy() + for k, v in state_dict.items() + } + self.safetensor_files[fname] = file + full_key = self.get_prefixed_key(hf_weight_key, file) + return file[full_key] def port_weight(self, keras_variable, hf_weight_key, hook_fn=None): hf_tensor = self.get_tensor(hf_weight_key) diff --git a/tools/checkpoint_conversion/convert_falcon_checkpoints.py b/tools/checkpoint_conversion/convert_falcon_checkpoints.py index 68d2ca2578..7358a7393d 100644 --- a/tools/checkpoint_conversion/convert_falcon_checkpoints.py +++ b/tools/checkpoint_conversion/convert_falcon_checkpoints.py @@ -16,6 +16,8 @@ ``` python tools/checkpoint_conversion/convert_falcon_checkpoints.py \ --preset falcon_refinedweb_1b_en +python tools/checkpoint_conversion/convert_falcon_checkpoints.py \ + --preset falcon_7b_instruct ``` """ @@ -35,6 +37,7 @@ PRESET_MAP = { "falcon_refinedweb_1b_en": "tiiuae/falcon-rw-1b", + "falcon_7b_instruct": "tiiuae/falcon-7b-instruct", } EXTRACT_DIR = "./model" @@ -50,7 +53,7 @@ def download_hf_model(hf_model_name): hf_model_dir = huggingface_hub.snapshot_download( repo_id=hf_model_name, - allow_patterns=["*.json", "*.bin"], + allow_patterns=["*.json", "*.bin", "*.safetensors"], ignore_patterns=["onnx/*"], local_dir=EXTRACT_DIR, ) @@ -68,6 +71,12 @@ def convert_model(hf_model): kwargs["intermediate_dim"] = 4 * kwargs["hidden_dim"] kwargs["feedforward_dropout_rate"] = hf_config["hidden_dropout"] kwargs["attention_dropout_rate"] = hf_config["attention_dropout"] + if hf_config.get("multi_query", False): + kwargs["num_kv_heads"] = 1 + else: + kwargs["num_kv_heads"] = hf_config.get( + "num_kv_heads", kwargs["num_attention_heads"] + ) return keras_hub.models.FalconBackbone(**kwargs)