nurasaki commited on
Commit
5b68ef9
·
1 Parent(s): 9d1b8d4

Improved no-context response and logs

Browse files
Files changed (4) hide show
  1. app.py +4 -9
  2. config.yaml +10 -2
  3. src/tools.py +5 -3
  4. src/vectorstore.py +9 -6
app.py CHANGED
@@ -72,18 +72,11 @@ def completion(history, model, system_prompt: str, tools=None, chat_params=chat_
72
  }
73
  if tools:
74
  request_params.update({"tool_choice": "auto", "tools": tools})
75
-
76
- cprint("=" * 150, "green")
77
- print(json.dumps(request_params, indent=2, ensure_ascii=False))
78
- cprint("=" * 150, "green")
79
 
80
  return client.chat.completions.create(**request_params)
81
 
82
 
83
  def llm_in_loop(history, system_prompt, recursive):
84
-
85
- cprint(history, "cyan")
86
- cprint("~" * 150, "yellow")
87
 
88
  try:
89
  models = client.models.list()
@@ -113,6 +106,7 @@ def llm_in_loop(history, system_prompt, recursive):
113
  history[-1].content += chunk.choices[0].delta.content
114
  yield history[recursive:]
115
 
 
116
  arguments = clean_json_string(arguments) if arguments else "{}"
117
  arguments = json.loads(arguments)
118
 
@@ -122,7 +116,6 @@ def llm_in_loop(history, system_prompt, recursive):
122
  if name:
123
  try:
124
  result = str(tools[name].invoke(input=arguments))
125
- cprint(f"*** Tool {name} invoked with arguments {arguments}, result: {result}", "yellow")
126
 
127
  except Exception as err:
128
  result = f"💥 Error: {err}"
@@ -137,6 +130,7 @@ def llm_in_loop(history, system_prompt, recursive):
137
 
138
 
139
  def respond(message, history, additional_inputs):
 
140
  history.append(ChatMessage(role="user", content=message))
141
  yield from llm_in_loop(history, additional_inputs, -1)
142
 
@@ -144,6 +138,7 @@ def respond(message, history, additional_inputs):
144
 
145
  if __name__ == "__main__":
146
 
147
- system_prompt = gr.Textbox(label="System prompt", value=cfg.system_prompt_template, lines=10)
 
148
  demo = gr.ChatInterface(respond, type="messages", additional_inputs=[system_prompt])
149
  demo.launch()
 
72
  }
73
  if tools:
74
  request_params.update({"tool_choice": "auto", "tools": tools})
 
 
 
 
75
 
76
  return client.chat.completions.create(**request_params)
77
 
78
 
79
  def llm_in_loop(history, system_prompt, recursive):
 
 
 
80
 
81
  try:
82
  models = client.models.list()
 
106
  history[-1].content += chunk.choices[0].delta.content
107
  yield history[recursive:]
108
 
109
+ # Convert arguments to a valid JSON
110
  arguments = clean_json_string(arguments) if arguments else "{}"
111
  arguments = json.loads(arguments)
112
 
 
116
  if name:
117
  try:
118
  result = str(tools[name].invoke(input=arguments))
 
119
 
120
  except Exception as err:
121
  result = f"💥 Error: {err}"
 
130
 
131
 
132
  def respond(message, history, additional_inputs):
133
+
134
  history.append(ChatMessage(role="user", content=message))
135
  yield from llm_in_loop(history, additional_inputs, -1)
136
 
 
138
 
139
  if __name__ == "__main__":
140
 
141
+ # system_prompt = gr.State(value=cfg.system_prompt_template)
142
+ system_prompt = gr.Textbox(label="System prompt", value=cfg.system_prompt_template, lines=10, visible=False)
143
  demo = gr.ChatInterface(respond, type="messages", additional_inputs=[system_prompt])
144
  demo.launch()
config.yaml CHANGED
@@ -5,14 +5,22 @@ vdb:
5
  embeddings_model: BAAI/bge-m3
6
  number_of_contexts: 4
7
  vs_local_path: data/vdb
8
- embedding_score_threshold: 0.4
9
 
10
  # Context formatting parameters
