Skip to content

Commit 9832d8a

Browse files
authored
Expose Task and Backbone (#1506)
These are already exposed on KerasCV, and I think it is time to also expose these in KerasNLP. This will give us a class to document common model functionality to all backbones such as `enable_lora` and `token_embedding` on keras.io. It can also open up a path for writing a custom architecture outside the library itself.
1 parent 26a2fb8 commit 9832d8a

File tree

3 files changed

+6
-2
lines changed

3 files changed

+6
-2
lines changed

keras_nlp/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
)
2121
from keras_nlp.models.albert.albert_preprocessor import AlbertPreprocessor
2222
from keras_nlp.models.albert.albert_tokenizer import AlbertTokenizer
23+
from keras_nlp.models.backbone import Backbone
2324
from keras_nlp.models.bart.bart_backbone import BartBackbone
2425
from keras_nlp.models.bart.bart_preprocessor import BartPreprocessor
2526
from keras_nlp.models.bart.bart_seq_2_seq_lm import BartSeq2SeqLM
@@ -130,6 +131,7 @@
130131
from keras_nlp.models.roberta.roberta_tokenizer import RobertaTokenizer
131132
from keras_nlp.models.t5.t5_backbone import T5Backbone
132133
from keras_nlp.models.t5.t5_tokenizer import T5Tokenizer
134+
from keras_nlp.models.task import Task
133135
from keras_nlp.models.whisper.whisper_audio_feature_extractor import (
134136
WhisperAudioFeatureExtractor,
135137
)

keras_nlp/models/backbone.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from keras_nlp.api_export import keras_nlp_export
1516
from keras_nlp.backend import config
1617
from keras_nlp.backend import keras
1718
from keras_nlp.utils.preset_utils import check_preset_class
@@ -20,7 +21,7 @@
2021
from keras_nlp.utils.python_utils import format_docstring
2122

2223

23-
@keras.saving.register_keras_serializable(package="keras_nlp")
24+
@keras_nlp_export("keras_nlp.models.Backbone")
2425
class Backbone(keras.Model):
2526
def __init__(self, *args, dtype=None, **kwargs):
2627
super().__init__(*args, **kwargs)

keras_nlp/models/task.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from rich import markup
1717
from rich import table as rich_table
1818

19+
from keras_nlp.api_export import keras_nlp_export
1920
from keras_nlp.backend import config
2021
from keras_nlp.backend import keras
2122
from keras_nlp.utils.keras_utils import print_msg
@@ -26,7 +27,7 @@
2627
from keras_nlp.utils.python_utils import format_docstring
2728

2829

29-
@keras.saving.register_keras_serializable(package="keras_nlp")
30+
@keras_nlp_export("keras_nlp.models.Task")
3031
class Task(PipelineModel):
3132
"""Base class for Task models."""
3233

0 commit comments

Comments
 (0)