Spaces:
Running
Running
anakin87
commited on
Commit
·
5b26a96
1
Parent(s):
5fe5c67
add LLM explanation feat
Browse files- Rock_fact_checker.py +9 -2
- app_utils/backend_utils.py +29 -4
- app_utils/config.py +9 -0
Rock_fact_checker.py
CHANGED
|
@@ -5,7 +5,7 @@ from json import JSONDecodeError
|
|
| 5 |
|
| 6 |
import streamlit as st
|
| 7 |
|
| 8 |
-
from app_utils.backend_utils import load_statements,
|
| 9 |
from app_utils.frontend_utils import (
|
| 10 |
set_state_if_absent,
|
| 11 |
reset_results,
|
|
@@ -80,7 +80,7 @@ def main():
|
|
| 80 |
st.session_state.statement = statement
|
| 81 |
with st.spinner("🧠 Performing neural search on documents..."):
|
| 82 |
try:
|
| 83 |
-
st.session_state.results =
|
| 84 |
print(f"S: {statement}")
|
| 85 |
time_end = time.time()
|
| 86 |
print(time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime()))
|
|
@@ -121,5 +121,12 @@ def main():
|
|
| 121 |
str_wiki_pages += f"[{doc}]({url}) "
|
| 122 |
st.markdown(str_wiki_pages)
|
| 123 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
|
| 125 |
main()
|
|
|
|
| 5 |
|
| 6 |
import streamlit as st
|
| 7 |
|
| 8 |
+
from app_utils.backend_utils import load_statements, check_statement, explain_using_llm
|
| 9 |
from app_utils.frontend_utils import (
|
| 10 |
set_state_if_absent,
|
| 11 |
reset_results,
|
|
|
|
| 80 |
st.session_state.statement = statement
|
| 81 |
with st.spinner("🧠 Performing neural search on documents..."):
|
| 82 |
try:
|
| 83 |
+
st.session_state.results = check_statement(statement, RETRIEVER_TOP_K)
|
| 84 |
print(f"S: {statement}")
|
| 85 |
time_end = time.time()
|
| 86 |
print(time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime()))
|
|
|
|
| 121 |
str_wiki_pages += f"[{doc}]({url}) "
|
| 122 |
st.markdown(str_wiki_pages)
|
| 123 |
|
| 124 |
+
if max_key != "neutral":
|
| 125 |
+
explanation = explain_using_llm(
|
| 126 |
+
statement=statement, documents=docs, entailment_or_contradiction=max_key
|
| 127 |
+
)
|
| 128 |
+
explanation = "#### Explanation 🧠 (experimental):\n" + explanation
|
| 129 |
+
st.markdown(explanation)
|
| 130 |
+
|
| 131 |
|
| 132 |
main()
|
app_utils/backend_utils.py
CHANGED
|
@@ -1,7 +1,9 @@
|
|
| 1 |
import shutil
|
|
|
|
| 2 |
|
|
|
|
| 3 |
from haystack.document_stores import FAISSDocumentStore
|
| 4 |
-
from haystack.nodes import EmbeddingRetriever
|
| 5 |
from haystack.pipelines import Pipeline
|
| 6 |
import streamlit as st
|
| 7 |
|
|
@@ -12,6 +14,7 @@ from app_utils.config import (
|
|
| 12 |
RETRIEVER_MODEL,
|
| 13 |
RETRIEVER_MODEL_FORMAT,
|
| 14 |
NLI_MODEL,
|
|
|
|
| 15 |
)
|
| 16 |
|
| 17 |
|
|
@@ -53,15 +56,37 @@ def start_haystack():
|
|
| 53 |
pipe = Pipeline()
|
| 54 |
pipe.add_node(component=retriever, name="retriever", inputs=["Query"])
|
| 55 |
pipe.add_node(component=entailment_checker, name="ec", inputs=["retriever"])
|
| 56 |
-
return pipe
|
| 57 |
|
|
|
|
| 58 |
|
| 59 |
-
pipe
|
|
|
|
|
|
|
|
|
|
| 60 |
|
| 61 |
# the pipeline is not included as parameter of the following function,
|
| 62 |
# because it is difficult to cache
|
| 63 |
@st.cache(allow_output_mutation=True)
|
| 64 |
-
def
|
| 65 |
"""Run query and verify statement"""
|
| 66 |
params = {"retriever": {"top_k": retriever_top_k}}
|
| 67 |
return pipe.run(statement, params=params)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import shutil
|
| 2 |
+
from typing import List
|
| 3 |
|
| 4 |
+
from haystack import Document
|
| 5 |
from haystack.document_stores import FAISSDocumentStore
|
| 6 |
+
from haystack.nodes import EmbeddingRetriever, PromptNode
|
| 7 |
from haystack.pipelines import Pipeline
|
| 8 |
import streamlit as st
|
| 9 |
|
|
|
|
| 14 |
RETRIEVER_MODEL,
|
| 15 |
RETRIEVER_MODEL_FORMAT,
|
| 16 |
NLI_MODEL,
|
| 17 |
+
PROMPT_MODEL,
|
| 18 |
)
|
| 19 |
|
| 20 |
|
|
|
|
| 56 |
pipe = Pipeline()
|
| 57 |
pipe.add_node(component=retriever, name="retriever", inputs=["Query"])
|
| 58 |
pipe.add_node(component=entailment_checker, name="ec", inputs=["retriever"])
|
|
|
|
| 59 |
|
| 60 |
+
prompt_node = PromptNode(model_name_or_path=PROMPT_MODEL, max_length=150)
|
| 61 |
|
| 62 |
+
return pipe, prompt_node
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
pipe, prompt_node = start_haystack()
|
| 66 |
|
| 67 |
# the pipeline is not included as parameter of the following function,
|
| 68 |
# because it is difficult to cache
|
| 69 |
@st.cache(allow_output_mutation=True)
|
| 70 |
+
def check_statement(statement: str, retriever_top_k: int = 5):
|
| 71 |
"""Run query and verify statement"""
|
| 72 |
params = {"retriever": {"top_k": retriever_top_k}}
|
| 73 |
return pipe.run(statement, params=params)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
@st.cache(
|
| 77 |
+
hash_funcs={"tokenizers.Tokenizer": lambda _: None}, allow_output_mutation=True
|
| 78 |
+
)
|
| 79 |
+
def explain_using_llm(
|
| 80 |
+
statement: str, documents: List[Document], entailment_or_contradiction: str
|
| 81 |
+
) -> str:
|
| 82 |
+
"""Explain entailment/contradiction, by prompting a LLM"""
|
| 83 |
+
premise = " \n".join([doc.content.replace("\n", ". ") for doc in documents])
|
| 84 |
+
if entailment_or_contradiction == "entailment":
|
| 85 |
+
verb = "entails"
|
| 86 |
+
elif entailment_or_contradiction == "contradiction":
|
| 87 |
+
verb = "contradicts"
|
| 88 |
+
|
| 89 |
+
prompt = f"Premise: {premise}; Hypothesis: {statement}; Please explain in detail why the Premise {verb} the Hypothesis. Step by step Explanation:"
|
| 90 |
+
|
| 91 |
+
print(prompt)
|
| 92 |
+
return prompt_node(prompt)[0]
|
app_utils/config.py
CHANGED
|
@@ -14,3 +14,12 @@ try:
|
|
| 14 |
except:
|
| 15 |
NLI_MODEL = "valhalla/distilbart-mnli-12-1"
|
| 16 |
print(f"Used NLI model: {NLI_MODEL}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
except:
|
| 15 |
NLI_MODEL = "valhalla/distilbart-mnli-12-1"
|
| 16 |
print(f"Used NLI model: {NLI_MODEL}")
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
# In HF Space, we use google/flan-t5-large
|
| 20 |
+
# for local testing, a smaller model is better
|
| 21 |
+
try:
|
| 22 |
+
PROMPT_MODEL = st.secrets["PROMPT_MODEL"]
|
| 23 |
+
except:
|
| 24 |
+
PROMPT_MODEL = "google/flan-t5-small"
|
| 25 |
+
print(f"Used Prompt model: {PROMPT_MODEL}")
|