-
Notifications
You must be signed in to change notification settings - Fork 288
Safetensors conversion #2290
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Safetensors conversion #2290
Conversation
Thanks for the PR, will take a look in a bit :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Just left some initial comments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's add a unit test that calls this util and tries loading the result with transformers and seeing if it works. OK to add transformers to our ci environment here https://github.com/keras-team/keras-hub/blob/master/requirements-common.txt
import os | ||
|
||
import torch | ||
from safetensors.torch import save_file |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does this work on all backends? or do we need to flip between versions depending on the backend? worth testing out
… into safetensors_conversion merge updated branch
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! Please address the changes from the earlier PR as well
keras_hub/src/utils/transformers/export_gemma_to_safetensors_test.py
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, nice work!
return hf_config | ||
|
||
|
||
def export_to_hf(keras_model, path): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should add the API export decorator here, similar to this: https://github.com/keras-team/keras-hub/blob/master/keras_hub/src/models/bloom/bloom_backbone.py#L15-L16
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, do you think we should refactor some of the common code across models to a separate file? We can then expose that as the API.
So, this is how the directory keras_hub/src/utils/transformers/convert_to_safetensor/
will look like:
export.py
: this will have the common code. We will expose this as the API. This will also check if we support safetensor conversion for a given passed model yet.gemma.py
: this will just have a way to create the weight dictionary for Gemma. Insideexport.py
, we will call the the weight conversion function specific to a specified model.
Pinging @mattdangerw to confirm if we should do this now or at a later point.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we could land and do the API bit a later point. Though agree it's an important concern. I'm not sure if we want a method like model.save_to_preset()
or a function like some_export(model)
. Any thoughts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think structuring the export logic with a utility function (export_to_hf) and model-specific mappings (gemma.py) will enhance scalability and maintainability. New models can be added by creating a new file, while existing tests only need an import update.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 to Abheesht's comment we need an API instead of a script for Gemma, we already have that
https://github.com/keras-team/keras-hub/blob/master/tools/gemma/export_gemma_to_hf.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Leaving comments since I don't see the changes we discussed last week.
@pytest.mark.large | ||
def test_export_to_hf(self): | ||
# Load Keras model | ||
keras_model = GemmaCausalLM.from_preset("gemma_2b_en") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We discussed this last week. In order to make GPU tests work, we need to use a smaller, randomly initialised Gemma model so that we don't hit OOM.
Args: | ||
keras_model: The Keras Gemma model (e.g., GemmaCausalLM) to convert. | ||
path: str. Path of the directory to which the safetensors file, | ||
config and tokenizer will be saved. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indent this
from safetensors.flax import save_file as flax_save_file | ||
from safetensors.tensorflow import save_file as tf_save_file | ||
from safetensors.torch import save_file as torch_save_file |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Discussed last week. We are supposed to import these conditionally, we don't want to import all of these in every case. If backend is JAX, import the Flax one, if backend is Torch, import the torch one, etc. You can raise an ImportError
if they are not present. Maybe, something like this?
keras-hub/keras_hub/src/utils/transformers/safetensor_utils.py
Lines 9 to 12 in 25c9062
try: | |
import safetensors | |
except ImportError: | |
safetensors = None |
Description of the change
Reference
Colab Notebook
https://colab.research.google.com/drive/1naqf0sO2J40skndWbVMeQismjL7MuEjd?usp=sharingChecklist