Spaces:
Runtime error
Runtime error
Create app.py
Browse files
app.py
ADDED
|
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import numpy as np
|
| 3 |
+
import time
|
| 4 |
+
import hashlib
|
| 5 |
+
import torch
|
| 6 |
+
from transformers import AutoTokenizer, AutoModel, AutoModelForQuestionAnswering, pipeline
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
import os
|
| 9 |
+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
| 10 |
+
import textract
|
| 11 |
+
from scipy.special import softmax
|
| 12 |
+
import pandas as pd
|
| 13 |
+
from datetime import datetime
|
| 14 |
+
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/multi-qa-mpnet-base-dot-v1")
|
| 15 |
+
model = AutoModel.from_pretrained("sentence-transformers/multi-qa-mpnet-base-dot-v1").to(device).eval()
|
| 16 |
+
tokenizer_ans = AutoTokenizer.from_pretrained("deepset/roberta-large-squad2")
|
| 17 |
+
model_ans = AutoModelForQuestionAnswering.from_pretrained("deepset/roberta-large-squad2").to(device).eval()
|
| 18 |
+
if device == 'cuda:0':
|
| 19 |
+
pipe = pipeline("question-answering",model_ans,tokenizer =tokenizer_ans,device = 0)
|
| 20 |
+
else:
|
| 21 |
+
pipe = pipeline("question-answering",model_ans,tokenizer =tokenizer_ans)
|
| 22 |
+
|
| 23 |
+
def cls_pooling(model_output):
|
| 24 |
+
return model_output.last_hidden_state[:,0]
|
| 25 |
+
|
| 26 |
+
def encode_query(query):
|
| 27 |
+
encoded_input = tokenizer(query, truncation=True, return_tensors='pt').to(device)
|
| 28 |
+
|
| 29 |
+
with torch.no_grad():
|
| 30 |
+
model_output = model(**encoded_input, return_dict=True)
|
| 31 |
+
|
| 32 |
+
embeddings = cls_pooling(model_output)
|
| 33 |
+
|
| 34 |
+
return embeddings.cpu()
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def encode_docs(docs,maxlen = 64, stride = 32):
|
| 38 |
+
encoded_input = []
|
| 39 |
+
embeddings = []
|
| 40 |
+
spans = []
|
| 41 |
+
file_names = []
|
| 42 |
+
name, text = docs
|
| 43 |
+
|
| 44 |
+
text = text.split(" ")
|
| 45 |
+
if len(text) < maxlen:
|
| 46 |
+
text = " ".join(text)
|
| 47 |
+
|
| 48 |
+
encoded_input.append(tokenizer(temp_text, return_tensors='pt', truncation = True).to(device))
|
| 49 |
+
spans.append(temp_text)
|
| 50 |
+
file_names.append(name)
|
| 51 |
+
|
| 52 |
+
else:
|
| 53 |
+
num_iters = int(len(text)/maxlen)+1
|
| 54 |
+
for i in range(num_iters):
|
| 55 |
+
if i == 0:
|
| 56 |
+
temp_text = " ".join(text[i*maxlen:(i+1)*maxlen+stride])
|
| 57 |
+
else:
|
| 58 |
+
temp_text = " ".join(text[(i-1)*maxlen:(i)*maxlen][-stride:] + text[i*maxlen:(i+1)*maxlen])
|
| 59 |
+
|
| 60 |
+
encoded_input.append(tokenizer(temp_text, return_tensors='pt', truncation = True).to(device))
|
| 61 |
+
spans.append(temp_text)
|
| 62 |
+
file_names.append(name)
|
| 63 |
+
|
| 64 |
+
with torch.no_grad():
|
| 65 |
+
for encoded in tqdm(encoded_input):
|
| 66 |
+
model_output = model(**encoded, return_dict=True)
|
| 67 |
+
embeddings.append(cls_pooling(model_output))
|
| 68 |
+
|
| 69 |
+
embeddings = np.float32(torch.stack(embeddings).transpose(0, 1).cpu())
|
| 70 |
+
|
| 71 |
+
np.save("encoded_gradio/emb_{}.npy".format(name),dict(zip(list(range(len(embeddings))),embeddings)))
|
| 72 |
+
np.save("encoded_gradio/spans_{}.npy".format(name),dict(zip(list(range(len(spans))),spans)))
|
| 73 |
+
np.save("encoded_gradio/file_{}.npy".format(name),dict(zip(list(range(len(file_names))),file_names)))
|
| 74 |
+
|
| 75 |
+
return embeddings, spans, file_names
|
| 76 |
+
|
| 77 |
+
def predict(query,data):
|
| 78 |
+
name_to_save = data.name.split("\\")[-1].split(".")[0][:-8]
|
| 79 |
+
st = str([query,name_to_save])
|
| 80 |
+
hist = st + " " + str(hashlib.sha256(st.encode()).hexdigest())
|
| 81 |
+
now = datetime.now()
|
| 82 |
+
current_time = now.strftime("%H:%M:%S")
|
| 83 |
+
try:
|
| 84 |
+
df = pd.read_csv("HISTORY/{}.csv".format(hash(st)))
|
| 85 |
+
return df
|
| 86 |
+
except Exception as e:
|
| 87 |
+
print(e)
|
| 88 |
+
print(st)
|
| 89 |
+
|
| 90 |
+
if name_to_save+".txt" in os.listdir("text_gradio"):
|
| 91 |
+
doc_emb = np.load('encoded_gradio/emb_{}.npy'.format(name_to_save),allow_pickle='TRUE').item()
|
| 92 |
+
doc_text = np.load('encoded_gradio/spans_{}.npy'.format(name_to_save),allow_pickle='TRUE').item()
|
| 93 |
+
file_names_dicto = np.load('encoded_gradio/file_{}.npy'.format(name_to_save),allow_pickle='TRUE').item()
|
| 94 |
+
|
| 95 |
+
doc_emb = np.array(list(doc_emb.values())).reshape(-1,768)
|
| 96 |
+
doc_text = list(doc_text.values())
|
| 97 |
+
file_names = list(file_names_dicto.values())
|
| 98 |
+
|
| 99 |
+
else:
|
| 100 |
+
text = textract.process("{}".format(data.name)).decode('utf8')
|
| 101 |
+
text = text.replace("\r", " ")
|
| 102 |
+
text = text.replace("\n", " ")
|
| 103 |
+
text = text.replace(" . "," ")
|
| 104 |
+
|
| 105 |
+
doc_emb, doc_text, file_names = encode_docs((name_to_save,text),maxlen = 64, stride = 32)
|
| 106 |
+
|
| 107 |
+
doc_emb = doc_emb.reshape(-1, 768)
|
| 108 |
+
with open("text_gradio/{}.txt".format(name_to_save),"w",encoding="utf-8") as f:
|
| 109 |
+
f.write(text)
|
| 110 |
+
start = time.time()
|
| 111 |
+
query_emb = encode_query(query)
|
| 112 |
+
|
| 113 |
+
scores = np.matmul(query_emb, doc_emb.transpose(1,0))[0].tolist()
|
| 114 |
+
doc_score_pairs = list(zip(doc_text, scores, file_names))
|
| 115 |
+
doc_score_pairs = sorted(doc_score_pairs, key=lambda x: x[1], reverse=True)
|
| 116 |
+
k = 5
|
| 117 |
+
probs_sum = 0
|
| 118 |
+
probs = softmax(sorted(scores,reverse = True)[:k])
|
| 119 |
+
table = {"Passage":[],"Answer":[],"Probabilities":[],"Source":[]}
|
| 120 |
+
|
| 121 |
+
for i, (passage, _, names) in enumerate(doc_score_pairs[:k]):
|
| 122 |
+
passage = passage.replace("\n","")
|
| 123 |
+
passage = passage.replace(" . "," ")
|
| 124 |
+
|
| 125 |
+
if probs[i] > 0.1 or (i < 3 and probs[i] > 0.05): #generate answers for more likely passages but no less than 2
|
| 126 |
+
QA = {'question':query,'context':passage}
|
| 127 |
+
ans = pipe(QA)
|
| 128 |
+
probabilities = "P(a|p): {}, P(a|p,q): {}, P(p|q): {}".format(round(ans["score"],5),
|
| 129 |
+
round(ans["score"]*probs[i],5),
|
| 130 |
+
round(probs[i],5))
|
| 131 |
+
passage = passage.replace(str(ans["answer"]),str(ans["answer"]).upper())
|
| 132 |
+
table["Passage"].append(passage)
|
| 133 |
+
table["Passage"].append("---")
|
| 134 |
+
table["Answer"].append(str(ans["answer"]).upper())
|
| 135 |
+
table["Answer"].append("---")
|
| 136 |
+
table["Probabilities"].append(probabilities)
|
| 137 |
+
table["Probabilities"].append("---")
|
| 138 |
+
table["Source"].append(names)
|
| 139 |
+
table["Source"].append("---")
|
| 140 |
+
else:
|
| 141 |
+
table["Passage"].append(passage)
|
| 142 |
+
table["Passage"].append("---")
|
| 143 |
+
table["Answer"].append("no_answer_calculated")
|
| 144 |
+
table["Answer"].append("---")
|
| 145 |
+
table["Probabilities"].append("P(p|q): {}".format(round(probs[i],5)))
|
| 146 |
+
table["Probabilities"].append("---")
|
| 147 |
+
table["Source"].append(names)
|
| 148 |
+
table["Source"].append("---")
|
| 149 |
+
df = pd.DataFrame(table)
|
| 150 |
+
print("time: "+ str(time.time()-start))
|
| 151 |
+
|
| 152 |
+
with open("HISTORY.txt","a", encoding = "utf-8") as f:
|
| 153 |
+
f.write(hist)
|
| 154 |
+
f.write(" " + str(current_time))
|
| 155 |
+
f.write("\n")
|
| 156 |
+
f.close()
|
| 157 |
+
df.to_csv("HISTORY/{}.csv".format(hash(st)), index=False)
|
| 158 |
+
|
| 159 |
+
return df
|
| 160 |
+
|
| 161 |
+
iface = gr.Interface(
|
| 162 |
+
|
| 163 |
+
fn =predict,
|
| 164 |
+
inputs = [gr.inputs.Textbox(default="What is Open-domain question answering?"),
|
| 165 |
+
gr.inputs.Checkbox(default=True),
|
| 166 |
+
gr.inputs.File(),
|
| 167 |
+
],
|
| 168 |
+
outputs = [
|
| 169 |
+
gr.outputs.Dataframe(),
|
| 170 |
+
],
|
| 171 |
+
|
| 172 |
+
allow_flagging ="manual",flagging_options = ["correct","wrong"],
|
| 173 |
+
allow_screenshot=False)
|
| 174 |
+
|
| 175 |
+
iface.launch(share = True,enable_queue=True, show_error =True)
|