Spaces:
Paused
Paused
Commit
·
b7068fd
1
Parent(s):
cf57696
question and answer postprocessing
Browse files- benchmark/__main__.py +1 -0
- qa_engine/qa_engine.py +30 -1
benchmark/__main__.py
CHANGED
|
@@ -33,6 +33,7 @@ def main():
|
|
| 33 |
|
| 34 |
wandb.init(
|
| 35 |
project='HF-Docs-QA',
|
|
|
|
| 36 |
name=f'{config.question_answering_model_id} - {config.embedding_model_id} - {config.index_repo_id}',
|
| 37 |
mode='run', # run/disabled
|
| 38 |
config=filtered_config
|
|
|
|
| 33 |
|
| 34 |
wandb.init(
|
| 35 |
project='HF-Docs-QA',
|
| 36 |
+
entity='hf-qa-bot',
|
| 37 |
name=f'{config.question_answering_model_id} - {config.embedding_model_id} - {config.index_repo_id}',
|
| 38 |
mode='run', # run/disabled
|
| 39 |
config=filtered_config
|
qa_engine/qa_engine.py
CHANGED
|
@@ -228,6 +228,33 @@ class QAEngine():
|
|
| 228 |
self.reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-12-v2')
|
| 229 |
|
| 230 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 231 |
def get_response(self, question: str, messages_context: str = '') -> Response:
|
| 232 |
"""
|
| 233 |
Generate an answer to the specified question.
|
|
@@ -271,7 +298,9 @@ class QAEngine():
|
|
| 271 |
response.set_sources(sources=[str(m['source']) for m in metadata])
|
| 272 |
|
| 273 |
logger.info('Running LLM chain')
|
| 274 |
-
|
|
|
|
|
|
|
| 275 |
response.set_answer(answer)
|
| 276 |
logger.info('Received answer')
|
| 277 |
|
|
|
|
| 228 |
self.reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-12-v2')
|
| 229 |
|
| 230 |
|
| 231 |
+
@staticmethod
|
| 232 |
+
def _preprocess_question(question: str) -> str:
|
| 233 |
+
if question[-1] != '?':
|
| 234 |
+
question += '?'
|
| 235 |
+
return question
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
@staticmethod
|
| 239 |
+
def _postprocess_answer(answer: str) -> str:
|
| 240 |
+
'''
|
| 241 |
+
Preprocess the answer by removing unnecessary sequences and stop sequences.
|
| 242 |
+
'''
|
| 243 |
+
REMOVE_SEQUENCES = [
|
| 244 |
+
'Factually: ', 'Answer: ', '<<SYS>>', '<</SYS>>', '[INST]', '[/INST]'
|
| 245 |
+
]
|
| 246 |
+
STOP_SEQUENCES = [
|
| 247 |
+
'\nUser:', '\nYou:'
|
| 248 |
+
]
|
| 249 |
+
for seq in REMOVE_SEQUENCES:
|
| 250 |
+
answer = answer.replace(seq, '')
|
| 251 |
+
for seq in STOP_SEQUENCES:
|
| 252 |
+
if seq in answer:
|
| 253 |
+
answer = answer[:answer.index(seq)]
|
| 254 |
+
answer = answer.strip()
|
| 255 |
+
return answer
|
| 256 |
+
|
| 257 |
+
|
| 258 |
def get_response(self, question: str, messages_context: str = '') -> Response:
|
| 259 |
"""
|
| 260 |
Generate an answer to the specified question.
|
|
|
|
| 298 |
response.set_sources(sources=[str(m['source']) for m in metadata])
|
| 299 |
|
| 300 |
logger.info('Running LLM chain')
|
| 301 |
+
question_processed = QAEngine._preprocess_question(question)
|
| 302 |
+
answer = self.llm_chain.run(question=question_processed, context=context)
|
| 303 |
+
answer = QAEngine._postprocess_answer(answer)
|
| 304 |
response.set_answer(answer)
|
| 305 |
logger.info('Received answer')
|
| 306 |
|