Spaces:
Sleeping
Sleeping
Commit
·
5c20978
1
Parent(s):
e87a6a0
feat: replace postgres with sqlite
Browse files- app.py +10 -2
- custom_pgvector.py +47 -25
app.py
CHANGED
|
@@ -2,6 +2,7 @@ import json
|
|
| 2 |
import os
|
| 3 |
|
| 4 |
import sqlalchemy
|
|
|
|
| 5 |
import streamlit as st
|
| 6 |
import streamlit.components.v1 as components
|
| 7 |
from langchain import OpenAI
|
|
@@ -9,13 +10,14 @@ from langchain.callbacks import get_openai_callback
|
|
| 9 |
from langchain.chains import ConversationalRetrievalChain
|
| 10 |
from langchain.chains.conversation.memory import ConversationBufferMemory
|
| 11 |
from langchain.embeddings import GPT4AllEmbeddings
|
|
|
|
| 12 |
|
| 13 |
from chat_history import insert_chat_history, insert_chat_history_articles
|
| 14 |
from css import load_css
|
| 15 |
from custom_pgvector import CustomPGVector
|
| 16 |
from message import Message
|
| 17 |
|
| 18 |
-
CONNECTION_STRING = "
|
| 19 |
|
| 20 |
st.set_page_config(layout="wide")
|
| 21 |
|
|
@@ -26,10 +28,16 @@ chat_column, doc_column = st.columns([2, 1])
|
|
| 26 |
|
| 27 |
def connect() -> sqlalchemy.engine.Connection:
|
| 28 |
engine = sqlalchemy.create_engine(CONNECTION_STRING)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
conn = engine.connect()
|
| 30 |
return conn
|
| 31 |
|
| 32 |
-
|
| 33 |
conn = connect()
|
| 34 |
|
| 35 |
|
|
|
|
| 2 |
import os
|
| 3 |
|
| 4 |
import sqlalchemy
|
| 5 |
+
import sqlite_vss
|
| 6 |
import streamlit as st
|
| 7 |
import streamlit.components.v1 as components
|
| 8 |
from langchain import OpenAI
|
|
|
|
| 10 |
from langchain.chains import ConversationalRetrievalChain
|
| 11 |
from langchain.chains.conversation.memory import ConversationBufferMemory
|
| 12 |
from langchain.embeddings import GPT4AllEmbeddings
|
| 13 |
+
from sqlalchemy import event
|
| 14 |
|
| 15 |
from chat_history import insert_chat_history, insert_chat_history_articles
|
| 16 |
from css import load_css
|
| 17 |
from custom_pgvector import CustomPGVector
|
| 18 |
from message import Message
|
| 19 |
|
| 20 |
+
CONNECTION_STRING = "sqlite:///data/sorbobot.db"
|
| 21 |
|
| 22 |
st.set_page_config(layout="wide")
|
| 23 |
|
|
|
|
| 28 |
|
| 29 |
def connect() -> sqlalchemy.engine.Connection:
|
| 30 |
engine = sqlalchemy.create_engine(CONNECTION_STRING)
|
| 31 |
+
|
| 32 |
+
@event.listens_for(engine, "connect")
|
| 33 |
+
def receive_connect(connection, _):
|
| 34 |
+
connection.enable_load_extension(True)
|
| 35 |
+
sqlite_vss.load(connection)
|
| 36 |
+
connection.enable_load_extension(False)
|
| 37 |
+
|
| 38 |
conn = engine.connect()
|
| 39 |
return conn
|
| 40 |
|
|
|
|
| 41 |
conn = connect()
|
| 42 |
|
| 43 |
|
custom_pgvector.py
CHANGED
|
@@ -4,6 +4,7 @@ import contextlib
|
|
| 4 |
import enum
|
| 5 |
import json
|
| 6 |
import logging
|
|
|
|
| 7 |
from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, Type
|
| 8 |
|
| 9 |
import pandas as pd
|
|
@@ -348,33 +349,54 @@ class CustomPGVector(VectorStore):
|
|
| 348 |
k: int = 4,
|
| 349 |
) -> List[Any]:
|
| 350 |
"""Query the collection."""
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
|
|
|
| 355 |
select
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
left join article_keyword ON article_keyword.article_id = a.id
|
| 365 |
-
left join keyword on article_keyword.keyword_id = keyword.id
|
| 366 |
-
left join article_author ON article_author.article_id = a.id
|
| 367 |
-
left join author on author.id = article_author.author_id
|
| 368 |
-
where abstract != 'NaN'
|
| 369 |
-
GROUP BY a.id
|
| 370 |
-
ORDER BY distance
|
| 371 |
-
LIMIT {k};
|
| 372 |
-
"""
|
| 373 |
)
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 378 |
return results
|
| 379 |
|
| 380 |
def similarity_search_by_vector(
|
|
|
|
| 4 |
import enum
|
| 5 |
import json
|
| 6 |
import logging
|
| 7 |
+
import struct
|
| 8 |
from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, Type
|
| 9 |
|
| 10 |
import pandas as pd
|
|
|
|
| 349 |
k: int = 4,
|
| 350 |
) -> List[Any]:
|
| 351 |
"""Query the collection."""
|
| 352 |
+
vector = bytearray(struct.pack("f" * len(embedding), *embedding))
|
| 353 |
+
|
| 354 |
+
cursor = self._conn.execute(
|
| 355 |
+
text("""
|
| 356 |
+
with matches as (
|
| 357 |
select
|
| 358 |
+
rowid,
|
| 359 |
+
distance
|
| 360 |
+
from vss_article
|
| 361 |
+
where vss_search(
|
| 362 |
+
abstract_embedding,
|
| 363 |
+
:vector
|
| 364 |
+
)
|
| 365 |
+
limit :limit
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 366 |
)
|
| 367 |
+
select
|
| 368 |
+
article.id,
|
| 369 |
+
article.title,
|
| 370 |
+
article.doi,
|
| 371 |
+
article.abstract,
|
| 372 |
+
group_concat(keyword."name", ',') as keywords,
|
| 373 |
+
group_concat(author."name", ',') as authors,
|
| 374 |
+
matches.distance
|
| 375 |
+
from matches
|
| 376 |
+
left join article on matches.rowid = article.rowid
|
| 377 |
+
left join article_keyword ak ON ak.article_id = article.id
|
| 378 |
+
left join keyword on ak.keyword_id = keyword.id
|
| 379 |
+
left join article_author ON article_author.article_id = article.id
|
| 380 |
+
left join author on author.id = article_author.author_id
|
| 381 |
+
group by article.id
|
| 382 |
+
order by distance;
|
| 383 |
+
"""),
|
| 384 |
+
{"vector": vector, "limit": k}
|
| 385 |
+
)
|
| 386 |
+
results = cursor.fetchall()
|
| 387 |
+
results = pd.DataFrame(
|
| 388 |
+
results,
|
| 389 |
+
columns=[
|
| 390 |
+
"id",
|
| 391 |
+
"title",
|
| 392 |
+
"doi",
|
| 393 |
+
"abstract",
|
| 394 |
+
"keywords",
|
| 395 |
+
"authors",
|
| 396 |
+
"distance",
|
| 397 |
+
],
|
| 398 |
+
)
|
| 399 |
+
results = results.to_dict(orient="records")
|
| 400 |
return results
|
| 401 |
|
| 402 |
def similarity_search_by_vector(
|