Spaces:
Runtime error
Runtime error
Commit
·
e04cd14
1
Parent(s):
54abba0
first attempt to hf spaces
Browse files- config/config.yaml +1 -1
- config/document_retriever/multiquery_retriever.yaml +1 -0
- config/gradio_config.yaml +27 -0
- data +1 -0
- src/demo.py +1 -1
- src/document_retriever/multiquery_retriever.py +37 -0
- src/gradio.py +0 -17
- src/gradio_app.py +68 -0
- src/llm4scilit_gradio_interface.py +508 -0
- src/question_answering/huggingface.py +7 -5
config/config.yaml
CHANGED
|
@@ -3,7 +3,7 @@ defaults:
|
|
| 3 |
- text_splitter: spacy
|
| 4 |
- text_embedding: huggingface
|
| 5 |
- vector_store: faiss
|
| 6 |
-
- document_retriever:
|
| 7 |
- question_answering: huggingface
|
| 8 |
- _self_
|
| 9 |
- override hydra/hydra_logging: disabled
|
|
|
|
| 3 |
- text_splitter: spacy
|
| 4 |
- text_embedding: huggingface
|
| 5 |
- vector_store: faiss
|
| 6 |
+
- document_retriever: multiquery_retriever
|
| 7 |
- question_answering: huggingface
|
| 8 |
- _self_
|
| 9 |
- override hydra/hydra_logging: disabled
|
config/document_retriever/multiquery_retriever.yaml
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
_target_: document_retriever.multiquery_retriever.MultiQueryDocumentRetriever
|
config/gradio_config.yaml
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- document_loader: grobid
|
| 3 |
+
- text_splitter: spacy
|
| 4 |
+
- text_embedding: huggingface
|
| 5 |
+
- vector_store: faiss
|
| 6 |
+
- document_retriever: simple_retriever
|
| 7 |
+
- question_answering: huggingface
|
| 8 |
+
- _self_
|
| 9 |
+
- override hydra/hydra_logging: disabled
|
| 10 |
+
- override hydra/job_logging: disabled
|
| 11 |
+
|
| 12 |
+
storage_path:
|
| 13 |
+
base: ./data
|
| 14 |
+
documents: ${storage_path.base}/papers
|
| 15 |
+
documents_processed: ${storage_path.documents}_processed
|
| 16 |
+
vector_store: ${storage_path.base}/vector_store
|
| 17 |
+
|
| 18 |
+
mode: interactive
|
| 19 |
+
debug:
|
| 20 |
+
is_debug: false
|
| 21 |
+
force_rebuild_storage: false
|
| 22 |
+
|
| 23 |
+
document_parsing:
|
| 24 |
+
enabled: false
|
| 25 |
+
|
| 26 |
+
hydra:
|
| 27 |
+
verbose: false
|
data
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
/data/tommaso/llm4scilit/data/
|
src/demo.py
CHANGED
|
@@ -114,7 +114,7 @@ class App:
|
|
| 114 |
|
| 115 |
def ask_chat(self, line, history):
|
| 116 |
# print(f"\nLLM4SciLit: a bunch of nonsense\n")
|
| 117 |
-
return self.qa_model.answer_question(line, {})
|
| 118 |
|
| 119 |
|
| 120 |
##################################################################################################
|
|
|
|
| 114 |
|
| 115 |
def ask_chat(self, line, history):
|
| 116 |
# print(f"\nLLM4SciLit: a bunch of nonsense\n")
|
| 117 |
+
return self.qa_model.answer_question(line, {})
|
| 118 |
|
| 119 |
|
| 120 |
##################################################################################################
|
src/document_retriever/multiquery_retriever.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langchain.llms.huggingface_pipeline import HuggingFacePipeline
|
| 2 |
+
from langchain.retrievers.multi_query import MultiQueryRetriever
|
| 3 |
+
|
| 4 |
+
# Set logging for the queries
|
| 5 |
+
import logging
|
| 6 |
+
|
| 7 |
+
logging.basicConfig()
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class MultiQueryDocumentRetriever:
|
| 11 |
+
def __init__(self, vector_store):
|
| 12 |
+
self.vector_store = vector_store
|
| 13 |
+
self.retriever = None
|
| 14 |
+
self.llm = None
|
| 15 |
+
# self.token = "LL-1kuyxK1z5NQYOiOsf5UdozHJuLhV6udoDGxL8NfM7brWCUbF0uqlii15sso8GNrd"
|
| 16 |
+
|
| 17 |
+
def initialize(self):
|
| 18 |
+
# self.llama = LlamaAPI(self.token)
|
| 19 |
+
self.llm = HuggingFacePipeline.from_model_id(
|
| 20 |
+
# model_id="bigscience/bloom-1b7",
|
| 21 |
+
model_id="bigscience/bloomz-1b7",
|
| 22 |
+
task="text-generation",
|
| 23 |
+
# device=1,
|
| 24 |
+
# model_kwargs={"do_sample": True, "temperature": 0.7, "num_beams": 4, "top_p": 0.95, "repetition_penalty": 1.25, "length_penalty": 1.2},
|
| 25 |
+
model_kwargs={"do_sample": True, "temperature": 0.7, "num_beams": 2},
|
| 26 |
+
# pipeline_kwargs={"max_new_tokens": 256, "min_new_tokens": 30},
|
| 27 |
+
pipeline_kwargs={"max_new_tokens": 256, "min_new_tokens": 30},
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
logging.getLogger("langchain.retrievers.multi_query").setLevel(logging.INFO)
|
| 31 |
+
self.retriever = MultiQueryRetriever.from_llm(
|
| 32 |
+
retriever=self.vector_store.db.as_retriever(search_kwargs={"k": 4, "fetch_k": 40}),
|
| 33 |
+
llm=self.llm
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
def retrieve(self, query: str, k: int = 4):
|
| 37 |
+
pass
|
src/gradio.py
DELETED
|
@@ -1,17 +0,0 @@
|
|
| 1 |
-
import gradio as gr
|
| 2 |
-
from hydra import compose, initialize
|
| 3 |
-
from omegaconf import OmegaConf
|
| 4 |
-
|
| 5 |
-
from demo import App
|
| 6 |
-
|
| 7 |
-
def main():
|
| 8 |
-
with initialize(version_base=None, config_path="../config", job_name="gradio_app"):
|
| 9 |
-
cfg = compose(config_name="config", overrides=["document_parsing.enabled=False"])
|
| 10 |
-
|
| 11 |
-
app = App(cfg)
|
| 12 |
-
|
| 13 |
-
webapp = gr.ChatInterface(fn=app.ask_chat, examples=["hello", "hola", "merhaba"], title="LLM4SciLit")
|
| 14 |
-
webapp.launch(share=True)
|
| 15 |
-
|
| 16 |
-
if __name__ == "__main__":
|
| 17 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/gradio_app.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import hydra
|
| 2 |
+
from omegaconf import DictConfig
|
| 3 |
+
from demo import App
|
| 4 |
+
|
| 5 |
+
from llm4scilit_gradio_interface import LLM4SciLitChatInterface
|
| 6 |
+
|
| 7 |
+
def echo(text, history):
|
| 8 |
+
asdf = "asdf"
|
| 9 |
+
values = [f"{x}\n{x*2}" for x in asdf]
|
| 10 |
+
return text, *values
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@hydra.main(version_base=None, config_path="../config", config_name="gradio_config")
|
| 14 |
+
def main(cfg : DictConfig) -> None:
|
| 15 |
+
cfg.document_parsing['enabled'] = False
|
| 16 |
+
|
| 17 |
+
app = App(cfg)
|
| 18 |
+
app._bootstrap()
|
| 19 |
+
|
| 20 |
+
def wrapped_ask_chat(text, history):
|
| 21 |
+
result = app.ask_chat(text, history)
|
| 22 |
+
sources = [
|
| 23 |
+
f"{x.metadata['paper_title']}\n{x.page_content}"
|
| 24 |
+
for x in result['source_documents']
|
| 25 |
+
]
|
| 26 |
+
return result['result'], *sources
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
LLM4SciLitChatInterface(wrapped_ask_chat, title="LLM4SciLit").launch()
|
| 30 |
+
# LLM4SciLitChatInterface(echo, title="LLM4SciLit").launch()
|
| 31 |
+
|
| 32 |
+
# textbox = gr.Textbox(placeholder="Ask a question about scientific literature", lines=2, label="Question", elem_id="textbox")
|
| 33 |
+
# chatbot = gr.Chatbot(label="LLM4SciLit", elem_id="chat")
|
| 34 |
+
# gr.Interface(fn=echo, inputs=[textbox, chatbot], outputs=[chatbot], title="LLM4SciLit").launch()
|
| 35 |
+
|
| 36 |
+
# with gr.Blocks() as demo:
|
| 37 |
+
# chatbot = gr.Chatbot()
|
| 38 |
+
# msg = gr.Textbox(container=False)
|
| 39 |
+
# clear = gr.ClearButton([msg, chatbot])
|
| 40 |
+
|
| 41 |
+
# def respond(message, chat_history):
|
| 42 |
+
# bot_message = "How are you?"
|
| 43 |
+
# chat_history.append((message, bot_message))
|
| 44 |
+
# return "", chat_history
|
| 45 |
+
|
| 46 |
+
# msg.submit(respond, [msg, chatbot], [msg, chatbot])
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
# with gr.Blocks(title="LLM4SciLit") as demo:
|
| 51 |
+
# with gr.Row():
|
| 52 |
+
# with gr.Column(scale=5):
|
| 53 |
+
# with gr.Row():
|
| 54 |
+
# gr.Chatbot(fn=echo)
|
| 55 |
+
# with gr.Row():
|
| 56 |
+
# gr.Button("Submit")
|
| 57 |
+
|
| 58 |
+
# with gr.Column(scale=5):
|
| 59 |
+
# with gr.Accordion("Retrieved documents"):
|
| 60 |
+
# gr.Label("Document 1")
|
| 61 |
+
|
| 62 |
+
# webapp = gr.ChatInterface(fn=app.ask_chat, examples=["hello", "hola", "merhaba"], title="LLM4SciLit")
|
| 63 |
+
# webapp = gr.ChatInterface(fn=echo, examples=["hello", "hola", "merhaba"], title="LLM4SciLit")
|
| 64 |
+
# demo.launch()
|
| 65 |
+
# webapp.launch(share=True)
|
| 66 |
+
|
| 67 |
+
if __name__ == "__main__":
|
| 68 |
+
main() # pylint: disable=no-value-for-parameter
|
src/llm4scilit_gradio_interface.py
ADDED
|
@@ -0,0 +1,508 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This file defines a useful high-level abstraction to build Gradio chatbots: ChatInterface.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
|
| 8 |
+
import inspect
|
| 9 |
+
from typing import AsyncGenerator, Callable
|
| 10 |
+
|
| 11 |
+
import anyio
|
| 12 |
+
from gradio_client import utils as client_utils
|
| 13 |
+
from gradio_client.documentation import document, set_documentation_group
|
| 14 |
+
|
| 15 |
+
from gradio.blocks import Blocks
|
| 16 |
+
from gradio.components import (
|
| 17 |
+
Button,
|
| 18 |
+
Chatbot,
|
| 19 |
+
IOComponent,
|
| 20 |
+
Markdown,
|
| 21 |
+
State,
|
| 22 |
+
Textbox,
|
| 23 |
+
get_component_instance,
|
| 24 |
+
)
|
| 25 |
+
from gradio.events import Dependency, EventListenerMethod, on
|
| 26 |
+
from gradio.helpers import create_examples as Examples # noqa: N812
|
| 27 |
+
from gradio.layouts import Accordion, Column, Group, Row
|
| 28 |
+
from gradio.themes import ThemeClass as Theme
|
| 29 |
+
from gradio.utils import SyncToAsyncIterator, async_iteration
|
| 30 |
+
|
| 31 |
+
set_documentation_group("chatinterface")
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
@document()
|
| 35 |
+
class LLM4SciLitChatInterface(Blocks):
|
| 36 |
+
"""
|
| 37 |
+
ChatInterface is Gradio's high-level abstraction for creating chatbot UIs, and allows you to create
|
| 38 |
+
a web-based demo around a chatbot model in a few lines of code. Only one parameter is required: fn, which
|
| 39 |
+
takes a function that governs the response of the chatbot based on the user input and chat history. Additional
|
| 40 |
+
parameters can be used to control the appearance and behavior of the demo.
|
| 41 |
+
|
| 42 |
+
Example:
|
| 43 |
+
import gradio as gr
|
| 44 |
+
|
| 45 |
+
def echo(message, history):
|
| 46 |
+
return message
|
| 47 |
+
|
| 48 |
+
demo = gr.ChatInterface(fn=echo, examples=["hello", "hola", "merhaba"], title="Echo Bot")
|
| 49 |
+
demo.launch()
|
| 50 |
+
Demos: chatinterface_random_response, chatinterface_streaming_echo
|
| 51 |
+
Guides: creating-a-chatbot-fast, sharing-your-app
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
def __init__(
|
| 55 |
+
self,
|
| 56 |
+
fn: Callable,
|
| 57 |
+
*,
|
| 58 |
+
chatbot: Chatbot | None = None,
|
| 59 |
+
textbox: Textbox | None = None,
|
| 60 |
+
additional_inputs: str | IOComponent | list[str | IOComponent] | None = None,
|
| 61 |
+
additional_inputs_accordion_name: str = "Additional Inputs",
|
| 62 |
+
examples: list[str] | None = None,
|
| 63 |
+
cache_examples: bool | None = None,
|
| 64 |
+
title: str | None = None,
|
| 65 |
+
description: str | None = None,
|
| 66 |
+
theme: Theme | str | None = None,
|
| 67 |
+
css: str | None = None,
|
| 68 |
+
analytics_enabled: bool | None = None,
|
| 69 |
+
submit_btn: str | None | Button = "Submit",
|
| 70 |
+
stop_btn: str | None | Button = "Stop",
|
| 71 |
+
retry_btn: str | None | Button = "🔄 Retry",
|
| 72 |
+
undo_btn: str | None | Button = "↩️ Undo",
|
| 73 |
+
clear_btn: str | None | Button = "🗑️ Clear",
|
| 74 |
+
autofocus: bool = True,
|
| 75 |
+
):
|
| 76 |
+
"""
|
| 77 |
+
Parameters:
|
| 78 |
+
fn: the function to wrap the chat interface around. Should accept two parameters: a string input message and list of two-element lists of the form [[user_message, bot_message], ...] representing the chat history, and return a string response. See the Chatbot documentation for more information on the chat history format.
|
| 79 |
+
chatbot: an instance of the gr.Chatbot component to use for the chat interface, if you would like to customize the chatbot properties. If not provided, a default gr.Chatbot component will be created.
|
| 80 |
+
textbox: an instance of the gr.Textbox component to use for the chat interface, if you would like to customize the textbox properties. If not provided, a default gr.Textbox component will be created.
|
| 81 |
+
additional_inputs: an instance or list of instances of gradio components (or their string shortcuts) to use as additional inputs to the chatbot. If components are not already rendered in a surrounding Blocks, then the components will be displayed under the chatbot, in an accordion.
|
| 82 |
+
additional_inputs_accordion_name: the label of the accordion to use for additional inputs, only used if additional_inputs is provided.
|
| 83 |
+
examples: sample inputs for the function; if provided, appear below the chatbot and can be clicked to populate the chatbot input.
|
| 84 |
+
cache_examples: If True, caches examples in the server for fast runtime in examples. The default option in HuggingFace Spaces is True. The default option elsewhere is False.
|
| 85 |
+
title: a title for the interface; if provided, appears above chatbot in large font. Also used as the tab title when opened in a browser window.
|
| 86 |
+
description: a description for the interface; if provided, appears above the chatbot and beneath the title in regular font. Accepts Markdown and HTML content.
|
| 87 |
+
theme: Theme to use, loaded from gradio.themes.
|
| 88 |
+
css: custom css or path to custom css file to use with interface.
|
| 89 |
+
analytics_enabled: Whether to allow basic telemetry. If None, will use GRADIO_ANALYTICS_ENABLED environment variable if defined, or default to True.
|
| 90 |
+
submit_btn: Text to display on the submit button. If None, no button will be displayed. If a Button object, that button will be used.
|
| 91 |
+
stop_btn: Text to display on the stop button, which replaces the submit_btn when the submit_btn or retry_btn is clicked and response is streaming. Clicking on the stop_btn will halt the chatbot response. If set to None, stop button functionality does not appear in the chatbot. If a Button object, that button will be used as the stop button.
|
| 92 |
+
retry_btn: Text to display on the retry button. If None, no button will be displayed. If a Button object, that button will be used.
|
| 93 |
+
undo_btn: Text to display on the delete last button. If None, no button will be displayed. If a Button object, that button will be used.
|
| 94 |
+
clear_btn: Text to display on the clear button. If None, no button will be displayed. If a Button object, that button will be used.
|
| 95 |
+
autofocus: If True, autofocuses to the textbox when the page loads.
|
| 96 |
+
"""
|
| 97 |
+
super().__init__(
|
| 98 |
+
analytics_enabled=analytics_enabled,
|
| 99 |
+
mode="chat_interface",
|
| 100 |
+
css=css,
|
| 101 |
+
title=title or "Gradio",
|
| 102 |
+
theme=theme,
|
| 103 |
+
)
|
| 104 |
+
self.fn = fn
|
| 105 |
+
self.is_async = inspect.iscoroutinefunction(
|
| 106 |
+
self.fn
|
| 107 |
+
) or inspect.isasyncgenfunction(self.fn)
|
| 108 |
+
self.is_generator = inspect.isgeneratorfunction(
|
| 109 |
+
self.fn
|
| 110 |
+
) or inspect.isasyncgenfunction(self.fn)
|
| 111 |
+
self.examples = examples
|
| 112 |
+
if self.space_id and cache_examples is None:
|
| 113 |
+
self.cache_examples = True
|
| 114 |
+
else:
|
| 115 |
+
self.cache_examples = cache_examples or False
|
| 116 |
+
self.buttons: list[Button] = []
|
| 117 |
+
|
| 118 |
+
if additional_inputs:
|
| 119 |
+
if not isinstance(additional_inputs, list):
|
| 120 |
+
additional_inputs = [additional_inputs]
|
| 121 |
+
self.additional_inputs = [
|
| 122 |
+
get_component_instance(i) for i in additional_inputs # type: ignore
|
| 123 |
+
]
|
| 124 |
+
else:
|
| 125 |
+
self.additional_inputs = []
|
| 126 |
+
self.additional_inputs_accordion_name = additional_inputs_accordion_name
|
| 127 |
+
|
| 128 |
+
self.additional_outputs = []
|
| 129 |
+
|
| 130 |
+
with self:
|
| 131 |
+
if title:
|
| 132 |
+
Markdown(
|
| 133 |
+
f"<h1 style='text-align: center; margin-bottom: 1rem'>{self.title}</h1>"
|
| 134 |
+
)
|
| 135 |
+
if description:
|
| 136 |
+
Markdown(description)
|
| 137 |
+
|
| 138 |
+
with Row():
|
| 139 |
+
with Column(variant="panel", scale=1):
|
| 140 |
+
if chatbot:
|
| 141 |
+
self.chatbot = chatbot.render()
|
| 142 |
+
else:
|
| 143 |
+
self.chatbot = Chatbot(label="Chatbot")
|
| 144 |
+
|
| 145 |
+
with Group():
|
| 146 |
+
with Row():
|
| 147 |
+
if textbox:
|
| 148 |
+
textbox.container = False
|
| 149 |
+
textbox.show_label = False
|
| 150 |
+
self.textbox = textbox.render()
|
| 151 |
+
else:
|
| 152 |
+
self.textbox = Textbox(
|
| 153 |
+
container=False,
|
| 154 |
+
show_label=False,
|
| 155 |
+
label="Message",
|
| 156 |
+
placeholder="Type a message...",
|
| 157 |
+
scale=7,
|
| 158 |
+
autofocus=autofocus,
|
| 159 |
+
)
|
| 160 |
+
if submit_btn:
|
| 161 |
+
if isinstance(submit_btn, Button):
|
| 162 |
+
submit_btn.render()
|
| 163 |
+
elif isinstance(submit_btn, str):
|
| 164 |
+
submit_btn = Button(
|
| 165 |
+
submit_btn,
|
| 166 |
+
variant="primary",
|
| 167 |
+
scale=1,
|
| 168 |
+
min_width=150,
|
| 169 |
+
)
|
| 170 |
+
else:
|
| 171 |
+
raise ValueError(
|
| 172 |
+
f"The submit_btn parameter must be a gr.Button, string, or None, not {type(submit_btn)}"
|
| 173 |
+
)
|
| 174 |
+
if stop_btn:
|
| 175 |
+
if isinstance(stop_btn, Button):
|
| 176 |
+
stop_btn.visible = False
|
| 177 |
+
stop_btn.render()
|
| 178 |
+
elif isinstance(stop_btn, str):
|
| 179 |
+
stop_btn = Button(
|
| 180 |
+
stop_btn,
|
| 181 |
+
variant="stop",
|
| 182 |
+
visible=False,
|
| 183 |
+
scale=1,
|
| 184 |
+
min_width=150,
|
| 185 |
+
)
|
| 186 |
+
else:
|
| 187 |
+
raise ValueError(
|
| 188 |
+
f"The stop_btn parameter must be a gr.Button, string, or None, not {type(stop_btn)}"
|
| 189 |
+
)
|
| 190 |
+
self.buttons.extend([submit_btn, stop_btn])
|
| 191 |
+
|
| 192 |
+
with Row():
|
| 193 |
+
for btn in [retry_btn, undo_btn, clear_btn]:
|
| 194 |
+
if btn:
|
| 195 |
+
if isinstance(btn, Button):
|
| 196 |
+
btn.render()
|
| 197 |
+
elif isinstance(btn, str):
|
| 198 |
+
btn = Button(btn, variant="secondary")
|
| 199 |
+
else:
|
| 200 |
+
raise ValueError(
|
| 201 |
+
f"All the _btn parameters must be a gr.Button, string, or None, not {type(btn)}"
|
| 202 |
+
)
|
| 203 |
+
self.buttons.append(btn)
|
| 204 |
+
|
| 205 |
+
self.fake_api_btn = Button("Fake API", visible=False)
|
| 206 |
+
self.fake_response_textbox = Textbox(
|
| 207 |
+
label="Response", visible=False
|
| 208 |
+
)
|
| 209 |
+
(
|
| 210 |
+
self.submit_btn,
|
| 211 |
+
self.stop_btn,
|
| 212 |
+
self.retry_btn,
|
| 213 |
+
self.undo_btn,
|
| 214 |
+
self.clear_btn,
|
| 215 |
+
) = self.buttons
|
| 216 |
+
|
| 217 |
+
with Column(variant="panel", scale=2):
|
| 218 |
+
for i in range(4):
|
| 219 |
+
self.additional_outputs.append(
|
| 220 |
+
Textbox(
|
| 221 |
+
interactive=False,
|
| 222 |
+
label=f"Document {i+1}"
|
| 223 |
+
)
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
if examples:
|
| 227 |
+
if self.is_generator:
|
| 228 |
+
examples_fn = self._examples_stream_fn
|
| 229 |
+
else:
|
| 230 |
+
examples_fn = self._examples_fn
|
| 231 |
+
|
| 232 |
+
self.examples_handler = Examples(
|
| 233 |
+
examples=examples,
|
| 234 |
+
inputs=[self.textbox] + self.additional_inputs,
|
| 235 |
+
outputs=self.chatbot,
|
| 236 |
+
fn=examples_fn,
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
any_unrendered_inputs = any(
|
| 240 |
+
not inp.is_rendered for inp in self.additional_inputs
|
| 241 |
+
)
|
| 242 |
+
if self.additional_inputs and any_unrendered_inputs:
|
| 243 |
+
with Accordion(self.additional_inputs_accordion_name, open=False):
|
| 244 |
+
for input_component in self.additional_inputs:
|
| 245 |
+
if not input_component.is_rendered:
|
| 246 |
+
input_component.render()
|
| 247 |
+
|
| 248 |
+
# The example caching must happen after the input components have rendered
|
| 249 |
+
if cache_examples:
|
| 250 |
+
client_utils.synchronize_async(self.examples_handler.cache)
|
| 251 |
+
|
| 252 |
+
self.saved_input = State()
|
| 253 |
+
self.chatbot_state = State([])
|
| 254 |
+
|
| 255 |
+
self._setup_events()
|
| 256 |
+
self._setup_api()
|
| 257 |
+
|
| 258 |
+
def _setup_events(self) -> None:
|
| 259 |
+
submit_fn = self._stream_fn if self.is_generator else self._submit_fn
|
| 260 |
+
submit_triggers = (
|
| 261 |
+
[self.textbox.submit, self.submit_btn.click]
|
| 262 |
+
if self.submit_btn
|
| 263 |
+
else [self.textbox.submit]
|
| 264 |
+
)
|
| 265 |
+
submit_event = (
|
| 266 |
+
on(
|
| 267 |
+
submit_triggers,
|
| 268 |
+
self._clear_and_save_textbox,
|
| 269 |
+
[self.textbox],
|
| 270 |
+
[self.textbox, self.saved_input],
|
| 271 |
+
api_name=False,
|
| 272 |
+
queue=False,
|
| 273 |
+
)
|
| 274 |
+
.then(
|
| 275 |
+
self._display_input,
|
| 276 |
+
[self.saved_input, self.chatbot_state],
|
| 277 |
+
[self.chatbot, self.chatbot_state],
|
| 278 |
+
api_name=False,
|
| 279 |
+
queue=False,
|
| 280 |
+
)
|
| 281 |
+
.then(
|
| 282 |
+
submit_fn,
|
| 283 |
+
[self.saved_input, self.chatbot_state] + self.additional_inputs,
|
| 284 |
+
[self.chatbot, self.chatbot_state] + self.additional_outputs,
|
| 285 |
+
api_name=False,
|
| 286 |
+
)
|
| 287 |
+
)
|
| 288 |
+
self._setup_stop_events(submit_triggers, submit_event)
|
| 289 |
+
|
| 290 |
+
if self.retry_btn:
|
| 291 |
+
retry_event = (
|
| 292 |
+
self.retry_btn.click(
|
| 293 |
+
self._delete_prev_fn,
|
| 294 |
+
[self.chatbot_state],
|
| 295 |
+
[self.chatbot, self.saved_input, self.chatbot_state],
|
| 296 |
+
api_name=False,
|
| 297 |
+
queue=False,
|
| 298 |
+
)
|
| 299 |
+
.then(
|
| 300 |
+
self._display_input,
|
| 301 |
+
[self.saved_input, self.chatbot_state],
|
| 302 |
+
[self.chatbot, self.chatbot_state],
|
| 303 |
+
api_name=False,
|
| 304 |
+
queue=False,
|
| 305 |
+
)
|
| 306 |
+
.then(
|
| 307 |
+
submit_fn,
|
| 308 |
+
[self.saved_input, self.chatbot_state] + self.additional_inputs,
|
| 309 |
+
[self.chatbot, self.chatbot_state],
|
| 310 |
+
api_name=False,
|
| 311 |
+
)
|
| 312 |
+
)
|
| 313 |
+
self._setup_stop_events([self.retry_btn.click], retry_event)
|
| 314 |
+
|
| 315 |
+
if self.undo_btn:
|
| 316 |
+
self.undo_btn.click(
|
| 317 |
+
self._delete_prev_fn,
|
| 318 |
+
[self.chatbot_state],
|
| 319 |
+
[self.chatbot, self.saved_input, self.chatbot_state],
|
| 320 |
+
api_name=False,
|
| 321 |
+
queue=False,
|
| 322 |
+
).then(
|
| 323 |
+
lambda x: x,
|
| 324 |
+
[self.saved_input],
|
| 325 |
+
[self.textbox],
|
| 326 |
+
api_name=False,
|
| 327 |
+
queue=False,
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
if self.clear_btn:
|
| 331 |
+
self.clear_btn.click(
|
| 332 |
+
lambda: ([], [], None),
|
| 333 |
+
None,
|
| 334 |
+
[self.chatbot, self.chatbot_state, self.saved_input],
|
| 335 |
+
queue=False,
|
| 336 |
+
api_name=False,
|
| 337 |
+
)
|
| 338 |
+
|
| 339 |
+
def _setup_stop_events(
|
| 340 |
+
self, event_triggers: list[EventListenerMethod], event_to_cancel: Dependency
|
| 341 |
+
) -> None:
|
| 342 |
+
if self.stop_btn and self.is_generator:
|
| 343 |
+
if self.submit_btn:
|
| 344 |
+
for event_trigger in event_triggers:
|
| 345 |
+
event_trigger(
|
| 346 |
+
lambda: (
|
| 347 |
+
Button.update(visible=False),
|
| 348 |
+
Button.update(visible=True),
|
| 349 |
+
),
|
| 350 |
+
None,
|
| 351 |
+
[self.submit_btn, self.stop_btn],
|
| 352 |
+
api_name=False,
|
| 353 |
+
queue=False,
|
| 354 |
+
)
|
| 355 |
+
event_to_cancel.then(
|
| 356 |
+
lambda: (Button.update(visible=True), Button.update(visible=False)),
|
| 357 |
+
None,
|
| 358 |
+
[self.submit_btn, self.stop_btn],
|
| 359 |
+
api_name=False,
|
| 360 |
+
queue=False,
|
| 361 |
+
)
|
| 362 |
+
else:
|
| 363 |
+
for event_trigger in event_triggers:
|
| 364 |
+
event_trigger(
|
| 365 |
+
lambda: Button.update(visible=True),
|
| 366 |
+
None,
|
| 367 |
+
[self.stop_btn],
|
| 368 |
+
api_name=False,
|
| 369 |
+
queue=False,
|
| 370 |
+
)
|
| 371 |
+
event_to_cancel.then(
|
| 372 |
+
lambda: Button.update(visible=False),
|
| 373 |
+
None,
|
| 374 |
+
[self.stop_btn],
|
| 375 |
+
api_name=False,
|
| 376 |
+
queue=False,
|
| 377 |
+
)
|
| 378 |
+
self.stop_btn.click(
|
| 379 |
+
None,
|
| 380 |
+
None,
|
| 381 |
+
None,
|
| 382 |
+
cancels=event_to_cancel,
|
| 383 |
+
api_name=False,
|
| 384 |
+
)
|
| 385 |
+
|
| 386 |
+
def _setup_api(self) -> None:
|
| 387 |
+
api_fn = self._api_stream_fn if self.is_generator else self._api_submit_fn
|
| 388 |
+
|
| 389 |
+
self.fake_api_btn.click(
|
| 390 |
+
api_fn,
|
| 391 |
+
[self.textbox, self.chatbot_state] + self.additional_inputs,
|
| 392 |
+
[self.textbox, self.chatbot_state],
|
| 393 |
+
api_name="chat",
|
| 394 |
+
)
|
| 395 |
+
|
| 396 |
+
def _clear_and_save_textbox(self, message: str) -> tuple[str, str]:
|
| 397 |
+
return "", message
|
| 398 |
+
|
| 399 |
+
def _display_input(
|
| 400 |
+
self, message: str, history: list[list[str | None]]
|
| 401 |
+
) -> tuple[list[list[str | None]], list[list[str | None]]]:
|
| 402 |
+
history.append([message, None])
|
| 403 |
+
return history, history
|
| 404 |
+
|
| 405 |
+
async def _submit_fn(
|
| 406 |
+
self,
|
| 407 |
+
message: str,
|
| 408 |
+
history_with_input: list[list[str | None]],
|
| 409 |
+
*args,
|
| 410 |
+
) -> tuple[list[list[str | None]], list[list[str | None]]]:
|
| 411 |
+
history = history_with_input[:-1]
|
| 412 |
+
if self.is_async:
|
| 413 |
+
[response, *other_outputs] = await self.fn(message, history, *args)
|
| 414 |
+
else:
|
| 415 |
+
[response, *other_outputs] = await anyio.to_thread.run_sync(
|
| 416 |
+
self.fn, message, history, *args, limiter=self.limiter
|
| 417 |
+
)
|
| 418 |
+
history.append([message, response])
|
| 419 |
+
|
| 420 |
+
return history, history, *other_outputs
|
| 421 |
+
|
| 422 |
+
async def _stream_fn(
|
| 423 |
+
self,
|
| 424 |
+
message: str,
|
| 425 |
+
history_with_input: list[list[str | None]],
|
| 426 |
+
*args,
|
| 427 |
+
) -> AsyncGenerator:
|
| 428 |
+
history = history_with_input[:-1]
|
| 429 |
+
if self.is_async:
|
| 430 |
+
generator = self.fn(message, history, *args)
|
| 431 |
+
else:
|
| 432 |
+
generator = await anyio.to_thread.run_sync(
|
| 433 |
+
self.fn, message, history, *args, limiter=self.limiter
|
| 434 |
+
)
|
| 435 |
+
generator = SyncToAsyncIterator(generator, self.limiter)
|
| 436 |
+
try:
|
| 437 |
+
first_response = await async_iteration(generator)
|
| 438 |
+
update = history + [[message, first_response]]
|
| 439 |
+
yield update, update
|
| 440 |
+
except StopIteration:
|
| 441 |
+
update = history + [[message, None]]
|
| 442 |
+
yield update, update
|
| 443 |
+
async for response in generator:
|
| 444 |
+
update = history + [[message, response]]
|
| 445 |
+
yield update, update
|
| 446 |
+
|
| 447 |
+
async def _api_submit_fn(
|
| 448 |
+
self, message: str, history: list[list[str | None]], *args
|
| 449 |
+
) -> tuple[str, list[list[str | None]]]:
|
| 450 |
+
if self.is_async:
|
| 451 |
+
response = await self.fn(message, history, *args)
|
| 452 |
+
else:
|
| 453 |
+
response = await anyio.to_thread.run_sync(
|
| 454 |
+
self.fn, message, history, *args, limiter=self.limiter
|
| 455 |
+
)
|
| 456 |
+
history.append([message, response])
|
| 457 |
+
return response, history
|
| 458 |
+
|
| 459 |
+
async def _api_stream_fn(
|
| 460 |
+
self, message: str, history: list[list[str | None]], *args
|
| 461 |
+
) -> AsyncGenerator:
|
| 462 |
+
if self.is_async:
|
| 463 |
+
generator = self.fn(message, history, *args)
|
| 464 |
+
else:
|
| 465 |
+
generator = await anyio.to_thread.run_sync(
|
| 466 |
+
self.fn, message, history, *args, limiter=self.limiter
|
| 467 |
+
)
|
| 468 |
+
generator = SyncToAsyncIterator(generator, self.limiter)
|
| 469 |
+
try:
|
| 470 |
+
first_response = await async_iteration(generator)
|
| 471 |
+
yield first_response, history + [[message, first_response]]
|
| 472 |
+
except StopIteration:
|
| 473 |
+
yield None, history + [[message, None]]
|
| 474 |
+
async for response in generator:
|
| 475 |
+
yield response, history + [[message, response]]
|
| 476 |
+
|
| 477 |
+
async def _examples_fn(self, message: str, *args) -> list[list[str | None]]:
|
| 478 |
+
if self.is_async:
|
| 479 |
+
response = await self.fn(message, [], *args)
|
| 480 |
+
else:
|
| 481 |
+
response = await anyio.to_thread.run_sync(
|
| 482 |
+
self.fn, message, [], *args, limiter=self.limiter
|
| 483 |
+
)
|
| 484 |
+
return [[message, response]]
|
| 485 |
+
|
| 486 |
+
async def _examples_stream_fn(
|
| 487 |
+
self,
|
| 488 |
+
message: str,
|
| 489 |
+
*args,
|
| 490 |
+
) -> AsyncGenerator:
|
| 491 |
+
if self.is_async:
|
| 492 |
+
generator = self.fn(message, [], *args)
|
| 493 |
+
else:
|
| 494 |
+
generator = await anyio.to_thread.run_sync(
|
| 495 |
+
self.fn, message, [], *args, limiter=self.limiter
|
| 496 |
+
)
|
| 497 |
+
generator = SyncToAsyncIterator(generator, self.limiter)
|
| 498 |
+
async for response in generator:
|
| 499 |
+
yield [[message, response]]
|
| 500 |
+
|
| 501 |
+
def _delete_prev_fn(
|
| 502 |
+
self, history: list[list[str | None]]
|
| 503 |
+
) -> tuple[list[list[str | None]], str, list[list[str | None]]]:
|
| 504 |
+
try:
|
| 505 |
+
message, _ = history.pop()
|
| 506 |
+
except IndexError:
|
| 507 |
+
message = ""
|
| 508 |
+
return history, message or "", history
|
src/question_answering/huggingface.py
CHANGED
|
@@ -1,15 +1,15 @@
|
|
| 1 |
-
from langchain import PromptTemplate
|
| 2 |
from langchain.chains import RetrievalQA
|
| 3 |
-
from langchain.llms import HuggingFacePipeline
|
| 4 |
|
| 5 |
class HuggingFaceQuestionAnswering:
|
| 6 |
def __init__(self, retriever) -> None:
|
| 7 |
self.retriever = retriever
|
| 8 |
self.llm = HuggingFacePipeline.from_model_id(
|
| 9 |
# model_id="bigscience/bloom-1b7",
|
| 10 |
-
model_id="bigscience/bloomz-
|
| 11 |
task="text-generation",
|
| 12 |
-
device=1,
|
| 13 |
# model_kwargs={"do_sample": True, "temperature": 0.7, "num_beams": 4, "top_p": 0.95, "repetition_penalty": 1.25, "length_penalty": 1.2},
|
| 14 |
model_kwargs={"do_sample": True, "temperature": 0.7, "num_beams": 2},
|
| 15 |
# pipeline_kwargs={"max_new_tokens": 256, "min_new_tokens": 30},
|
|
@@ -27,6 +27,7 @@ class HuggingFaceQuestionAnswering:
|
|
| 27 |
|
| 28 |
def answer_question(self, question: str, filter_dict):
|
| 29 |
retriever = self.retriever.vector_store.db.as_retriever(search_kwargs={"filter": filter_dict, "fetch_k": 150})
|
|
|
|
| 30 |
|
| 31 |
try:
|
| 32 |
self.chain = RetrievalQA.from_chain_type(self.llm, retriever=retriever, return_source_documents=True)
|
|
@@ -36,5 +37,6 @@ class HuggingFaceQuestionAnswering:
|
|
| 36 |
Retrieved Documents:
|
| 37 |
{docs if docs != "" else "No documents found."}""")
|
| 38 |
return result
|
| 39 |
-
except:
|
|
|
|
| 40 |
return {"result": "Error generating answer."}
|
|
|
|
| 1 |
+
from langchain.prompts.prompt import PromptTemplate
|
| 2 |
from langchain.chains import RetrievalQA
|
| 3 |
+
from langchain.llms.huggingface_pipeline import HuggingFacePipeline
|
| 4 |
|
| 5 |
class HuggingFaceQuestionAnswering:
|
| 6 |
def __init__(self, retriever) -> None:
|
| 7 |
self.retriever = retriever
|
| 8 |
self.llm = HuggingFacePipeline.from_model_id(
|
| 9 |
# model_id="bigscience/bloom-1b7",
|
| 10 |
+
model_id="bigscience/bloomz-1b7",
|
| 11 |
task="text-generation",
|
| 12 |
+
# device=1,
|
| 13 |
# model_kwargs={"do_sample": True, "temperature": 0.7, "num_beams": 4, "top_p": 0.95, "repetition_penalty": 1.25, "length_penalty": 1.2},
|
| 14 |
model_kwargs={"do_sample": True, "temperature": 0.7, "num_beams": 2},
|
| 15 |
# pipeline_kwargs={"max_new_tokens": 256, "min_new_tokens": 30},
|
|
|
|
| 27 |
|
| 28 |
def answer_question(self, question: str, filter_dict):
|
| 29 |
retriever = self.retriever.vector_store.db.as_retriever(search_kwargs={"filter": filter_dict, "fetch_k": 150})
|
| 30 |
+
# retriever = self.retriever.retriever
|
| 31 |
|
| 32 |
try:
|
| 33 |
self.chain = RetrievalQA.from_chain_type(self.llm, retriever=retriever, return_source_documents=True)
|
|
|
|
| 37 |
Retrieved Documents:
|
| 38 |
{docs if docs != "" else "No documents found."}""")
|
| 39 |
return result
|
| 40 |
+
except Exception as e:
|
| 41 |
+
print(e)
|
| 42 |
return {"result": "Error generating answer."}
|