varun500 commited on
Commit
5d84094
·
1 Parent(s): bff3235

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +145 -0
  2. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ from tqdm import tqdm
3
+ import pinecone
4
+ import torch
5
+ from sentence_transformers import SentenceTransformer
6
+ from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM
7
+ import streamlit as st
8
+ import openai
9
+
10
+
11
+ # connect to pinecone environment
12
+ pinecone.init(
13
+ api_key="d4f20339-fcc1-4a11-b04f-3800203eacd2",
14
+ environment="us-east1-gcp"
15
+ )
16
+
17
+ index_name = "abstractive-question-answering"
18
+
19
+ index = pinecone.Index(index_name)
20
+
21
+ # Initialize models from HuggingFace
22
+
23
+ @st.cache_resource
24
+ def get_t5_model():
25
+ return pipeline("summarization", model="t5-base", tokenizer="t5-base")
26
+
27
+ @st.cache_resource
28
+ def get_flan_t5_model():
29
+ return pipeline("summarization", model="google/flan-t5-base", tokenizer="google/flan-t5-base")
30
+
31
+ @st.cache_resource
32
+ def get_embedding_model():
33
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
34
+ model = SentenceTransformer("flax-sentence-embeddings/all_datasets_v3_mpnet-base", device=device)
35
+ model.max_seq_length = 512
36
+ return model
37
+
38
+ @st.cache_data()
39
+ def save_key(api_key):
40
+ return api_key
41
+
42
+ retriever_model = get_embedding_model()
43
+
44
+ def query_pinecone(query, top_k, model):
45
+ # generate embeddings for the query
46
+ xq = model.encode([query]).tolist()
47
+ # search pinecone index for context passage with the answer
48
+ xc = index.query(xq, top_k=top_k, include_metadata=True)
49
+ return xc
50
+
51
+ def format_query(query_results):
52
+ # extract passage_text from Pinecone search result
53
+ context = [result['metadata']['merged_text'] for result in query_results['matches']]
54
+ return context
55
+
56
+ def gpt3_summary(text):
57
+ response = openai.Completion.create(
58
+ model="text-davinci-003",
59
+ prompt=text+"\n\nTl;dr",
60
+ temperature=0.1,
61
+ max_tokens=512,
62
+ top_p=1.0,
63
+ frequency_penalty=0.0,
64
+ presence_penalty=1
65
+ )
66
+ return response.choices[0].text
67
+
68
+ def gpt3_qa(query, answer):
69
+ response = openai.Completion.create(
70
+ model="text-davinci-003",
71
+ prompt="Q: " + query + "\nA: " + answer,
72
+ temperature=0,
73
+ max_tokens=512,
74
+ top_p=1,
75
+ frequency_penalty=0.0,
76
+ presence_penalty=0.0,
77
+ stop=["\n"]
78
+ )
79
+ return response.choices[0].text
80
+
81
+
82
+ st.title("Abstractive Question Answering - APPL")
83
+
84
+ query_text = st.text_input("Input Query", value="Who is the CEO of Apple?")
85
+
86
+ num_results = int(st.number_input("Number of Results to query", 1, 5, value=2))
87
+
88
+ query_results = query_pinecone(query_text, num_results, retriever_model)
89
+
90
+ context_list = format_query(query_results)
91
+
92
+
93
+
94
+ # Choose decoder model
95
+
96
+ models_choice = ["GPT3 (text_davinci)", "GPT3 - QA", "T5", "FLAN-T5"]
97
+
98
+ decoder_model = st.selectbox(
99
+ 'Select Decoder Model',
100
+ models_choice)
101
+
102
+ st.subheader("Answer:")
103
+
104
+
105
+ if decoder_model == "GPT3 (text_davinci)":
106
+ openai_key = st.text_input("Enter OpenAI key")
107
+ api_key = save_key(openai_key)
108
+ openai.api_key = api_key
109
+ output_text = []
110
+ for context_text in context_list:
111
+ output_text.append(gpt3_summary(context_text))
112
+ generated_text = " ".join(output_text)
113
+ st.write(gpt3_summary(generated_text))
114
+
115
+ elif decoder_model=="GPT3 - QA":
116
+ openai_key = st.text_input("Enter OpenAI key")
117
+ api_key = save_key(openai_key)
118
+ openai.api_key = api_key
119
+ output_text = []
120
+ for context_text in context_list:
121
+ output_text.append(gpt3_qa(query_text, context_text))
122
+ generated_text = " ".join(output_text)
123
+ st.write(gpt3_qa(query_text, generated_text))
124
+
125
+ elif decoder_model == "T5":
126
+ t5_pipeline = get_t5_model()
127
+ output_text = []
128
+ for context_text in context_list:
129
+ output_text.append(t5_pipeline(context_text)[0]["summary_text"])
130
+ generated_text = " ".join(output_text)
131
+ st.write(t5_pipeline(generated_text)[0]["summary_text"])
132
+
133
+ elif decoder_model == "FLAN-T5":
134
+ flan_t5_pipeline = get_flan_t5_model()
135
+ output_text = []
136
+ for context_text in context_list:
137
+ output_text.append(flan_t5_pipeline(context_text)[0]["summary_text"])
138
+ generated_text = " ".join(output_text)
139
+ st.write(flan_t5_pipeline(generated_text)[0]["summary_text"])
140
+
141
+ st.subheader("Retrieved Text:")
142
+
143
+ for context_text in context_list:
144
+ st.markdown(f"- {context_text}")
145
+
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ datasets
2
+ pinecone-client
3
+ sentence-transformers
4
+ torch
5
+ tqdm
6
+ openai