Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions keyphrasetransformer/keyphrasetransformer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# import
import os
import sys
import nltk
import torch
from nltk.corpus import words
from nltk.tokenize import word_tokenize, sent_tokenize
from transformers import AutoTokenizer, T5ForConditionalGeneration, MT5ForConditionalGeneration
Expand All @@ -12,12 +12,14 @@
class KeyPhraseTransformer:
def __init__(self, model_type: str = "t5", model_name: str = "snrspeaks/KeyPhraseTransformer"):
self.model_name = model_name
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if model_type == "t5":
self.model = T5ForConditionalGeneration.from_pretrained(self.model_name)
self.model = T5ForConditionalGeneration.from_pretrained(self.model_name).to(self.device)
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
if model_type == "mt5":
self.model = MT5ForConditionalGeneration.from_pretrained(self.model_name)
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
self.model = MT5ForConditionalGeneration.from_pretrained(self.model_name).to(self.device)
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)

def split_into_paragraphs(self, doc: str, max_tokens_per_para: int = 128):
sentences = sent_tokenize(doc.strip())
Expand Down Expand Up @@ -72,7 +74,7 @@ def filter_outputs(self, key_phrases, text):
def predict(self, doc: str):
input_ids = self.tokenizer.encode(
doc, return_tensors="pt", add_special_tokens=True
)
).to(self.device)
generated_ids = self.model.generate(
input_ids=input_ids,
num_beams=2,
Expand Down