Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 12 additions & 25 deletions tableqa/clauses.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,15 @@

from tensorflow.keras.models import load_model
from sentence_transformers import SentenceTransformer
from numpy import asarray


class Clause:
def __init__(self):
self.bert_model = SentenceTransformer('bert-base-nli-mean-tokens')
self.model=load_model("Question_Classifier.h5")
self.types={0:'SELECT {} FROM {}', 1:'SELECT MAX({}) FROM {}', 2:'SELECT MIN({}) FROM {}', 3:'SELECT COUNT({}) FROM {}', 4:'SELECT SUM({}) FROM {}', 5:'SELECT AVG({}) FROM {}'}

def adapt(self,q,inttype=False,priority=False):
emb=asarray(self.bert_model.encode(q))
self.clause=self.types[self.model.predict_classes(emb)[0]]
#from nlp import qa
# class Clause:
# def __init__(self):

if priority and inttype and "COUNT" in self.clause:
self.clause= '''SELECT SUM({}) FROM {}'''
return self.clause







# self.base_q="what is {} here"
# self.types={"the entity":'SELECT {} FROM {}', "the maximum":'SELECT MAX({}) FROM {}', "the minimum":'SELECT MIN({}) FROM {}', "counted":'SELECT COUNT({}) FROM {}', "summed":'SELECT SUM({}) FROM {}', "averaged":'SELECT AVG({}) FROM {}'}

# def adapt(self,q,inttype=False,priority=False):
# scores={}
# for k,v in self.types.items():
# scores[k]=qa(q,self.base_q.format(k),return_score=True)[1]
# return self.types[max(scores, key=scores.get)]


20 changes: 17 additions & 3 deletions tableqa/nlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,18 @@
from data_utils import data_utils
import os
from transformers import TFBertForQuestionAnswering, BertTokenizer
from transformers import TFAutoModelForSequenceClassification, AutoTokenizer
import tensorflow as tf
from rake_nltk import Rake
import column_types
import json
from clauses import Clause
from conditionmaps import conditions


qa_model = TFBertForQuestionAnswering.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')
qa_tokenizer = BertTokenizer.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')



import nltk
from nltk.stem import WordNetLemmatizer
from nltk.corpus import stopwords
Expand Down Expand Up @@ -159,6 +158,22 @@ def _find(lst, sublst):

def _window_overlap(s1, e1, s2, e2):
return s2 <= e1 if s1 <s2 else s1 <= e2


class Clause:
def __init__(self):

self.base_q="what is {} here"
self.types={"the entity":'SELECT {} FROM {}', "the maximum":'SELECT MAX({}) FROM {}', "the minimum":'SELECT MIN({}) FROM {}', "counted":'SELECT COUNT({}) FROM {}', "summed":'SELECT SUM({}) FROM {}', "averaged":'SELECT AVG({}) FROM {}'}

def adapt(self,q,inttype=False,priority=False):
scores={}
for k,v in self.types.items():
scores[k]=qa(q,self.base_q.format(k),return_score=True)[1]
return self.types[max(scores, key=scores.get)]



class Nlp:
def __init__(self,data_dir,schema_dir):
self.data_dir=data_dir
Expand Down Expand Up @@ -262,7 +277,6 @@ def _is_numeric(typ):

return ret


def cond_map(self,s):
data=conditions

Expand Down
Loading