11
  context_fmt: "Context document {num_document}:\n{document_content}"
12
  join_str: "\n\n"
13
  header_context_str: "The following is the context to help you answer the question (sorted from most to least relevant):\n\n"
14
  footer_context_str: "\n\nAnswer based only on the above context."
15
- no_context_str: "Answer 'no relevant context found'."
 
 
 
 
 
 
 
 
16
 
17
  # LLM client configuration
18
  # ================================================================================
 
5
  embeddings_model: BAAI/bge-m3
6
  number_of_contexts: 4
7
  vs_local_path: data/vdb
8
+ embedding_score_threshold: 0.3
9
 
10
  # Context formatting parameters
11
  context_fmt: "Context document {num_document}:\n{document_content}"
12
  join_str: "\n\n"
13
  header_context_str: "The following is the context to help you answer the question (sorted from most to least relevant):\n\n"
14
  footer_context_str: "\n\nAnswer based only on the above context."
15
+ no_context_str: |
16
+ Answer exactly with the following text respecting HTML tags:
17
+ "No relevant context found. Here are Aina Kit and Discord links for more information:
18
+ 1. Aina Kit official: <a href="https://langtech-bsc.gitbook.io/aina-kit">https://langtech-bsc.gitbook.io/aina-kit</a>
19
+ 2. Discord community: <a href="https://discord.com/invite/twy3Gn">https://discord.com/invite/twy3Gn</a>"
20
+
21
+ # https://discord.com/invite/twy3GnBCaY
22
+ # https://discord.com/invite/twy3GnBCaY
23
+ # https://langtech-bsc.gitbook.io/aina-kit
24
 
25
  # LLM client configuration
26
  # ================================================================================
src/tools.py CHANGED
@@ -4,6 +4,8 @@ from typing import Dict, Union, get_origin, get_args
4
  from pydantic import BaseModel, Field
5
  from types import UnionType
6
  import logging
 
 
7
  from src.vectorstore import VectorStore
8
  from omegaconf import OmegaConf
9
 
@@ -104,14 +106,14 @@ def tool_register(cls: BaseModel):
104
  @tool_register
105
  class retrieve_aina_data(ToolBase):
106
  """Retrieves relevant information from Aina Challenge vectorstore, based on the user's query."""
107
- logging.info("@tool_register: retrieve_aina_data()")
108
 
109
  query: str = Field(description="The user's input or question, used to search in Aina Challenge vectorstore.")
110
- logging.info(f"query: {query}")
111
 
112
  @classmethod
113
  def invoke(cls, input: Dict) -> str:
114
- logging.info(f"retrieve_aina_data.invoke() input: {input}")
115
 
116
  # Check if the input is a dictionary
117
  query = input.get("query", None)
 
4
  from pydantic import BaseModel, Field
5
  from types import UnionType
6
  import logging
7
+ log = logging.getLogger(__name__)
8
+
9
  from src.vectorstore import VectorStore
10
  from omegaconf import OmegaConf
11
 
 
106
  @tool_register
107
  class retrieve_aina_data(ToolBase):
108
  """Retrieves relevant information from Aina Challenge vectorstore, based on the user's query."""
109
+ log.info("@tool_register: retrieve_aina_data()")
110
 
111
  query: str = Field(description="The user's input or question, used to search in Aina Challenge vectorstore.")
112
+ log.info(f"query: {query}")
113
 
114
  @classmethod
115
  def invoke(cls, input: Dict) -> str:
116
+ log.info(f"retrieve_aina_data.invoke() input: {input}")
117
 
118
  # Check if the input is a dictionary
119
  query = input.get("query", None)
src/vectorstore.py CHANGED
@@ -3,6 +3,7 @@ from langchain_community.vectorstores import FAISS
3
  from langchain_huggingface import HuggingFaceEmbeddings
4
  from huggingface_hub import snapshot_download
5
  import logging
 
6
 
7
  from termcolor import cprint
8
 
@@ -55,7 +56,7 @@ class VectorStore:
55
  Defaults to "\n\nAnswer based only on the above context.".
