diff --git a/keras_hub/src/layers/preprocessing/audio_converter.py b/keras_hub/src/layers/preprocessing/audio_converter.py index 8e655702c4..4069c40a5a 100644 --- a/keras_hub/src/layers/preprocessing/audio_converter.py +++ b/keras_hub/src/layers/preprocessing/audio_converter.py @@ -3,7 +3,6 @@ PreprocessingLayer, ) from keras_hub.src.utils.preset_utils import builtin_presets -from keras_hub.src.utils.preset_utils import find_subclass from keras_hub.src.utils.preset_utils import get_preset_loader from keras_hub.src.utils.preset_utils import get_preset_saver from keras_hub.src.utils.python_utils import classproperty @@ -89,10 +88,7 @@ class like `keras_hub.models.AudioConverter.from_preset()`, or from a ``` """ loader = get_preset_loader(preset) - backbone_cls = loader.check_backbone_class() - if cls.backbone_cls != backbone_cls: - cls = find_subclass(preset, cls, backbone_cls) - return loader.load_audio_converter(cls, **kwargs) + return loader.load_audio_converter(cls=cls, kwargs=kwargs) def save_to_preset(self, preset_dir): """Save audio converter to a preset directory. diff --git a/keras_hub/src/layers/preprocessing/audio_converter_test.py b/keras_hub/src/layers/preprocessing/audio_converter_test.py index f0580970c2..839e291b2c 100644 --- a/keras_hub/src/layers/preprocessing/audio_converter_test.py +++ b/keras_hub/src/layers/preprocessing/audio_converter_test.py @@ -29,7 +29,7 @@ def test_from_preset(self): @pytest.mark.large def test_from_preset_errors(self): - with self.assertRaises(ValueError): + with self.assertRaises(FileNotFoundError): AudioConverter.from_preset("bert_tiny_en_uncased") with self.assertRaises(ValueError): # No loading on a non-keras model. diff --git a/keras_hub/src/layers/preprocessing/image_converter.py b/keras_hub/src/layers/preprocessing/image_converter.py index b86cd6d2a0..2fd88da82c 100644 --- a/keras_hub/src/layers/preprocessing/image_converter.py +++ b/keras_hub/src/layers/preprocessing/image_converter.py @@ -11,7 +11,6 @@ ) from keras_hub.src.utils.keras_utils import standardize_data_format from keras_hub.src.utils.preset_utils import builtin_presets -from keras_hub.src.utils.preset_utils import find_subclass from keras_hub.src.utils.preset_utils import get_preset_loader from keras_hub.src.utils.preset_utils import get_preset_saver from keras_hub.src.utils.python_utils import classproperty @@ -380,10 +379,7 @@ def from_preset( ``` """ loader = get_preset_loader(preset) - backbone_cls = loader.check_backbone_class() - if cls.backbone_cls != backbone_cls: - cls = find_subclass(preset, cls, backbone_cls) - return loader.load_image_converter(cls, **kwargs) + return loader.load_image_converter(cls=cls, kwargs=kwargs) def save_to_preset(self, preset_dir): """Save image converter to a preset directory. diff --git a/keras_hub/src/layers/preprocessing/image_converter_test.py b/keras_hub/src/layers/preprocessing/image_converter_test.py index 8d47872a43..c40464d5d4 100644 --- a/keras_hub/src/layers/preprocessing/image_converter_test.py +++ b/keras_hub/src/layers/preprocessing/image_converter_test.py @@ -122,7 +122,7 @@ def test_from_preset(self): @pytest.mark.large def test_from_preset_errors(self): - with self.assertRaises(ValueError): + with self.assertRaises(FileNotFoundError): ImageConverter.from_preset("bert_tiny_en_uncased") with self.assertRaises(ValueError): # No loading on a non-keras model. diff --git a/keras_hub/src/models/backbone.py b/keras_hub/src/models/backbone.py index cd4c7ecaf9..91d68d44b0 100644 --- a/keras_hub/src/models/backbone.py +++ b/keras_hub/src/models/backbone.py @@ -168,14 +168,9 @@ class like `keras_hub.models.Backbone.from_preset()`, or from ``` """ loader = get_preset_loader(preset) - backbone_cls = loader.check_backbone_class() - if not issubclass(backbone_cls, cls): - raise ValueError( - f"Saved preset has type `{backbone_cls.__name__}` which is not " - f"a subclass of calling class `{cls.__name__}`. Call " - f"`from_preset` directly on `{backbone_cls.__name__}` instead." - ) - return loader.load_backbone(backbone_cls, load_weights, **kwargs) + return loader.load_backbone( + cls=cls, load_weights=load_weights, kwargs=kwargs + ) def save_to_preset(self, preset_dir, max_shard_size=10): """Save backbone to a preset directory. diff --git a/keras_hub/src/models/preprocessor.py b/keras_hub/src/models/preprocessor.py index de2a94af1d..eddd50b7ae 100644 --- a/keras_hub/src/models/preprocessor.py +++ b/keras_hub/src/models/preprocessor.py @@ -6,7 +6,6 @@ ) from keras_hub.src.utils.preset_utils import PREPROCESSOR_CONFIG_FILE from keras_hub.src.utils.preset_utils import builtin_presets -from keras_hub.src.utils.preset_utils import find_subclass from keras_hub.src.utils.preset_utils import get_preset_loader from keras_hub.src.utils.preset_utils import get_preset_saver from keras_hub.src.utils.python_utils import classproperty @@ -171,7 +170,7 @@ def from_preset( ) ``` """ - if cls == Preprocessor: + if cls is Preprocessor: raise ValueError( "Do not call `Preprocessor.from_preset()` directly. Instead " "choose a particular task preprocessing class, e.g. " @@ -179,35 +178,30 @@ def from_preset( ) loader = get_preset_loader(preset) - backbone_cls = loader.check_backbone_class() - # Detect the correct subclass if we need to. - if cls.backbone_cls != backbone_cls: - cls = find_subclass(preset, cls, backbone_cls) - return loader.load_preprocessor(cls, config_file, **kwargs) + return loader.load_preprocessor( + cls=cls, config_file=config_file, kwargs=kwargs + ) @classmethod - def _add_missing_kwargs(cls, loader, kwargs): - """Fill in required kwargs when loading from preset. - - This is a private method hit when loading a preprocessing layer that - was not directly saved in the preset. This method should fill in - all required kwargs required to call the class constructor. For almost, - all preprocessors, the only required args are `tokenizer`, - `image_converter`, and `audio_converter`, but this can be overridden, - e.g. for a preprocessor with multiple tokenizers for different - encoders. + def _from_defaults(cls, loader, kwargs): + """Load a preprocessor from default values. + + This is a private method hit for loading a preprocessing layer that was + not directly saved in the preset. Usually this means loading a + tokenizer, image_converter and/or audio_converter and calling the + constructor. But this can be overridden by subclasses as needed. """ + defaults = {} + # Allow loading any tokenizer, image_converter or audio_converter config + # we find on disk. We allow mixing a matching tokenizers and + # preprocessing layers (though this is usually not a good idea). if "tokenizer" not in kwargs and cls.tokenizer_cls: - kwargs["tokenizer"] = loader.load_tokenizer(cls.tokenizer_cls) + defaults["tokenizer"] = loader.load_tokenizer() if "audio_converter" not in kwargs and cls.audio_converter_cls: - kwargs["audio_converter"] = loader.load_audio_converter( - cls.audio_converter_cls - ) + defaults["audio_converter"] = loader.load_audio_converter() if "image_converter" not in kwargs and cls.image_converter_cls: - kwargs["image_converter"] = loader.load_image_converter( - cls.image_converter_cls - ) - return kwargs + defaults["image_converter"] = loader.load_image_converter() + return cls(**{**defaults, **kwargs}) def load_preset_assets(self, preset): """Load all static assets needed by the preprocessing layer. diff --git a/keras_hub/src/models/task.py b/keras_hub/src/models/task.py index d273759b46..9af1df4d6d 100644 --- a/keras_hub/src/models/task.py +++ b/keras_hub/src/models/task.py @@ -6,13 +6,11 @@ from keras_hub.src.api_export import keras_hub_export from keras_hub.src.layers.preprocessing.audio_converter import AudioConverter from keras_hub.src.layers.preprocessing.image_converter import ImageConverter -from keras_hub.src.models.backbone import Backbone from keras_hub.src.models.preprocessor import Preprocessor from keras_hub.src.tokenizers.tokenizer import Tokenizer from keras_hub.src.utils.keras_utils import print_msg from keras_hub.src.utils.pipeline_model import PipelineModel from keras_hub.src.utils.preset_utils import builtin_presets -from keras_hub.src.utils.preset_utils import find_subclass from keras_hub.src.utils.preset_utils import get_preset_loader from keras_hub.src.utils.preset_utils import get_preset_saver from keras_hub.src.utils.python_utils import classproperty @@ -175,7 +173,7 @@ def from_preset( ) ``` """ - if cls == Task: + if cls is Task: raise ValueError( "Do not call `Task.from_preset()` directly. Instead call a " "particular task class, e.g. " @@ -183,19 +181,30 @@ def from_preset( ) loader = get_preset_loader(preset) - backbone_cls = loader.check_backbone_class() - # Detect the correct subclass if we need to. - if ( - issubclass(backbone_cls, Backbone) - and cls.backbone_cls != backbone_cls - ): - cls = find_subclass(preset, cls, backbone_cls) - # Specifically for classifiers, we never load task weights if - # num_classes is supplied. We handle this in the task base class because - # it is the same logic for classifiers regardless of modality (text, - # images, audio). - load_task_weights = "num_classes" not in kwargs - return loader.load_task(cls, load_weights, load_task_weights, **kwargs) + return loader.load_task( + cls=cls, load_weights=load_weights, kwargs=kwargs + ) + + @classmethod + def _from_defaults(cls, loader, load_weights, kwargs, backbone_kwargs): + """Load a task from default values. + + This is a private method hit for loading a task layer that was + not directly saved in the preset. Usually this means loading a backbone + and preprocessor and calling the constructor. But this can be overridden + by subclasses as needed. + """ + defaults = {} + if "backbone" not in kwargs: + defaults["backbone"] = loader.load_backbone( + load_weights=load_weights, kwargs=backbone_kwargs + ) + if "preprocessor" not in kwargs and cls.preprocessor_cls: + # Only load the "matching" preprocessor class for a task class. + defaults["preprocessor"] = loader.load_preprocessor( + cls=cls.preprocessor_cls + ) + return cls(**{**defaults, **kwargs}) def load_task_weights(self, filepath): """Load only the tasks specific weights not in the backbone.""" diff --git a/keras_hub/src/tokenizers/tokenizer.py b/keras_hub/src/tokenizers/tokenizer.py index 5e8986a89e..7a4e9673e9 100644 --- a/keras_hub/src/tokenizers/tokenizer.py +++ b/keras_hub/src/tokenizers/tokenizer.py @@ -7,7 +7,6 @@ from keras_hub.src.utils.preset_utils import ASSET_DIR from keras_hub.src.utils.preset_utils import TOKENIZER_CONFIG_FILE from keras_hub.src.utils.preset_utils import builtin_presets -from keras_hub.src.utils.preset_utils import find_subclass from keras_hub.src.utils.preset_utils import get_file from keras_hub.src.utils.preset_utils import get_preset_loader from keras_hub.src.utils.preset_utils import get_preset_saver @@ -257,7 +256,6 @@ class like `keras_hub.models.Tokenizer.from_preset()`, or from ``` """ loader = get_preset_loader(preset) - backbone_cls = loader.check_backbone_class() - if cls.backbone_cls != backbone_cls: - cls = find_subclass(preset, cls, backbone_cls) - return loader.load_tokenizer(cls, config_file, **kwargs) + return loader.load_tokenizer( + cls=cls, config_file=config_file, kwargs=kwargs + ) diff --git a/keras_hub/src/utils/preset_utils.py b/keras_hub/src/utils/preset_utils.py index 9df23b2568..8a1a432449 100644 --- a/keras_hub/src/utils/preset_utils.py +++ b/keras_hub/src/utils/preset_utils.py @@ -100,26 +100,6 @@ def list_subclasses(cls): return subclasses -def find_subclass(preset, cls, backbone_cls): - """Find a subclass that is compatible with backbone_cls.""" - subclasses = list_subclasses(cls) - subclasses = filter(lambda x: x.backbone_cls == backbone_cls, subclasses) - subclasses = list(subclasses) - if not subclasses: - raise ValueError( - f"Unable to find a subclass of {cls.__name__} that is compatible " - f"with {backbone_cls.__name__} found in preset '{preset}'." - ) - # If we find multiple subclasses, try to filter to direct subclasses of - # the class we are trying to instantiate. - if len(subclasses) > 1: - directs = list(filter(lambda x: x in cls.__bases__, subclasses)) - if len(directs) > 1: - subclasses = directs - # Return the subclass that was registered first (prefer built-in classes). - return subclasses[0] - - def get_file(preset, path): """Download a preset file in necessary and return the local path.""" # TODO: Add tests for FileNotFound exceptions. @@ -175,8 +155,9 @@ def get_file(preset, path): from modelscope.hub.snapshot_download import snapshot_download except ImportError: raise ImportError( - "To load a preset from ModelScope {preset} using from_preset," - "install the modelscope package with: pip install modelscope." + "`from_preset()` requires the `modelscope` package to " + f"load from '{preset}'. " + "Please install with `pip install modelscope`." ) modelscope_handle = preset.removeprefix(MODELSCOPE_SCHEME + "://") try: @@ -204,9 +185,9 @@ def get_file(preset, path): elif scheme == HF_SCHEME: if huggingface_hub is None: raise ImportError( - "`from_preset()` requires the `huggingface_hub` package to " - "load from '{preset}'. " - "Please install with `pip install huggingface_hub`." + "`from_preset()` requires the `huggingface-hub` package to " + f"load from '{preset}'. " + "Please install with `pip install huggingface-hub`." ) hf_handle = preset.removeprefix(HF_SCHEME + "://") try: @@ -580,123 +561,143 @@ def __init__(self, preset, config): self.config = config self.preset = preset - def get_backbone_kwargs(self, **kwargs): + def split_backbone_kwargs(self, kwargs): backbone_kwargs = {} - # Forward `dtype` to backbone. backbone_kwargs["dtype"] = kwargs.pop("dtype", None) - # Forward `height` and `width` to backbone when using `TextToImage`. if "image_shape" in kwargs: backbone_kwargs["image_shape"] = kwargs.pop("image_shape", None) - return backbone_kwargs, kwargs def check_backbone_class(self): - """Infer the backbone architecture.""" + """Check the cls of a saved backbone.""" raise NotImplementedError - def load_backbone(self, cls, load_weights, **kwargs): + def find_compatible_subclass(self, cls): + """Find a subclass that is compatible with backbone_cls.""" + backbone_cls = self.check_backbone_class() + if cls.backbone_cls == backbone_cls: + return cls + subclasses = list_subclasses(cls) + subclasses = filter( + lambda x: x.backbone_cls == backbone_cls, subclasses + ) + subclasses = list(subclasses) + if not subclasses: + raise ValueError( + f"Unable to find a subclass of {cls} that is compatible " + f"with {backbone_cls} found in preset '{self.preset}'." + ) + # If we find multiple subclasses, try to filter to direct subclasses of + # the class we are trying to instantiate. + if len(subclasses) > 1: + directs = list(filter(lambda x: x in cls.__bases__, subclasses)) + if len(directs) > 1: + subclasses = directs + # Return the subclass registered first (prefer built-in classes). + return subclasses[0] + + def load_backbone(self, cls=None, load_weights=True, kwargs=None): """Load the backbone model from the preset.""" raise NotImplementedError - def load_tokenizer(self, cls, config_file=TOKENIZER_CONFIG_FILE, **kwargs): + def load_tokenizer( + self, cls=None, config_file=TOKENIZER_CONFIG_FILE, kwargs=None + ): """Load a tokenizer layer from the preset.""" raise NotImplementedError - def load_audio_converter(self, cls, **kwargs): + def load_audio_converter(self, cls=None, kwargs=None): """Load an audio converter layer from the preset.""" raise NotImplementedError - def load_image_converter(self, cls, **kwargs): + def load_image_converter(self, cls=None, kwargs=None): """Load an image converter layer from the preset.""" raise NotImplementedError - def load_task(self, cls, load_weights, load_task_weights, **kwargs): - """Load a task model from the preset. - - By default, we create a task from a backbone and preprocessor with - default arguments. This means - """ - if "backbone" not in kwargs: - backbone_class = cls.backbone_cls - backbone_kwargs, kwargs = self.get_backbone_kwargs(**kwargs) - kwargs["backbone"] = self.load_backbone( - backbone_class, load_weights, **backbone_kwargs - ) - if "preprocessor" not in kwargs and cls.preprocessor_cls: - kwargs["preprocessor"] = self.load_preprocessor( - cls.preprocessor_cls, - ) - - return cls(**kwargs) + def load_task(self, cls, load_weights=True, kwargs=None): + """Load a task model from the preset.""" + kwargs = kwargs or {} + cls = self.find_compatible_subclass(cls) + backbone_kwargs, kwargs = self.split_backbone_kwargs(kwargs) + return cls._from_defaults( + self, + load_weights=load_weights, + kwargs=kwargs, + backbone_kwargs=backbone_kwargs, + ) def load_preprocessor( - self, cls, config_file=PREPROCESSOR_CONFIG_FILE, **kwargs + self, cls, config_file=PREPROCESSOR_CONFIG_FILE, kwargs=None ): - """Load a prepocessor layer from the preset. - - By default, we create a preprocessor from a tokenizer with default - arguments. This allow us to support transformers checkpoints by - only converting the backbone and tokenizer. - """ - kwargs = cls._add_missing_kwargs(self, kwargs) - return cls(**kwargs) + """Load a prepocessor layer from the preset.""" + kwargs = kwargs or {} + cls = self.find_compatible_subclass(cls) + return cls._from_defaults(self, kwargs=kwargs) class KerasPresetLoader(PresetLoader): def check_backbone_class(self): return check_config_class(self.config) - def load_backbone(self, cls, load_weights, **kwargs): - backbone = self._load_serialized_object(self.config, **kwargs) + def load_backbone(self, cls=None, load_weights=True, kwargs=None): + kwargs = kwargs or {} + backbone = self._load_serialized_object(cls, self.config, kwargs) if load_weights: jax_memory_cleanup(backbone) self._load_backbone_weights(backbone) return backbone - def load_tokenizer(self, cls, config_file=TOKENIZER_CONFIG_FILE, **kwargs): + def load_tokenizer( + self, cls=None, config_file=TOKENIZER_CONFIG_FILE, kwargs=None + ): + kwargs = kwargs or {} tokenizer_config = load_json(self.preset, config_file) - tokenizer = self._load_serialized_object(tokenizer_config, **kwargs) + tokenizer = self._load_serialized_object(cls, tokenizer_config, kwargs) if hasattr(tokenizer, "load_preset_assets"): tokenizer.load_preset_assets(self.preset) return tokenizer - def load_audio_converter(self, cls, **kwargs): + def load_audio_converter(self, cls=None, kwargs=None): + kwargs = kwargs or {} converter_config = load_json(self.preset, AUDIO_CONVERTER_CONFIG_FILE) - return self._load_serialized_object(converter_config, **kwargs) + return self._load_serialized_object(cls, converter_config, kwargs) - def load_image_converter(self, cls, **kwargs): + def load_image_converter(self, cls=None, kwargs=None): + kwargs = kwargs or {} converter_config = load_json(self.preset, IMAGE_CONVERTER_CONFIG_FILE) - return self._load_serialized_object(converter_config, **kwargs) + return self._load_serialized_object(cls, converter_config, kwargs) - def load_task(self, cls, load_weights, load_task_weights, **kwargs): + def load_task(self, cls, load_weights=True, kwargs=None): + kwargs = kwargs or {} # If there is no `task.json` or it's for the wrong class delegate to the # super class loader. if not check_file_exists(self.preset, TASK_CONFIG_FILE): return super().load_task( - cls, load_weights, load_task_weights, **kwargs + cls=cls, load_weights=load_weights, kwargs=kwargs ) task_config = load_json(self.preset, TASK_CONFIG_FILE) - if not issubclass(check_config_class(task_config), cls): + if cls and not issubclass(check_config_class(task_config), cls): return super().load_task( - cls, load_weights, load_task_weights, **kwargs + cls=cls, load_weights=load_weights, kwargs=kwargs ) # We found a `task.json` with a complete config for our class. # Forward backbone args. - backbone_kwargs, kwargs = self.get_backbone_kwargs(**kwargs) + backbone_kwargs, kwargs = self.split_backbone_kwargs(kwargs) if "backbone" in task_config["config"]: backbone_config = task_config["config"]["backbone"]["config"] backbone_config = {**backbone_config, **backbone_kwargs} task_config["config"]["backbone"]["config"] = backbone_config - task = self._load_serialized_object(task_config, **kwargs) + task = self._load_serialized_object(cls, task_config, kwargs) if task.preprocessor and hasattr( task.preprocessor, "load_preset_assets" ): task.preprocessor.load_preset_assets(self.preset) if load_weights: has_task_weights = check_file_exists(self.preset, TASK_WEIGHTS_FILE) - if has_task_weights and load_task_weights: + # Skip head weights for classifiers in num_classes is provided. + if has_task_weights and "num_classes" not in kwargs: jax_memory_cleanup(task) task_weights = get_file(self.preset, TASK_WEIGHTS_FILE) task.load_task_weights(task_weights) @@ -706,22 +707,32 @@ def load_task(self, cls, load_weights, load_task_weights, **kwargs): return task def load_preprocessor( - self, cls, config_file=PREPROCESSOR_CONFIG_FILE, **kwargs + self, cls, config_file=PREPROCESSOR_CONFIG_FILE, kwargs=None ): + kwargs = kwargs or {} # If there is no `preprocessing.json` or it's for the wrong class, # delegate to the super class loader. if not check_file_exists(self.preset, config_file): - return super().load_preprocessor(cls, **kwargs) - preprocessor_json = load_json(self.preset, config_file) - if not issubclass(check_config_class(preprocessor_json), cls): - return super().load_preprocessor(cls, **kwargs) + return super().load_preprocessor(cls=cls, kwargs=kwargs) + preprocessor_config = load_json(self.preset, config_file) + if cls and not issubclass(check_config_class(preprocessor_config), cls): + return super().load_preprocessor(cls=cls, kwargs=kwargs) # We found a `preprocessing.json` with a complete config for our class. - preprocessor = self._load_serialized_object(preprocessor_json, **kwargs) + preprocessor = self._load_serialized_object( + cls, preprocessor_config, kwargs + ) if hasattr(preprocessor, "load_preset_assets"): preprocessor.load_preset_assets(self.preset) return preprocessor - def _load_serialized_object(self, config, **kwargs): + def _load_serialized_object(self, cls, config, kwargs=None): + kwargs = kwargs or {} + config_cls = check_config_class(config) + if cls and not issubclass(config_cls, cls): + raise ValueError( + f"Unable to load config with saved class {config_cls.__name__} " + f"as an object of class {cls.__name__}" + ) # `dtype` in config might be a serialized `DTypePolicy` or # `DTypePolicyMap`. Ensure that `dtype` is properly configured. dtype = kwargs.pop("dtype", None) diff --git a/keras_hub/src/utils/timm/preset_loader.py b/keras_hub/src/utils/timm/preset_loader.py index 18b6180e68..9036ebdc5c 100644 --- a/keras_hub/src/utils/timm/preset_loader.py +++ b/keras_hub/src/utils/timm/preset_loader.py @@ -1,5 +1,6 @@ """Convert timm models to KerasHub.""" +from keras_hub.src.layers.preprocessing.image_converter import ImageConverter from keras_hub.src.models.image_classifier import ImageClassifier from keras_hub.src.utils.preset_utils import PresetLoader from keras_hub.src.utils.preset_utils import jax_memory_cleanup @@ -37,8 +38,10 @@ def __init__(self, preset, config): def check_backbone_class(self): return self.converter.backbone_cls - def load_backbone(self, cls, load_weights, **kwargs): + def load_backbone(self, cls=None, load_weights=True, kwargs=None): + kwargs = kwargs or {} keras_config = self.converter.convert_backbone_config(self.config) + cls = self.check_backbone_class() backbone = cls(**{**keras_config, **kwargs}) if load_weights: jax_memory_cleanup(backbone) @@ -47,42 +50,42 @@ def load_backbone(self, cls, load_weights, **kwargs): self.converter.convert_weights(backbone, loader, self.config) return backbone - def load_task(self, cls, load_weights, load_task_weights, **kwargs): - if not load_task_weights or not issubclass(cls, ImageClassifier): - return super().load_task( - cls, load_weights, load_task_weights, **kwargs - ) + def load_task(self, cls, load_weights=True, kwargs=None): + kwargs = kwargs or {} + if not issubclass(cls, ImageClassifier) or "num_classes" in kwargs: + return super().load_task(cls, load_weights, kwargs) # Support loading the classification head for classifier models. - kwargs["num_classes"] = self.config["num_classes"] + if "num_classes" in self.config: + kwargs["num_classes"] = self.config["num_classes"] + # TODO: Move arch specific config to the converter. if ( - "num_features" in self.config - and "mobilenet" in self.config["architecture"] + self.config["architecture"].startswith("mobilenet") + and "num_features" not in kwargs + and "num_features" in self.config ): kwargs["num_features"] = self.config["num_features"] - - task = super().load_task(cls, load_weights, load_task_weights, **kwargs) - if load_task_weights: + task = super().load_task(cls, load_weights, kwargs) + if load_weights: with SafetensorLoader(self.preset, prefix="") as loader: self.converter.convert_head(task, loader, self.config) return task - def load_image_converter(self, cls, **kwargs): + def load_image_converter(self, cls=None, kwargs=None): + kwargs = kwargs or {} + cls = self.find_compatible_subclass(cls or ImageConverter) pretrained_cfg = self.config.get("pretrained_cfg", None) if not pretrained_cfg or "input_size" not in pretrained_cfg: return None # This assumes the same basic setup for all timm preprocessing, We may # need to extend this as we cover more model types. - input_size = pretrained_cfg["input_size"] + defaults = {} + defaults["image_size"] = pretrained_cfg["input_size"][1:] mean = pretrained_cfg["mean"] std = pretrained_cfg["std"] - scale = [1.0 / 255.0 / s for s in std] - offset = [-m / s for m, s in zip(mean, std)] + defaults["scale"] = [1.0 / 255.0 / s for s in std] + defaults["offset"] = [-m / s for m, s in zip(mean, std)] interpolation = pretrained_cfg["interpolation"] if interpolation not in ("bilinear", "nearest", "bicubic"): interpolation = "bilinear" # Unsupported interpolation type. - return cls( - image_size=input_size[1:], - scale=scale, - offset=offset, - interpolation=interpolation, - ) + defaults["interpolation"] = interpolation + return cls(**{**defaults, **kwargs}) diff --git a/keras_hub/src/utils/transformers/preset_loader.py b/keras_hub/src/utils/transformers/preset_loader.py index fe49a9b269..a182e544e1 100644 --- a/keras_hub/src/utils/transformers/preset_loader.py +++ b/keras_hub/src/utils/transformers/preset_loader.py @@ -1,6 +1,7 @@ """Convert huggingface models to KerasHub.""" from keras_hub.src.models.image_classifier import ImageClassifier +from keras_hub.src.tokenizers.tokenizer import Tokenizer from keras_hub.src.utils.preset_utils import PresetLoader from keras_hub.src.utils.preset_utils import jax_memory_cleanup from keras_hub.src.utils.transformers import convert_albert @@ -65,7 +66,9 @@ def __init__(self, preset, config): def check_backbone_class(self): return self.converter.backbone_cls - def load_backbone(self, cls, load_weights, **kwargs): + def load_backbone(self, cls=None, load_weights=True, kwargs=None): + cls = self.check_backbone_class() + kwargs = kwargs or {} keras_config = self.converter.convert_backbone_config(self.config) backbone = cls(**{**keras_config, **kwargs}) if load_weights: @@ -74,28 +77,23 @@ def load_backbone(self, cls, load_weights, **kwargs): self.converter.convert_weights(backbone, loader, self.config) return backbone - def load_task(self, cls, load_weights, load_task_weights, **kwargs): - architecture = self.config["architectures"][0] - if ( - not load_task_weights - or not issubclass(cls, ImageClassifier) - or architecture == "ViTModel" - ): - return super().load_task( - cls, load_weights, load_task_weights, **kwargs - ) + def load_task(self, cls, load_weights=True, kwargs=None): + kwargs = kwargs or {} + if not issubclass(cls, ImageClassifier) or "num_classes" in kwargs: + return super().load_task(cls, load_weights, kwargs) # Support loading the classification head for classifier models. - if "ForImageClassification" in architecture: - kwargs["num_classes"] = len(self.config["id2label"]) - task = super().load_task(cls, load_weights, load_task_weights, **kwargs) - if load_task_weights: - with SafetensorLoader(self.preset, prefix="") as loader: + kwargs["num_classes"] = len(self.config["id2label"]) + task = super().load_task(cls, load_weights, kwargs) + if load_weights: + with SafetensorLoader(self.preset) as loader: self.converter.convert_head(task, loader, self.config) return task - def load_tokenizer(self, cls, config_name="tokenizer.json", **kwargs): + def load_tokenizer(self, cls=None, config_file=None, kwargs=None): + kwargs = kwargs or {} + cls = self.find_compatible_subclass(cls or Tokenizer) return self.converter.convert_tokenizer(cls, self.preset, **kwargs) - def load_image_converter(self, cls, **kwargs): - # TODO: set image size for pali gemma checkpoints. + def load_image_converter(self, cls=None, **kwargs): + # TODO. return None