-
Notifications
You must be signed in to change notification settings - Fork 311
Added Falcon model converter and Added Falcon 7b support #2040
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 2 commits
f0d3696
21df61e
bc4b4f7
060e95c
496f3e7
9dd0e61
b990401
8f2284c
3642f1e
6da4ced
a8ea36f
cea948d
60078c5
d7a5c31
c7d4a9c
152c19e
3bc83bd
d3cbdec
50e6d06
89bac89
164e6cc
7873b3c
5f174d4
559ee01
5047254
af2c647
3aaa529
13c04d7
8cc06a6
21e4473
6aa4244
1ce3837
496eeeb
9ccc46a
b64cd4c
fba4aba
f3c5041
9b860c1
41289d8
2284520
2933774
b3ba59a
f0fb361
88a91a1
6b26899
e7b39bb
8e7520b
6d5ae8c
5948b6e
79951dd
a9eed7c
0243caf
06bd348
a588b76
92aa32b
9345630
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,60 @@ | ||
| import numpy as np | ||
|
|
||
| from keras_hub.src.models.falcon import FalconBackbone | ||
| from keras_hub.src.utils.preset_utils import load_json | ||
|
|
||
| backbone_cls = FalconBackbone | ||
|
|
||
|
|
||
| def convert_backbone_config(transformers_config): | ||
| return { | ||
| "vocabulary_size": transformers_config["vocab_size"], | ||
| "num_layers": transformers_config["num_hidden_layers"], | ||
| "num_attention_heads": transformers_config["num_attention_heads"], | ||
| "hidden_dim": transformers_config["hidden_size"], | ||
| "intermediate_dim": 32 * 4, | ||
| } | ||
|
|
||
|
|
||
| def transpose_and_reshape(x, shape): | ||
| return np.reshape(np.transpose(x), shape) | ||
|
|
||
|
|
||
| def convert_weights(backbone, loader, transformers_config): | ||
| # Embeddings | ||
| loader.port_weight( | ||
| keras_variable=backbone.get_layer("token_embedding").embeddings, | ||
| hf_weight_key="word_embeddings.weight", | ||
| ) | ||
|
|
||
| for i in range(backbone.num_layers): | ||
| decoder_layer = backbone.get_layer(f"transformer_layer_{i}") | ||
|
|
||
| # Norm layer | ||
| loader.port_weight( | ||
| keras_variable=decoder_layer.input_layernorm.gamma, | ||
| hf_weight_key=f"h.{i}.input_layernorm.weight", | ||
| ) | ||
|
|
||
| # Attention layers | ||
| loader.port_weight( | ||
| keras_variable=decoder_layer.attention_layer.output_dense.kernel, | ||
| hf_weight_key=f"h.{i}.self_attention.dense.weight", | ||
| ) | ||
|
|
||
| loader.port_weight( | ||
| keras_variable=decoder_layer.post_attention_layernorm.gamma, | ||
| hf_weight_key=f"h.{i}.self_attention.query_key_value.weight", | ||
| hook_fn=lambda hf_tensor, keras_shape: np.mean( | ||
| np.reshape(hf_tensor, (-1, keras_shape[0])), axis=0 | ||
| ), | ||
| ) | ||
|
|
||
|
|
||
| def convert_tokenizer(cls, preset, **kwargs): | ||
| tokenizer_data = load_json(preset, "tokenizer.json") | ||
| vocab = tokenizer_data["model"]["vocab"] | ||
| merges = tokenizer_data["model"].get("merges", None) | ||
|
|
||
| tokenizer_kwargs = {"vocabulary": vocab, "merges": merges} | ||
| return cls(**tokenizer_kwargs) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,23 @@ | ||
| import pytest | ||
|
|
||
| from keras_hub.src.models.falcon.falcon_backbone import FalconBackbone | ||
| from keras_hub.src.models.falcon.falcon_causal_lm import FalconCausalLM | ||
| from keras_hub.src.tests.test_case import TestCase | ||
|
|
||
|
|
||
| class TestTask(TestCase): | ||
| @pytest.mark.large | ||
| def test_convert_tiny_preset(self): | ||
| model = FalconCausalLM.from_preset("hf://tiiuae/falcon-7b") | ||
| prompt = "What is your favorite condiment?" | ||
| model.generate([prompt], max_length=15) | ||
|
|
||
| @pytest.mark.large | ||
| def test_class_detection(self): | ||
| model = FalconCausalLM.from_preset("hf://tiiuae/falcon-7b") | ||
|
||
| self.assertIsInstance(model, FalconCausalLM) | ||
| model = FalconBackbone.from_preset( | ||
| "hf://tiiuae/falcon-7b", | ||
| load_weights=False, | ||
| ) | ||
| self.assertIsInstance(model, FalconBackbone) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we can afford to download this ~15gb file in our testing setup. You could try the 1b model? Or create a small test model on hf, as was done for llama and others.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mattdangerw - I'll create small test with 1b falcon model and commit again.