Spaces:
Running
Running
add image
Browse files- app.py +47 -30
- asset/teaser.png +0 -0
app.py
CHANGED
|
@@ -12,6 +12,9 @@ import openai
|
|
| 12 |
import google.generativeai as genai
|
| 13 |
|
| 14 |
|
|
|
|
|
|
|
|
|
|
| 15 |
def query_gpt_model(
|
| 16 |
prompt: str,
|
| 17 |
llm: str = 'gpt-3.5-turbo-1106',
|
|
@@ -81,8 +84,8 @@ def query_model(
|
|
| 81 |
else:
|
| 82 |
raise ValueError('Unexpected model_name: ', model_name)
|
| 83 |
|
| 84 |
-
# Load QuALITY dataset
|
| 85 |
|
|
|
|
| 86 |
_ONE2ONE_FIELDS = (
|
| 87 |
'article',
|
| 88 |
'article_id',
|
|
@@ -166,7 +169,6 @@ def quality_gutenberg_parser(raw_article):
|
|
| 166 |
return ' '.join(lines)
|
| 167 |
|
| 168 |
|
| 169 |
-
|
| 170 |
# ReadAgent (1) Episode Pagination
|
| 171 |
prompt_pagination_template = """
|
| 172 |
You are given a passage that is taken from a larger text (article, book, ...) and some numbered labels between the paragraphs in the passage.
|
|
@@ -250,8 +252,6 @@ def quality_pagination(example,
|
|
| 250 |
text_output += f"\n\n[Pagination] Done with {len(pages)} pages"
|
| 251 |
return pages, text_output
|
| 252 |
|
| 253 |
-
# pages = quality_pagination(example)
|
| 254 |
-
|
| 255 |
|
| 256 |
# ReadAgent (2) Memory Gisting
|
| 257 |
prompt_shorten_template = """
|
|
@@ -284,8 +284,6 @@ def quality_gisting(example, pages, model_name, client=None, word_limit=600, sta
|
|
| 284 |
text_output += f"\n\ncompression rate {round(100.0 - gist_word_count/word_count*100, 2)}% ({gist_word_count}/{word_count})"
|
| 285 |
return output, text_output
|
| 286 |
|
| 287 |
-
# example_with_gists = quality_gisting(example, pages)
|
| 288 |
-
|
| 289 |
|
| 290 |
# ReadAgent (3) Look-Up
|
| 291 |
prompt_lookup_template = """
|
|
@@ -405,6 +403,7 @@ def quality_parallel_lookup(example, model_name, client, verbose=True):
|
|
| 405 |
return text_outputs
|
| 406 |
|
| 407 |
|
|
|
|
| 408 |
def query_model_with_quality(
|
| 409 |
index: int,
|
| 410 |
model_name: str = 'gemini-pro',
|
|
@@ -421,38 +420,52 @@ def query_model_with_quality(
|
|
| 421 |
genai.configure(api_key=api_key)
|
| 422 |
|
| 423 |
example = quality_dev[index]
|
|
|
|
| 424 |
pages, pagination = quality_pagination(example, model_name, client)
|
| 425 |
print('Finish Pagination.')
|
| 426 |
example_with_gists, gisting = quality_gisting(example, pages, model_name, client)
|
| 427 |
print('Finish Gisting.')
|
| 428 |
answers = quality_parallel_lookup(example_with_gists, model_name, client)
|
| 429 |
-
return prompt_pagination_template, pagination, prompt_shorten_template, gisting, prompt_lookup_template, '\n\n'.join(answers)
|
|
|
|
| 430 |
|
| 431 |
|
| 432 |
-
llm_api_options = ['gemini-pro', 'gemini-1.5-flash', 'gpt-3.5-turbo-1106']
|
| 433 |
-
|
| 434 |
with gr.Blocks() as demo:
|
| 435 |
gr.Markdown(
|
| 436 |
"""
|
| 437 |
# A Human-Inspired Reading Agent with Gist Memory of Very Long Contexts
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 438 |
""")
|
| 439 |
-
with gr.
|
| 440 |
-
|
| 441 |
-
|
| 442 |
-
|
| 443 |
-
|
| 444 |
-
|
| 445 |
-
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
|
| 452 |
-
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
|
| 456 |
|
| 457 |
button.click(
|
| 458 |
fn=query_model_with_quality,
|
|
@@ -462,12 +475,16 @@ with gr.Blocks() as demo:
|
|
| 462 |
llm_api_key,
|
| 463 |
],
|
| 464 |
outputs=[
|
| 465 |
-
prompt_pagination, pagination_results,
|
| 466 |
-
prompt_gisting, gisting_results,
|
| 467 |
-
prompt_lookup, lookup_qa_results,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 468 |
]
|
| 469 |
)
|
| 470 |
|
| 471 |
|
| 472 |
if __name__ == '__main__':
|
| 473 |
-
demo.launch()
|
|
|
|
| 12 |
import google.generativeai as genai
|
| 13 |
|
| 14 |
|
| 15 |
+
# Set up LLM APIs
|
| 16 |
+
llm_api_options = ['gemini-pro', 'gemini-1.5-flash', 'gpt-3.5-turbo-1106']
|
| 17 |
+
|
| 18 |
def query_gpt_model(
|
| 19 |
prompt: str,
|
| 20 |
llm: str = 'gpt-3.5-turbo-1106',
|
|
|
|
| 84 |
else:
|
| 85 |
raise ValueError('Unexpected model_name: ', model_name)
|
| 86 |
|
|
|
|
| 87 |
|
| 88 |
+
# Load QuALITY dataset
|
| 89 |
_ONE2ONE_FIELDS = (
|
| 90 |
'article',
|
| 91 |
'article_id',
|
|
|
|
| 169 |
return ' '.join(lines)
|
| 170 |
|
| 171 |
|
|
|
|
| 172 |
# ReadAgent (1) Episode Pagination
|
| 173 |
prompt_pagination_template = """
|
| 174 |
You are given a passage that is taken from a larger text (article, book, ...) and some numbered labels between the paragraphs in the passage.
|
|
|
|
| 252 |
text_output += f"\n\n[Pagination] Done with {len(pages)} pages"
|
| 253 |
return pages, text_output
|
| 254 |
|
|
|
|
|
|
|
| 255 |
|
| 256 |
# ReadAgent (2) Memory Gisting
|
| 257 |
prompt_shorten_template = """
|
|
|
|
| 284 |
text_output += f"\n\ncompression rate {round(100.0 - gist_word_count/word_count*100, 2)}% ({gist_word_count}/{word_count})"
|
| 285 |
return output, text_output
|
| 286 |
|
|
|
|
|
|
|
| 287 |
|
| 288 |
# ReadAgent (3) Look-Up
|
| 289 |
prompt_lookup_template = """
|
|
|
|
| 403 |
return text_outputs
|
| 404 |
|
| 405 |
|
| 406 |
+
# ReadAgent
|
| 407 |
def query_model_with_quality(
|
| 408 |
index: int,
|
| 409 |
model_name: str = 'gemini-pro',
|
|
|
|
| 420 |
genai.configure(api_key=api_key)
|
| 421 |
|
| 422 |
example = quality_dev[index]
|
| 423 |
+
article = f"[Title: {example['title']}]\n\n{example['article']}"
|
| 424 |
pages, pagination = quality_pagination(example, model_name, client)
|
| 425 |
print('Finish Pagination.')
|
| 426 |
example_with_gists, gisting = quality_gisting(example, pages, model_name, client)
|
| 427 |
print('Finish Gisting.')
|
| 428 |
answers = quality_parallel_lookup(example_with_gists, model_name, client)
|
| 429 |
+
# return prompt_pagination_template, pagination, prompt_shorten_template, gisting, prompt_lookup_template, '\n\n'.join(answers)
|
| 430 |
+
return article, pagination, gisting, '\n\n'.join(answers)
|
| 431 |
|
| 432 |
|
|
|
|
|
|
|
| 433 |
with gr.Blocks() as demo:
|
| 434 |
gr.Markdown(
|
| 435 |
"""
|
| 436 |
# A Human-Inspired Reading Agent with Gist Memory of Very Long Contexts
|
| 437 |
+
|
| 438 |
+
[[website]](https://read-agent.github.io/)
|
| 439 |
+
[[view on huggingface]](https://huggingface.co/spaces/ReadAgent/read-agent)
|
| 440 |
+
[[arXiv]](https://arxiv.org/abs/2402.09727)
|
| 441 |
+
[[OpenReview]](https://openreview.net/forum?id=OTmcsyEO5G)
|
| 442 |
+
|
| 443 |
+

|
| 444 |
+
|
| 445 |
+
The demo below showcases a version of the ReadAgent algorithm, which is nspired by how humans interactively read long documents.
|
| 446 |
+
We implement ReadAgent as a simple prompting system that uses the advanced language capabilities of LLMs to (1) decide what content to store together in a memory episode (**Episode Pagination**), (2) compress those memory episodes into short episodic memories called gist memories (**Memory Gisting**), and (3) take actions to look up passages in the original text if ReadAgent needs to remind itself of relevant details to complete a task (**Parallel Lookup and QA**)
|
| 447 |
+
This demo can handle long-document reading comprehension tasks ([QuALITY](https://arxiv.org/abs/2112.08608); max 6,000 words) efficiently.
|
| 448 |
+
|
| 449 |
+
To get started, you can choose an index of QuALITY dataset.
|
| 450 |
+
This demo uses Gemini API or OpenAI API so it requires the corresponding API key.
|
| 451 |
""")
|
| 452 |
+
with gr.Row():
|
| 453 |
+
with gr.Column():
|
| 454 |
+
llm_options = gr.Radio(llm_api_options, label="Backend LLM API", value='gemini-pro')
|
| 455 |
+
llm_api_key = gr.Textbox(
|
| 456 |
+
label="Paste your OpenAI API key (sk-...) or Gemini API key",
|
| 457 |
+
lines=1,
|
| 458 |
+
type="password",
|
| 459 |
+
)
|
| 460 |
+
index = gr.Dropdown(list(range(len(quality_dev))), value=13, label="QuALITY Index")
|
| 461 |
+
button = gr.Button("Execute")
|
| 462 |
+
original_article = gr.Textbox(label="Original Article", lines=20)
|
| 463 |
+
# prompt_pagination = gr.Textbox(label="Episode Pagination Prompt Template", lines=5)
|
| 464 |
+
pagination_results = gr.Textbox(label="(1) Episode Pagination", lines=20)
|
| 465 |
+
# prompt_gisting = gr.Textbox(label="Memory Gisting Prompt Template", lines=5)
|
| 466 |
+
gisting_results = gr.Textbox(label="(2) Memory Gisting", lines=20)
|
| 467 |
+
# prompt_lookup = gr.Textbox(label="Parallel Lookup Prompt Template", lines=5)
|
| 468 |
+
lookup_qa_results = gr.Textbox(label="(3) Parallel Lookup and QA", lines=20)
|
| 469 |
|
| 470 |
button.click(
|
| 471 |
fn=query_model_with_quality,
|
|
|
|
| 475 |
llm_api_key,
|
| 476 |
],
|
| 477 |
outputs=[
|
| 478 |
+
# prompt_pagination, pagination_results,
|
| 479 |
+
# prompt_gisting, gisting_results,
|
| 480 |
+
# prompt_lookup, lookup_qa_results,
|
| 481 |
+
original_article,
|
| 482 |
+
pagination_results,
|
| 483 |
+
gisting_results,
|
| 484 |
+
lookup_qa_results,
|
| 485 |
]
|
| 486 |
)
|
| 487 |
|
| 488 |
|
| 489 |
if __name__ == '__main__':
|
| 490 |
+
demo.launch(allowed_paths=['./asset/teaser.png'])
|
asset/teaser.png
ADDED
|