diff --git a/keyphrasetransformer/keyphrasetransformer.py b/keyphrasetransformer/keyphrasetransformer.py index 6b48082..6b41985 100644 --- a/keyphrasetransformer/keyphrasetransformer.py +++ b/keyphrasetransformer/keyphrasetransformer.py @@ -1,7 +1,7 @@ -# import import os import sys import nltk +import torch from nltk.corpus import words from nltk.tokenize import word_tokenize, sent_tokenize from transformers import AutoTokenizer, T5ForConditionalGeneration, MT5ForConditionalGeneration @@ -12,12 +12,14 @@ class KeyPhraseTransformer: def __init__(self, model_type: str = "t5", model_name: str = "snrspeaks/KeyPhraseTransformer"): self.model_name = model_name + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + if model_type == "t5": - self.model = T5ForConditionalGeneration.from_pretrained(self.model_name) + self.model = T5ForConditionalGeneration.from_pretrained(self.model_name).to(self.device) self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) if model_type == "mt5": - self.model = MT5ForConditionalGeneration.from_pretrained(self.model_name) - self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) + self.model = MT5ForConditionalGeneration.from_pretrained(self.model_name).to(self.device) + self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) def split_into_paragraphs(self, doc: str, max_tokens_per_para: int = 128): sentences = sent_tokenize(doc.strip()) @@ -72,7 +74,7 @@ def filter_outputs(self, key_phrases, text): def predict(self, doc: str): input_ids = self.tokenizer.encode( doc, return_tensors="pt", add_special_tokens=True - ) + ).to(self.device) generated_ids = self.model.generate( input_ids=input_ids, num_beams=2,