Spaces:
Runtime error
Runtime error
| import json | |
| import os | |
| from datetime import datetime | |
| import dotenv | |
| import lancedb | |
| import requests | |
| from datasets import load_dataset | |
| from fasthtml.common import * # noqa | |
| from huggingface_hub import login, whoami | |
| # def get_images(query: str): | |
| # url = "http://147.189.194.113:80/get_pages" | |
| # response = requests.get(url, params={"query": query}) | |
| # return response.json() | |
| server_ip = "147.189.194.113" | |
| # server_ip = "47.47.180.31" | |
| def get_images(query: str): | |
| url = f"http://{server_ip}:80/get_pages" | |
| response = requests.get(url, params={"query": query}) | |
| return response.json() | |
| # def rerank_api(query, docs): | |
| # url = "http://47.47.180.31:80/rerank" | |
| # params = {"query": query, "docs": docs} | |
| # response = requests.get(url, params=params) | |
| # return response.json() | |
| def rerank_api(query, docs): | |
| url = f"http://{server_ip}:80/rerank" | |
| data = {"query": query, "docs": docs} | |
| response = requests.post(url, json=data) # Use POST and send data as JSON | |
| return response.json() | |
| dotenv.load_dotenv() | |
| login(token=os.environ.get("HF_TOKEN")) | |
| hf_user = whoami(os.environ.get("HF_TOKEN"))["name"] | |
| HF_REPO_ID_TXT = f"{hf_user}/zotero-answer-ai-texts" | |
| abstract_ds = load_dataset(HF_REPO_ID_TXT, "abstracts")["train"] | |
| article_ds = load_dataset(HF_REPO_ID_TXT, "articles")["train"] | |
| # ranker = Reranker("answerdotai/answerai-colbert-small-v1", model_type="colbert") | |
| uri = "data/zotero-fts" | |
| db = lancedb.connect(uri) | |
| id2abstract = {example["arxiv_id"]: example["abstract"] for example in abstract_ds} | |
| id2content = {example["arxiv_id"]: example["contents"] for example in article_ds} | |
| id2title = {example["arxiv_id"]: example["title"] for example in article_ds} | |
| arxiv_ids = set(list(id2abstract.keys())) | |
| data = [] | |
| for arxiv_id in arxiv_ids: | |
| abstract = id2abstract[arxiv_id] | |
| title = id2title[arxiv_id] | |
| full_text = title | |
| for item in id2content[arxiv_id]: | |
| full_text += f"{item['title']}\n\n{item['content']}" | |
| data.append( | |
| { | |
| "arxiv_id": arxiv_id, | |
| "title": title, | |
| "abstract": abstract, | |
| "full_text": full_text, | |
| } | |
| ) | |
| table = db.create_table("articles", data=data, mode="overwrite") | |
| table.create_fts_index("full_text", replace=True) | |
| # format results ---- | |
| def _format_results(results): | |
| ret = [] | |
| for result in results: | |
| arx_id = result["arxiv_id"] | |
| title = result["title"] | |
| abstract = result["abstract"] | |
| if "Abstract\n\n" in abstract: | |
| abstract = abstract.split("Abstract\n\n")[-1] | |
| this_ex = { | |
| "title": title, | |
| "url": f"https://arxiv.org/abs/{arx_id}", | |
| "abstract": abstract, | |
| } | |
| ret.append(this_ex) | |
| return ret | |
| def retrieve_and_rerank(query, k=3): | |
| # retrieve --- | |
| n_fetch = 25 | |
| retrieved = ( | |
| table.search(query, vector_column_name="", query_type="fts") | |
| .limit(n_fetch) | |
| .select(["arxiv_id", "title", "abstract"]) | |
| .to_list() | |
| ) | |
| print(f"Retrieved {len(retrieved)} documents") | |
| # re-rank | |
| docs = [f"{item['title']} {item['abstract']}" for item in retrieved] | |
| # results = ranker.rank(query=query, docs=docs) | |
| ranked_doc_ids = rerank_api(query, docs)["ranked_doc_ids"][:k] | |
| # ranked_doc_ids = [] | |
| # for result in results[:k]: | |
| # ranked_doc_ids.append(result.doc_id) | |
| final_results = [retrieved[idx] for idx in ranked_doc_ids] | |
| final_results = _format_results(final_results) | |
| return final_results | |
| ########################################################################### | |
| # FastHTML app ----- | |
| ########################################################################### | |
| style = Style(""" | |
| :root { | |
| color-scheme: dark; | |
| } | |
| body { | |
| max-width: 1200px; | |
| margin: 0 auto; | |
| padding: 20px; | |
| line-height: 1.6; | |
| } | |
| #query { | |
| width: 100%; | |
| margin-bottom: 1rem; | |
| } | |
| #search-form button { | |
| width: 100%; | |
| } | |
| #search-results, #log-entries { | |
| margin-top: 2rem; | |
| } | |
| .log-entry { | |
| border: 1px solid #ccc; | |
| padding: 10px; | |
| margin-bottom: 10px; | |
| } | |
| .log-entry pre { | |
| white-space: pre-wrap; | |
| word-wrap: break-word; | |
| } | |
| .htmx-indicator { | |
| display: none; | |
| } | |
| .htmx-request .htmx-indicator { | |
| display: inline-block; | |
| } | |
| .spinner { | |
| display: inline-block; | |
| width: 2.5em; | |
| height: 2.5em; | |
| border: 0.3em solid rgba(255,255,255,.3); | |
| border-radius: 50%; | |
| border-top-color: #fff; | |
| animation: spin 1s ease-in-out infinite; | |
| margin-left: 10px; | |
| vertical-align: middle; | |
| } | |
| @keyframes spin { | |
| to { transform: rotate(360deg); } | |
| } | |
| .searching-text { | |
| font-size: 1.2em; | |
| font-weight: bold; | |
| color: #fff; | |
| margin-right: 10px; | |
| vertical-align: middle; | |
| } | |
| .image-results { | |
| display: flex; | |
| flex-wrap: wrap; | |
| gap: 10px; | |
| margin-top: 20px; | |
| } | |
| .image-result { | |
| width: calc(33% - 10px); | |
| text-align: center; | |
| } | |
| .image-result img { | |
| max-width: 100%; | |
| height: auto; | |
| border-radius: 5px; | |
| } | |
| """) | |
| # get the fast app and route | |
| app, rt = fast_app(hdrs=(style,)) | |
| # Initialize a database to store search logs -- | |
| db = database("log_data/search_logs.db") | |
| search_logs = db.t.search_logs | |
| if search_logs not in db.t: | |
| search_logs.create( | |
| id=int, | |
| timestamp=str, | |
| query=str, | |
| results=str, | |
| pk="id", | |
| ) | |
| SearchLog = search_logs.dataclass() | |
| def insert_log_entry(log_entry): | |
| "Insert a log entry into the database" | |
| return search_logs.insert( | |
| SearchLog( | |
| timestamp=log_entry["timestamp"].isoformat(), | |
| query=log_entry["query"], | |
| results=json.dumps(log_entry["results"]), | |
| ) | |
| ) | |
| async def get(): | |
| query_form = Form( | |
| Textarea(id="query", name="query", placeholder="Enter your query..."), | |
| Button("Submit", type="submit"), | |
| Div( | |
| Span("Searching...", cls="searching-text htmx-indicator"), | |
| Span(cls="spinner htmx-indicator"), | |
| cls="indicator-container", | |
| ), | |
| id="search-form", | |
| hx_post="/search", | |
| hx_target="#search-results", | |
| hx_indicator=".indicator-container", | |
| ) | |
| results_div = Div(Div(id="search-results", cls="results-container")) | |
| view_logs_link = A("View Logs", href="/logs", cls="view-logs-link") | |
| return Titled( | |
| "Zotero Search", Div(query_form, results_div, view_logs_link, cls="container") | |
| ) | |
| def SearchResult(result): | |
| "Custom component for displaying a search result" | |
| return Card( | |
| H4(A(result["title"], href=result["url"], target="_blank")), | |
| P(result["abstract"]), | |
| footer=A("Read more →", href=result["url"], target="_blank"), | |
| ) | |
| # def base64_to_pil(base64_string): | |
| # # Remove the "data:image/png;base64," part if it exists | |
| # if "base64," in base64_string: | |
| # base64_string = base64_string.split("base64,")[1] | |
| # # Decode the base64 string | |
| # img_data = base64.b64decode(base64_string) | |
| # # Open the image using PIL | |
| # img = Image.open(BytesIO(img_data)) | |
| # return img | |
| # def process_image(image, max_size=(500, 500), quality=85): | |
| # pil_image = base64_to_pil(image) | |
| # img_byte_arr = io.BytesIO() | |
| # pil_image.thumbnail(max_size) | |
| # pil_image.save(img_byte_arr, format="JPEG", quality=quality, optimize=True) | |
| # return f"data:image/jpeg;base64,{base64.b64encode(img_byte_arr.getvalue()).decode('utf-8')}" | |
| def ImageResult(image): | |
| return Div( | |
| Img(src=f"data:image/jpeg;base64,{image}", alt="arxiv image"), | |
| cls="image-result", | |
| ) | |
| # def ImageResult(image): | |
| # return Div( | |
| # Img(src=process_image(image), alt="arxiv image"), | |
| # cls="image-result", | |
| # ) | |
| def log_query_and_results(query, results): | |
| log_entry = { | |
| "timestamp": datetime.now(), | |
| "query": query, | |
| "results": [{"title": r["title"], "url": r["url"]} for r in results], | |
| } | |
| insert_log_entry(log_entry) | |
| async def post(query: str): | |
| image_results = get_images(query) | |
| # print(image_results) | |
| results = retrieve_and_rerank(query) | |
| log_query_and_results(query, results) | |
| return Div( | |
| Br(), | |
| H3("Byaldi Results"), | |
| Div(*[ImageResult(img) for img in image_results], cls="image-results"), | |
| Br(), | |
| H3("Text Results"), | |
| Div(*[SearchResult(r) for r in results], id="text-results"), | |
| id="search-results", | |
| ) | |
| # return Div(*[SearchResult(r) for r in results], id="search-results") | |
| def LogEntry(entry): | |
| return Div( | |
| H4(f"Query: {entry.query}"), | |
| P(f"Timestamp: {entry.timestamp}"), | |
| H5("Results:"), | |
| Pre(entry.results), | |
| cls="log-entry", | |
| ) | |
| async def get(): | |
| logs = search_logs(order_by="-id", limit=50) # Get the latest 50 logs | |
| log_entries = [LogEntry(log) for log in logs] | |
| return Titled( | |
| "Logs", | |
| Div( | |
| H2("Recent Search Logs"), | |
| Div(*log_entries, id="log-entries"), | |
| A("Back to Search", href="/", cls="back-link"), | |
| cls="container", | |
| ), | |
| ) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("PORT", 7860))) | |
| # run_uv() | |