diff --git a/camel_tools/disambig/bert/unfactored.py b/camel_tools/disambig/bert/unfactored.py index 3ab0b09..e8fe866 100644 --- a/camel_tools/disambig/bert/unfactored.py +++ b/camel_tools/disambig/bert/unfactored.py @@ -87,8 +87,9 @@ def labels(self): Returns: :obj:`list` of :obj:`str`: List of Morph labels. """ - - return list(self._labels_map.values()) + num_labels = len(self._labels_map) + labels = [self._labels_map[i] for i in range(num_labels)] + return labels def _align_predictions(self, predictions, label_ids, sent_ids): """Aligns the predictions of the model with the inputs and it takes @@ -147,14 +148,17 @@ def predict(self, sentences, batch_size=32, max_seq_length=512): test_dataset = MorphDataset(sentences=sorted_sentences_text, tokenizer=self._tokenizer, - labels=list(self._labels_map.values()), + labels=self.labels(), max_seq_length=max_seq_length) data_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=False, collate_fn=self._collate_fn) - predictions = [] + label_ids = None + preds = None + sent_ids = None + device = ('cuda' if self._use_gpu and torch.cuda.is_available() else 'cpu') self._model.to(device) @@ -167,20 +171,25 @@ def predict(self, sentences, batch_size=32, max_seq_length=512): 'token_type_ids': batch['token_type_ids'], 'attention_mask': batch['attention_mask']} - label_ids = batch['label_ids'] - sent_ids = batch['sent_id'] + label_ids = (batch['label_ids'] if label_ids is None + else torch.cat((label_ids, batch['label_ids']))) + sent_ids = (batch['sent_id'] if sent_ids is None + else torch.cat((sent_ids, batch['sent_id']))) + logits = self._model(**inputs)[0] - preds = logits - prediction = self._align_predictions(preds.cpu().numpy(), - label_ids.cpu().numpy(), - sent_ids.cpu().numpy()) - predictions.extend(prediction) + + preds = logits if preds is None else torch.cat((preds, logits), + dim=0) + predictions = self._align_predictions(preds.cpu().numpy(), + label_ids.cpu().numpy(), + sent_ids.cpu().numpy()) sorted_predictions_pair = zip(sorted_sentences_idx, predictions) sorted_predictions = sorted(sorted_predictions_pair, key=lambda x: x[0]) + predictions = [i[1] for i in sorted_predictions] - return [i[1] for i in sorted_predictions] + return predictions def _collate_fn(self, batch): input_ids = [] @@ -412,7 +421,8 @@ def _predict_sentences(self, sentences): morphosyntactic labels for the given sentences. """ - preds = self._model['unfactored'].predict(sentences, self._batch_size) + preds = self._model['unfactored'].predict(sentences, + batch_size=self._batch_size) parsed_predictions = [] for sent, pred in zip(sentences, preds): @@ -446,7 +456,7 @@ def _predict_sentence(self, sentence): parsed_predictions = [] model = self._model['unfactored'] - preds = model.predict([sentence], self._batch_size)[0] + preds = model.predict([sentence], batch_size=self._batch_size)[0] for word, pred in zip(sentence, preds): d = {}