Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 24 additions & 14 deletions camel_tools/disambig/bert/unfactored.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,9 @@ def labels(self):
Returns:
:obj:`list` of :obj:`str`: List of Morph labels.
"""

return list(self._labels_map.values())
num_labels = len(self._labels_map)
labels = [self._labels_map[i] for i in range(num_labels)]
return labels

def _align_predictions(self, predictions, label_ids, sent_ids):
"""Aligns the predictions of the model with the inputs and it takes
Expand Down Expand Up @@ -147,14 +148,17 @@ def predict(self, sentences, batch_size=32, max_seq_length=512):

test_dataset = MorphDataset(sentences=sorted_sentences_text,
tokenizer=self._tokenizer,
labels=list(self._labels_map.values()),
labels=self.labels(),
max_seq_length=max_seq_length)

data_loader = DataLoader(test_dataset, batch_size=batch_size,
shuffle=False, drop_last=False,
collate_fn=self._collate_fn)

predictions = []
label_ids = None
preds = None
sent_ids = None

device = ('cuda' if self._use_gpu and torch.cuda.is_available()
else 'cpu')
self._model.to(device)
Expand All @@ -167,20 +171,25 @@ def predict(self, sentences, batch_size=32, max_seq_length=512):
'token_type_ids': batch['token_type_ids'],
'attention_mask': batch['attention_mask']}

label_ids = batch['label_ids']
sent_ids = batch['sent_id']
label_ids = (batch['label_ids'] if label_ids is None
else torch.cat((label_ids, batch['label_ids'])))
sent_ids = (batch['sent_id'] if sent_ids is None
else torch.cat((sent_ids, batch['sent_id'])))

logits = self._model(**inputs)[0]
preds = logits
prediction = self._align_predictions(preds.cpu().numpy(),
label_ids.cpu().numpy(),
sent_ids.cpu().numpy())
predictions.extend(prediction)

preds = logits if preds is None else torch.cat((preds, logits),
dim=0)
predictions = self._align_predictions(preds.cpu().numpy(),
label_ids.cpu().numpy(),
sent_ids.cpu().numpy())

sorted_predictions_pair = zip(sorted_sentences_idx, predictions)
sorted_predictions = sorted(sorted_predictions_pair,
key=lambda x: x[0])
predictions = [i[1] for i in sorted_predictions]

return [i[1] for i in sorted_predictions]
return predictions

def _collate_fn(self, batch):
input_ids = []
Expand Down Expand Up @@ -412,7 +421,8 @@ def _predict_sentences(self, sentences):
morphosyntactic labels for the given sentences.
"""

preds = self._model['unfactored'].predict(sentences, self._batch_size)
preds = self._model['unfactored'].predict(sentences,
batch_size=self._batch_size)
parsed_predictions = []

for sent, pred in zip(sentences, preds):
Expand Down Expand Up @@ -446,7 +456,7 @@ def _predict_sentence(self, sentence):

parsed_predictions = []
model = self._model['unfactored']
preds = model.predict([sentence], self._batch_size)[0]
preds = model.predict([sentence], batch_size=self._batch_size)[0]

for word, pred in zip(sentence, preds):
d = {}
Expand Down