56
  """
57
 
58
- logging.info("Loading vectorstore...")
59
 
60
  # Retrieval parameters
61
  self.number_of_contexts = number_of_contexts
@@ -69,22 +70,22 @@ class VectorStore:
69
  self.no_context_str = no_context_str
70
 
71
  embeddings = HuggingFaceEmbeddings(model_name=embeddings_model)
72
- logging.info(f"Loaded embeddings model: {embeddings_model}")
73
 
74
  if vs_hf_path:
75
  hf_vectorstore = snapshot_download(repo_id=vs_hf_path)
76
  self.vdb = FAISS.load_local(hf_vectorstore, embeddings, allow_dangerous_deserialization=True)
77
- logging.info(f"Loaded vectorstore from {vs_hf_path}")
78
  else:
79
  self.vdb = FAISS.load_local(vs_local_path, embeddings, allow_dangerous_deserialization=True)
80
- logging.info(f"Loaded vectorstore from {vs_local_path}")
81
 
82
 
83
  def get_context(self, query,):
84
 
85
  # Retrieve documents
86
  results = self.vdb.similarity_search_with_relevance_scores(query=query, k=self.number_of_contexts, score_threshold=self.embedding_score_threshold)
87
- logging.info(f"Retrieved {len(results)} documents from the vectorstore.")
88
 
89
  # Return formatted context
90
  return self._beautiful_context(results)
@@ -92,7 +93,7 @@ class VectorStore:
92
 
93
  def _beautiful_context(self, docs):
94
 
95
- logging.info(f"Formatting {len(docs)} contexts...")
96
 
97
  # If no documents are retrieved, return the no_context_str
98
  if not docs:
@@ -101,6 +102,8 @@ class VectorStore:
101
  contexts = []
102
  for i, doc in enumerate(docs):
103
 
 
 
104
  # Format each context document using the provided template
105
  context = self.context_fmt.format(num_document=i + 1, document_content=doc[0].page_content)
106
  contexts.append(context)
 
3
  from langchain_huggingface import HuggingFaceEmbeddings
4
  from huggingface_hub import snapshot_download
5
  import logging
6
+ log = logging.getLogger(__name__)
7
 
8
  from termcolor import cprint
9
 
 
56
  Defaults to "\n\nAnswer based only on the above context.".
57
  """
58
 
59
+ log.info("Loading vectorstore...")
60
 
61
  # Retrieval parameters
62
  self.number_of_contexts = number_of_contexts
 
70
  self.no_context_str = no_context_str
71
 
72
  embeddings = HuggingFaceEmbeddings(model_name=embeddings_model)
73
+ log.info(f"Loaded embeddings model: {embeddings_model}")
74
 
75
  if vs_hf_path:
76
  hf_vectorstore = snapshot_download(repo_id=vs_hf_path)
77
  self.vdb = FAISS.load_local(hf_vectorstore, embeddings, allow_dangerous_deserialization=True)
78
+ log.info(f"Loaded vectorstore from {vs_hf_path}")
79
  else:
80
  self.vdb = FAISS.load_local(vs_local_path, embeddings, allow_dangerous_deserialization=True)
81
+ log.info(f"Loaded vectorstore from {vs_local_path}")
82
 
83
 
84
  def get_context(self, query,):
85
 
86
  # Retrieve documents
87
  results = self.vdb.similarity_search_with_relevance_scores(query=query, k=self.number_of_contexts, score_threshold=self.embedding_score_threshold)
88
+ log.info(f"Retrieved {len(results)} documents from the vectorstore.")
89
 
90
  # Return formatted context
91
  return self._beautiful_context(results)
 
93
 
94
  def _beautiful_context(self, docs):
95
 
96
+ log.info(f"Formatting {len(docs)} contexts...")
97
 
98
  # If no documents are retrieved, return the no_context_str
99
  if not docs:
 
102
  contexts = []
103
  for i, doc in enumerate(docs):
104
 
105
+ log.info(f"Document {i+1} (score: {doc[1]:.4f}): {repr(doc[0].page_content[:100])}...")
106
+
107
  # Format each context document using the provided template
108
  context = self.context_fmt.format(num_document=i + 1, document_content=doc[0].page_content)
109
  contexts.append(context)