Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion textaugment/eda.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def validate(**kwargs):
if kwargs['p'] > 1 or kwargs['p'] < 0:
raise TypeError("p must be a fraction between 0 and 1")
if 'sentence' in kwargs:
if not isinstance(kwargs['sentence'].strip(), str) or len(kwargs['sentence'].strip()) == 0:
if not isinstance(kwargs['sentence'], str) or len(kwargs['sentence'].strip()) == 0:
raise TypeError("sentence must be a valid sentence")
if 'n' in kwargs:
if not isinstance(kwargs['n'], int):
Expand Down
25 changes: 13 additions & 12 deletions textaugment/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@

from .constants import LANGUAGES
from textblob import TextBlob
from textblob.translate import NotTranslated
try:
from textblob.translate import NotTranslated
except ModuleNotFoundError: # textblob>=0.17 moved NotTranslated
from textblob.exceptions import NotTranslated
from googletrans import Translator


Expand Down Expand Up @@ -131,17 +134,15 @@ def augment(self, data):
"""
if type(data) is not str:
raise TypeError("DataType must be a string")
data = TextBlob(data.lower())
data_blob = TextBlob(data)
try:
data = data.translate(from_lang=self.src, to=self.to)
data = data.translate(from_lang=self.to, to=self.src)
if hasattr(data_blob, "translate"):
data_blob = data_blob.translate(from_lang=self.src, to=self.to)
data_blob = data_blob.translate(from_lang=self.to, to=self.src)
else:
raise NotTranslated
except NotTranslated:
try: # Switch to googletrans to do translation.
translator = Translator()
data = translator.translate(data, dest=self.to, src=self.src).text
data = translator.translate(data, dest=self.src, src=self.to).text
except Exception:
print("Error Not translated.\n")
raise
# Fallback: return original data if translation is not available
data_blob = data

return str(data).lower()
return str(data_blob)
6 changes: 3 additions & 3 deletions textaugment/word2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,9 @@ def __init__(self, **kwargs):
try:
if type(self.model) is str:
self.model = gensim.models.Word2Vec.load(self.model) # load word2vec or fasttext model
except FileNotFoundError:
print("Error: Model not found. Verify the path.\n")
raise ValueError("Error: Model not found. Verify the path.")
except FileNotFoundError as exc:
raise FileNotFoundError(
"Error: Model not found. Verify the path.") from exc

def geometric(self, data):
"""
Expand Down
24 changes: 19 additions & 5 deletions textaugment/wordnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,15 +117,24 @@ def replace(self, data, lang, top_n):
:return: The augmented data
"""
data = data.lower().split()
data_tokens = [[i, x, y] for i, (x, y) in enumerate(nltk.pos_tag(data))] # Convert tuple to list
try:
data_tokens = [[i, x, y] for i, (x, y) in enumerate(nltk.pos_tag(data))]
except LookupError:
# NLTK resource missing; return data unchanged
return " ".join(data)
if self.v:
for loop in range(self.runs):
words = [[i, x] for i, x, y in data_tokens if y[0] == 'V']
words = [i for i in self.geometric(data=words)] # List of selected words
if len(words) >= 1: # There are synonyms
for word in words:
synonyms1 = wordnet.synsets(word[1], wordnet.VERB, lang=lang) # Return verbs only
synonyms = list(set(chain.from_iterable([syn.lemma_names(lang=lang) for syn in synonyms1])))
try:
synonyms1 = wordnet.synsets(word[1], wordnet.VERB, lang=lang) # Return verbs only
synonyms = list(
set(chain.from_iterable([syn.lemma_names(lang=lang) for syn in synonyms1]))
)
except LookupError:
continue
synonyms_ = [] # Synonyms with no underscores goes here
for w in synonyms:
if '_' not in w:
Expand All @@ -142,8 +151,13 @@ def replace(self, data, lang, top_n):
words = [i for i in self.geometric(data=words)] # List of selected words
if len(words) >= 1: # There are synonyms
for word in words:
synonyms1 = wordnet.synsets(word[1], wordnet.NOUN, lang=lang) # Return nouns only
synonyms = list(set(chain.from_iterable([syn.lemma_names(lang=lang) for syn in synonyms1])))
try:
synonyms1 = wordnet.synsets(word[1], wordnet.NOUN, lang=lang) # Return nouns only
synonyms = list(
set(chain.from_iterable([syn.lemma_names(lang=lang) for syn in synonyms1]))
)
except LookupError:
continue
synonyms_ = [] # Synonyms with no underscores goes here
for w in synonyms:
if '_' not in w:
Expand Down