Skip to content

Commit e683079

Browse files
authored
Add custom model loading (#2)
* Add custom model loading * Improve edge case condition checks * Use __all__ for imports
1 parent 781dcb5 commit e683079

File tree

3 files changed

+43
-20
lines changed

3 files changed

+43
-20
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
setuptools.setup(
2121
name="spacy-udpipe",
22-
version="0.0.1",
22+
version="0.0.2",
2323
description="Use fast UDPipe models directly in spaCy",
2424
long_description=long_description,
2525
long_description_content_type="text/markdown",

spacy_udpipe/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,5 @@
1-
from .language import UDPipeLanguage, UDPipeModel, load
1+
from .language import UDPipeLanguage, UDPipeModel, load, load_from_path
22
from .util import download
3+
4+
__all__ = ["UDPipeLanguage", "UDPipeModel",
5+
"load", "load_from_path", "download"]

spacy_udpipe/language.py

Lines changed: 38 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,27 @@ def load(lang):
1818
mimicks spacy.load.
1919
2020
lang (unicode): ISO 639-1 language code or shorthand UDPipe model name.
21-
RETURNS (spacy.language.Language): The UDPipeLanguage object.
21+
RETURNS (spacy.language.Language): The UDPipeLanguage object.
2222
"""
2323
model = UDPipeModel(lang)
2424
nlp = UDPipeLanguage(model)
2525
return nlp
2626

2727

28+
def load_from_path(lang, path, meta=None):
29+
"""Convenience function for initializing the Language class and loading
30+
a custom UDPipe model via the path argument.
31+
32+
lang (unicode): ISO 639-1 language code.
33+
path (unicode): Path to the UDPipe model.
34+
meta (dict): Meta-information about the UDPipe model.
35+
RETURNS (spacy.language.Language): The UDPipeLanguage object.
36+
"""
37+
model = UDPipeModel(lang, path, meta)
38+
nlp = UDPipeLanguage(model)
39+
return nlp
40+
41+
2842
class UDPipeLanguage(Language):
2943

3044
def __init__(self, udpipe_model, meta=None, **kwargs):
@@ -93,7 +107,7 @@ def __call__(self, text):
93107
udpipe_sents = self.model(text) if text else [Sentence()]
94108
text = " ".join(s.getText() for s in udpipe_sents)
95109
tokens, heads = self.get_tokens_with_heads(udpipe_sents)
96-
if not len(tokens):
110+
if not tokens:
97111
return Doc(self.vocab)
98112

99113
words = []
@@ -186,32 +200,38 @@ def check_aligned(self, text, tokens):
186200

187201
class UDPipeModel:
188202

189-
def __init__(self, lang):
203+
def __init__(self, lang, path=None, meta=None):
190204
"""Load UDPipe model for given language.
191205
192206
lang (unicode): ISO 639-1 language code or shorthand UDPipe model name.
207+
path (unicode): Path to UDPipe model.
208+
meta (dict): Meta-information about the UDPipe model.
193209
RETURNS (UDPipeModel): Language specific UDPipeModel.
194210
"""
195-
path = get_path(lang)
211+
if path is None:
212+
path = get_path(lang)
196213
self.model = Model.load(path)
197-
if not self.model:
214+
if self.model is None:
198215
msg = "Cannot load UDPipe model from " \
199216
"file '{}'".format(path)
200217
raise Exception(msg)
201218
self._lang = lang.split('-')[0]
202-
self._meta = {'authors': ("Milan Straka, "
203-
"Jana Straková"),
204-
'description': "UDPipe pretrained model.",
205-
'email': '[email protected]',
206-
'lang': 'udpipe_' + self._lang,
207-
'license': 'CC BY-NC-SA 4.0',
208-
'name': path.split('/')[-1],
209-
'parent_package': 'spacy_udpipe',
210-
'pipeline': 'Tokenizer, POS Tagger, Lemmatizer, Parser',
211-
'source': 'Universal Dependencies 2.4',
212-
'url': 'http://ufal.mff.cuni.cz/udpipe',
213-
'version': '1.2.0'
214-
}
219+
if meta is None:
220+
self._meta = {'authors': ("Milan Straka, "
221+
"Jana Straková"),
222+
'description': "UDPipe pretrained model.",
223+
'email': '[email protected]',
224+
'lang': 'udpipe_' + self._lang,
225+
'license': 'CC BY-NC-SA 4.0',
226+
'name': path.split('/')[-1],
227+
'parent_package': 'spacy_udpipe',
228+
'pipeline': 'Tokenizer, POS Tagger, Lemmatizer, Parser',
229+
'source': 'Universal Dependencies 2.4',
230+
'url': 'http://ufal.mff.cuni.cz/udpipe',
231+
'version': '1.2.0'
232+
}
233+
else:
234+
self._meta = meta
215235

216236
def __call__(self, text):
217237
"""Tokenize, tag and parse the text and return it in an UDPipe

0 commit comments

Comments
 (0)