Update app.py
Browse files
app.py
CHANGED
|
@@ -1,12 +1,63 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
from transformers import pipeline
|
| 3 |
from huggingface_hub import InferenceClient, login, snapshot_download
|
| 4 |
-
from langchain_community.vectorstores import FAISS
|
| 5 |
from langchain_huggingface import HuggingFaceEmbeddings
|
| 6 |
import os
|
| 7 |
import pandas as pd
|
| 8 |
from datetime import datetime
|
| 9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
"""
|
| 12 |
For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
|
|
@@ -14,9 +65,7 @@ For more information on `huggingface_hub` Inference API support, please check th
|
|
| 14 |
HF_TOKEN=os.getenv('TOKEN')
|
| 15 |
login(HF_TOKEN)
|
| 16 |
|
| 17 |
-
|
| 18 |
-
#model = "google/mt5-small"
|
| 19 |
-
model = "mistralai/Mistral-7B-Instruct-v0.3"
|
| 20 |
|
| 21 |
client = InferenceClient(model)
|
| 22 |
|
|
@@ -24,98 +73,36 @@ folder = snapshot_download(repo_id="umaiku/faiss_index", repo_type="dataset", lo
|
|
| 24 |
|
| 25 |
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2")
|
| 26 |
|
| 27 |
-
vector_db = FAISS.load_local("faiss_index_mpnet_cos", embeddings, allow_dangerous_deserialization=True)
|
| 28 |
|
| 29 |
df = pd.read_csv("faiss_index/bger_cedh_db 1954-2024.csv")
|
| 30 |
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
history: list[tuple[str, str]],
|
| 34 |
-
system_message,
|
| 35 |
-
max_tokens,
|
| 36 |
-
temperature,
|
| 37 |
-
top_p,
|
| 38 |
-
score,
|
| 39 |
-
):
|
| 40 |
-
#messages = [{"role": "system", "content": system_message}]
|
| 41 |
|
| 42 |
-
|
| 43 |
-
print(system_message)
|
| 44 |
|
| 45 |
-
|
| 46 |
-
|
| 47 |
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
# retriever = vector_db.as_retriever(search_type="similarity_score_threshold", search_kwargs={"score_threshold": score, "k": 10})
|
| 54 |
-
# retriever = vector_db.as_retriever(search_type="similarity", search_kwargs={"k": 10})
|
| 55 |
-
# retriever = vector_db.as_retriever(search_type="mmr")
|
| 56 |
-
# documents = retriever.invoke(message)
|
| 57 |
-
|
| 58 |
-
documents_en = vector_db.similarity_search_with_score(prompt_en, k=4)
|
| 59 |
-
print(prompt_en)
|
| 60 |
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
documents_fr = vector_db.similarity_search_with_score(prompt_fr, k=4)
|
| 64 |
-
print(prompt_fr)
|
| 65 |
-
|
| 66 |
-
documents_it = vector_db.similarity_search_with_score(prompt_it, k=4)
|
| 67 |
-
print(prompt_it)
|
| 68 |
-
|
| 69 |
-
documents = documents_en + documents_de + documents_fr + documents_it
|
| 70 |
-
|
| 71 |
-
documents = sorted(documents, key=lambda x: x[1])[:4]
|
| 72 |
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
nb_char = 2000
|
| 76 |
-
|
| 77 |
-
#print(message)
|
| 78 |
-
print(f"* Documents found: {len(documents)}")
|
| 79 |
-
|
| 80 |
-
for doc in documents:
|
| 81 |
-
case_text = df[df["case_url"] == doc[0].metadata["case_url"]].case_text.values[0]
|
| 82 |
-
index = case_text.find(doc[0].page_content)
|
| 83 |
-
start = max(0, index - nb_char)
|
| 84 |
-
end = min(len(case_text), index + len(doc[0].page_content) + nb_char)
|
| 85 |
-
case_text_summary = case_text[start:end]
|
| 86 |
-
|
| 87 |
-
context += "#######" + spacer
|
| 88 |
-
context += "# Case number: " + doc[0].metadata["case_nb"] + spacer
|
| 89 |
-
context += "# Case source: " + ("Swiss Federal Court" if doc[0].metadata["case_ref"] == "ATF" else "European Court of Human Rights") + spacer
|
| 90 |
-
context += "# Case date: " + doc[0].metadata["case_date"] + spacer
|
| 91 |
-
context += "# Case url: " + doc[0].metadata["case_url"] + spacer
|
| 92 |
-
#context += "# Case text: " + doc[0].page_content + spacer
|
| 93 |
-
context += "Case extract: " + case_text_summary + spacer
|
| 94 |
-
|
| 95 |
-
#print("# Case number: " + doc.metadata["case_nb"] + spacer)
|
| 96 |
-
#print("# Case url: " + doc.metadata["case_url"] + spacer)
|
| 97 |
-
|
| 98 |
-
system_message += f"""A user is asking you the following question: {message}
|
| 99 |
-
Please answer the user in the same language that he used in his question using ONLY the following given context not any prior knowledge or information found on the internet.
|
| 100 |
-
# Context:
|
| 101 |
-
The following case extracts have been found in either Swiss Federal Court or European Court of Human Rights cases and could fit the question:
|
| 102 |
-
{context}
|
| 103 |
-
# Task:
|
| 104 |
-
If the retrieved context is not relevant cases or the issue has not been addressed within the context, just say "I can't find enough relevant information".
|
| 105 |
-
Don't make up an answer or give irrelevant information not requested by the user.
|
| 106 |
-
Otherwise, if relevant cases were found, answer in the user's question's language using the context that you found relevant and reference the sources, including the urls and dates.
|
| 107 |
-
# Instructions:
|
| 108 |
-
Always answer the user using the language used in his question: {message}
|
| 109 |
-
"""
|
| 110 |
-
|
| 111 |
-
print(system_message)
|
| 112 |
-
messages = [{"role": "system", "content": system_message}]
|
| 113 |
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
|
|
|
|
|
|
| 119 |
|
| 120 |
messages.append({"role": "user", "content": message})
|
| 121 |
|
|
@@ -129,6 +116,8 @@ Always answer the user using the language used in his question: {message}
|
|
| 129 |
top_p=top_p,
|
| 130 |
):
|
| 131 |
token = message.choices[0].delta.content
|
|
|
|
|
|
|
| 132 |
|
| 133 |
response += token
|
| 134 |
yield response
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
from transformers import pipeline
|
| 3 |
from huggingface_hub import InferenceClient, login, snapshot_download
|
| 4 |
+
from langchain_community.vectorstores import FAISS, DistanceStrategy
|
| 5 |
from langchain_huggingface import HuggingFaceEmbeddings
|
| 6 |
import os
|
| 7 |
import pandas as pd
|
| 8 |
from datetime import datetime
|
| 9 |
|
| 10 |
+
from smolagents import Tool, HfApiModel, ToolCallingAgent
|
| 11 |
+
from langchain_core.vectorstores import VectorStore
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class RetrieverTool(Tool):
|
| 15 |
+
name = "retriever"
|
| 16 |
+
description = "Using semantic similarity in German, French, English and Italian, retrieves some documents from the knowledge base that have the closest embeddings to the input query."
|
| 17 |
+
inputs = {
|
| 18 |
+
"query": {
|
| 19 |
+
"type": "string",
|
| 20 |
+
"description": "The query to perform. This should be semantically close to your target documents. Use the affirmative form rather than a question.",
|
| 21 |
+
}
|
| 22 |
+
}
|
| 23 |
+
output_type = "string"
|
| 24 |
+
|
| 25 |
+
def __init__(self, vectordb: VectorStore, **kwargs):
|
| 26 |
+
super().__init__(**kwargs)
|
| 27 |
+
self.vectordb = vectordb
|
| 28 |
+
|
| 29 |
+
def forward(self, query: str) -> str:
|
| 30 |
+
assert isinstance(query, str), "Your search query must be a string"
|
| 31 |
+
|
| 32 |
+
docs = self.vectordb.similarity_search(
|
| 33 |
+
query,
|
| 34 |
+
k=7,
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
df = pd.read_csv("bger_cedh_db 1954-2024.csv")
|
| 38 |
+
|
| 39 |
+
spacer = " \n"
|
| 40 |
+
context = ""
|
| 41 |
+
nb_char = 100
|
| 42 |
+
|
| 43 |
+
for doc in docs:
|
| 44 |
+
case_text = df[df["case_url"] == doc.metadata["case_url"]].case_text.values[0]
|
| 45 |
+
index = case_text.find(doc.page_content)
|
| 46 |
+
start = max(0, index - nb_char)
|
| 47 |
+
end = min(len(case_text), index + len(doc.page_content) + nb_char)
|
| 48 |
+
case_text_summary = case_text[start:end]
|
| 49 |
+
|
| 50 |
+
context += "#######" + spacer
|
| 51 |
+
context += "# Case number: " + doc.metadata["case_ref"] + " " + doc.metadata["case_nb"] + spacer
|
| 52 |
+
context += "# Case source: " + ("Swiss Federal Court" if doc.metadata["case_ref"] == "ATF" else "European Court of Human Rights") + spacer
|
| 53 |
+
context += "# Case date: " + doc.metadata["case_date"] + spacer
|
| 54 |
+
context += "# Case url: " + doc.metadata["case_url"] + spacer
|
| 55 |
+
#context += "# Case text: " + doc.page_content + spacer
|
| 56 |
+
context += "# Case extract: " + case_text_summary + spacer
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
return "\nRetrieved documents:\n" + context
|
| 60 |
+
|
| 61 |
|
| 62 |
"""
|
| 63 |
For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
|
|
|
|
| 65 |
HF_TOKEN=os.getenv('TOKEN')
|
| 66 |
login(HF_TOKEN)
|
| 67 |
|
| 68 |
+
model = "meta-llama/Meta-Llama-3-8B-Instruct"
|
|
|
|
|
|
|
| 69 |
|
| 70 |
client = InferenceClient(model)
|
| 71 |
|
|
|
|
| 73 |
|
| 74 |
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2")
|
| 75 |
|
| 76 |
+
vector_db = FAISS.load_local("faiss_index_mpnet_cos", embeddings, allow_dangerous_deserialization=True, distance_strategy=DistanceStrategy.COSINE)
|
| 77 |
|
| 78 |
df = pd.read_csv("faiss_index/bger_cedh_db 1954-2024.csv")
|
| 79 |
|
| 80 |
+
retriever_tool = RetrieverTool(vector_db)
|
| 81 |
+
agent = ToolCallingAgent(tools=[retriever_tool], model=HfApiModel(model))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
|
| 83 |
+
def respond(message, history: list[tuple[str, str]], system_message, max_tokens, temperature, top_p, score,):
|
|
|
|
| 84 |
|
| 85 |
+
print(datetime.now())
|
| 86 |
+
context = retriever_tool(question)
|
| 87 |
|
| 88 |
+
prompt = f"""Given the question and supporting documents below, give a comprehensive answer to the question.
|
| 89 |
+
Respond only to the question asked, response should be concise and relevant to the question and answer in the same language as the question.
|
| 90 |
+
Provide the number of the source document when relevant, as well as the link to the document.
|
| 91 |
+
If you cannot find information, do not give up and try calling your retriever again with different arguments!
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
|
| 93 |
+
Question:
|
| 94 |
+
{question}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
|
| 96 |
+
{context}
|
| 97 |
+
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
|
| 99 |
+
messages = [{"role": "user", "content": prompt}]
|
| 100 |
+
|
| 101 |
+
for val in history:
|
| 102 |
+
if val[0]:
|
| 103 |
+
messages.append({"role": "user", "content": val[0]})
|
| 104 |
+
if val[1]:
|
| 105 |
+
messages.append({"role": "assistant", "content": val[1]})
|
| 106 |
|
| 107 |
messages.append({"role": "user", "content": message})
|
| 108 |
|
|
|
|
| 116 |
top_p=top_p,
|
| 117 |
):
|
| 118 |
token = message.choices[0].delta.content
|
| 119 |
+
|
| 120 |
+
# answer = client.chat_completion(messages, temperature=0.1).choices[0].message.content
|
| 121 |
|
| 122 |
response += token
|
| 123 |
yield response
|