Skip to content
This repository was archived by the owner on Jan 15, 2024. It is now read-only.

Commit c99061d

Browse files
authored
Add list_models API (#1181)
1 parent 6fb43ec commit c99061d

File tree

4 files changed

+68
-2
lines changed

4 files changed

+68
-2
lines changed

docs/api/model.rst

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,14 @@ The `get_model` function returns a pre-defined model given the name of a
2020
registered model. The following sections of this page present a list of
2121
registered names for each model category.
2222

23+
Information about pretrained models
24+
-----------------------------------
25+
26+
.. autosummary::
27+
:nosignatures:
28+
29+
list_models
30+
2331
Language Modeling
2432
-----------------
2533

@@ -138,6 +146,7 @@ Sequence Sampling
138146
BeamSearchSampler
139147
SequenceSampler
140148

149+
141150
Other Modeling Utilities
142151
------------------------
143152

src/gluonnlp/model/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@
7171
convolutional_encoder, elmo, highway, language_model,
7272
lstmpcellwithclip, parameter, sampled_block,
7373
seq2seq_encoder_decoder, sequence_sampler, train, transformer,
74-
utils)
74+
utils, info)
7575
from .attention_cell import *
7676
from .bert import *
7777
from .bilm_encoder import BiLMEncoder
@@ -88,13 +88,14 @@
8888
from .transformer import *
8989
from .translation import *
9090
from .utils import *
91+
from .info import *
9192
from ..base import get_home_dir
9293

9394
__all__ = (language_model.__all__ + sequence_sampler.__all__ + attention_cell.__all__ +
9495
utils.__all__ + parameter.__all__ + block.__all__ + highway.__all__ +
9596
convolutional_encoder.__all__ + sampled_block.__all__ + bilm_encoder.__all__ +
9697
lstmpcellwithclip.__all__ + elmo.__all__ + seq2seq_encoder_decoder.__all__ +
97-
transformer.__all__ + bert.__all__ + ['train', 'get_model'])
98+
transformer.__all__ + bert.__all__ + info.__all__ + ['train', 'get_model'])
9899

99100

100101
def get_model(name, **kwargs):

src/gluonnlp/model/info.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""API to get list of pretrained models"""
18+
__all__ = ['list_models']
19+
20+
from . import (bert, bilm_encoder, elmo, language_model,
21+
transformer)
22+
23+
24+
def list_models():
25+
"""Returns the list of pretrained models
26+
"""
27+
models = (bert.__all__ + bilm_encoder.__all__ + elmo.__all__ +
28+
language_model.__all__ + transformer.__all__)
29+
30+
return models

tests/unittest/test_info.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
import os
19+
import sys
20+
import warnings
21+
22+
import gluonnlp as nlp
23+
24+
def test_get_models():
25+
models = nlp.model.list_models()
26+
assert len(models)!=0

0 commit comments

Comments
 (0)