Spaces:
Sleeping
Sleeping
| import json | |
| import os | |
| from datetime import datetime | |
| from typing import ClassVar | |
| # import dotenv | |
| import lancedb | |
| import srsly | |
| from fasthtml.common import * # noqa | |
| from fasthtml_hf import setup_hf_backup | |
| from huggingface_hub import snapshot_download | |
| from lancedb.embeddings.base import TextEmbeddingFunction | |
| from lancedb.embeddings.registry import register | |
| from lancedb.pydantic import LanceModel, Vector | |
| from lancedb.rerankers import CohereReranker, ColbertReranker | |
| from lancedb.util import attempt_import_or_raise | |
| # dotenv.load_dotenv() | |
| # download the zotero index (~1200 papers as of July 24, currently hosted on HF) ---- | |
| def download_data(): | |
| snapshot_download( | |
| repo_id="rbiswasfc/zotero_db", | |
| repo_type="dataset", | |
| local_dir="./data", | |
| token=os.environ["HF_TOKEN"], | |
| ) | |
| print("Data downloaded!") | |
| if not os.path.exists( | |
| "./data/.lancedb_zotero_v0" | |
| ): # TODO: implement a better check / refresh mechanism | |
| download_data() | |
| # cohere embedding utils ---- | |
| class CohereEmbeddingFunction_2(TextEmbeddingFunction): | |
| name: str = "embed-english-v3.0" | |
| client: ClassVar = None | |
| def ndims(self): | |
| return 768 | |
| def generate_embeddings(self, texts): | |
| """ | |
| Get the embeddings for the given texts | |
| Parameters | |
| ---------- | |
| texts: list[str] or np.ndarray (of str) | |
| The texts to embed | |
| """ | |
| # TODO retry, rate limit, token limit | |
| self._init_client() | |
| rs = CohereEmbeddingFunction_2.client.embed( | |
| texts=texts, model=self.name, input_type="search_document" | |
| ) | |
| return [emb for emb in rs.embeddings] | |
| def _init_client(self): | |
| cohere = attempt_import_or_raise("cohere") | |
| if CohereEmbeddingFunction_2.client is None: | |
| CohereEmbeddingFunction_2.client = cohere.Client( | |
| os.environ["COHERE_API_KEY"] | |
| ) | |
| COHERE_EMBEDDER = CohereEmbeddingFunction_2.create() | |
| # LanceDB model ---- | |
| class ArxivModel(LanceModel): | |
| text: str = COHERE_EMBEDDER.SourceField() | |
| vector: Vector(1024) = COHERE_EMBEDDER.VectorField() | |
| title: str | |
| paper_title: str | |
| content_type: str | |
| arxiv_id: str | |
| VERSION = "0.0.0a" | |
| DB = lancedb.connect("./data/.lancedb_zotero_v0") | |
| ID_TO_ABSTRACT = srsly.read_json("./data/id_to_abstract.json") | |
| RERANKERS = {"colbert": ColbertReranker(), "cohere": CohereReranker()} | |
| TBL = DB.open_table("arxiv_zotero_v0") | |
| # format results ---- | |
| def _format_results(arxiv_refs): | |
| results = [] | |
| for arx_id, paper_title in arxiv_refs.items(): | |
| abstract = ID_TO_ABSTRACT.get(arx_id, "") | |
| # these are all ugly hacks because the data preprocessing is poor. to be fixed v soon. | |
| if "Abstract\n\n" in abstract: | |
| abstract = abstract.split("Abstract\n\n")[-1] | |
| if paper_title in abstract: | |
| abstract = abstract.split(paper_title)[-1] | |
| if abstract.startswith("\n"): | |
| abstract = abstract[1:] | |
| if "\n\n" in abstract[:20]: | |
| abstract = "\n\n".join(abstract.split("\n\n")[1:]) | |
| result = { | |
| "title": paper_title, | |
| "url": f"https://arxiv.org/abs/{arx_id}", | |
| "abstract": abstract, | |
| } | |
| results.append(result) | |
| return results | |
| # Search logic ---- | |
| def query_db(query: str, k: int = 10, reranker: str = "cohere"): | |
| raw_results = TBL.search(query, query_type="hybrid").limit(k) | |
| if reranker is not None: | |
| ranked_results = raw_results.rerank(reranker=RERANKERS[reranker]) | |
| else: | |
| ranked_results = raw_results | |
| ranked_results = ranked_results.to_pandas() | |
| top_results = ranked_results.groupby("arxiv_id").agg({"_relevance_score": "sum"}) | |
| top_results = top_results.sort_values(by="_relevance_score", ascending=False).head( | |
| 3 | |
| ) | |
| top_results_dict = { | |
| row["arxiv_id"]: row["paper_title"] | |
| for index, row in ranked_results.iterrows() | |
| if row["arxiv_id"] in top_results.index | |
| } | |
| final_results = _format_results(top_results_dict) | |
| return final_results | |
| ########################################################################### | |
| # FastHTML app ----- | |
| ########################################################################### | |
| style = Style(""" | |
| :root { | |
| color-scheme: dark; | |
| } | |
| body { | |
| max-width: 1200px; | |
| margin: 0 auto; | |
| padding: 20px; | |
| line-height: 1.6; | |
| } | |
| #query { | |
| width: 100%; | |
| margin-bottom: 1rem; | |
| } | |
| #search-form button { | |
| width: 100%; | |
| } | |
| #search-results, #log-entries { | |
| margin-top: 2rem; | |
| } | |
| .log-entry { | |
| border: 1px solid #ccc; | |
| padding: 10px; | |
| margin-bottom: 10px; | |
| } | |
| .log-entry pre { | |
| white-space: pre-wrap; | |
| word-wrap: break-word; | |
| } | |
| .htmx-indicator { | |
| display: none; | |
| } | |
| .htmx-request .htmx-indicator { | |
| display: inline-block; | |
| } | |
| .spinner { | |
| display: inline-block; | |
| width: 2.5em; | |
| height: 2.5em; | |
| border: 0.3em solid rgba(255,255,255,.3); | |
| border-radius: 50%; | |
| border-top-color: #fff; | |
| animation: spin 1s ease-in-out infinite; | |
| margin-left: 10px; | |
| vertical-align: middle; | |
| } | |
| @keyframes spin { | |
| to { transform: rotate(360deg); } | |
| } | |
| .searching-text { | |
| font-size: 1.2em; | |
| font-weight: bold; | |
| color: #fff; | |
| margin-right: 10px; | |
| vertical-align: middle; | |
| } | |
| """) | |
| # get the fast app and route | |
| app, rt = fast_app(live=True, hdrs=(style,)) | |
| # Initialize a database to store search logs -- | |
| db = database("log_data/search_logs.db") | |
| search_logs = db.t.search_logs | |
| if search_logs not in db.t: | |
| search_logs.create( | |
| id=int, | |
| timestamp=str, | |
| query=str, | |
| results=str, | |
| pk="id", | |
| ) | |
| SearchLog = search_logs.dataclass() | |
| def insert_log_entry(log_entry): | |
| "Insert a log entry into the database" | |
| return search_logs.insert( | |
| SearchLog( | |
| timestamp=log_entry["timestamp"].isoformat(), | |
| query=log_entry["query"], | |
| results=json.dumps(log_entry["results"]), | |
| ) | |
| ) | |
| async def get(): | |
| query_form = Form( | |
| Textarea(id="query", name="query", placeholder="Enter your query..."), | |
| Button("Submit", type="submit"), | |
| Div( | |
| Span("Searching...", cls="searching-text htmx-indicator"), | |
| Span(cls="spinner htmx-indicator"), | |
| cls="indicator-container", | |
| ), | |
| id="search-form", | |
| hx_post="/search", | |
| hx_target="#search-results", | |
| hx_indicator=".indicator-container", | |
| ) | |
| # results_div = Div(H2("Search Results"), Div(id="search-results", cls="results-container")) | |
| results_div = Div(Div(id="search-results", cls="results-container")) | |
| view_logs_link = A("View Logs", href="/logs", cls="view-logs-link") | |
| return Titled( | |
| "Zotero Search", Div(query_form, results_div, view_logs_link, cls="container") | |
| ) | |
| def SearchResult(result): | |
| "Custom component for displaying a search result" | |
| return Card( | |
| H4(A(result["title"], href=result["url"], target="_blank")), | |
| P(result["abstract"]), | |
| footer=A("Read more →", href=result["url"], target="_blank"), | |
| ) | |
| def log_query_and_results(query, results): | |
| log_entry = { | |
| "timestamp": datetime.now(), | |
| "query": query, | |
| "results": [{"title": r["title"], "url": r["url"]} for r in results], | |
| } | |
| insert_log_entry(log_entry) | |
| async def post(query: str): | |
| results = query_db(query) | |
| log_query_and_results(query, results) | |
| return Div(*[SearchResult(r) for r in results], id="search-results") | |
| def LogEntry(entry): | |
| return Div( | |
| H4(f"Query: {entry.query}"), | |
| P(f"Timestamp: {entry.timestamp}"), | |
| H5("Results:"), | |
| Pre(entry.results), | |
| cls="log-entry", | |
| ) | |
| async def get(): | |
| logs = search_logs(order_by="-id", limit=50) # Get the latest 50 logs | |
| log_entries = [LogEntry(log) for log in logs] | |
| return Titled( | |
| "Logs", | |
| Div( | |
| H2("Recent Search Logs"), | |
| Div(*log_entries, id="log-entries"), | |
| A("Back to Search", href="/", cls="back-link"), | |
| cls="container", | |
| ), | |
| ) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| setup_hf_backup(app) | |
| uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("PORT", 7860))) | |
| # run_uv() | |