Spaces:
Sleeping
Sleeping
Léo Bourrel
commited on
Commit
·
24d1b6f
1
Parent(s):
b8c8744
feat: replace sqlalchemy query by executing SQL
Browse files- custom_pgvector.py +33 -14
custom_pgvector.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
from __future__ import annotations
|
| 2 |
-
|
| 3 |
import asyncio
|
| 4 |
import contextlib
|
| 5 |
import enum
|
|
@@ -28,6 +28,7 @@ from langchain.vectorstores.utils import maximal_marginal_relevance
|
|
| 28 |
from pgvector.sqlalchemy import Vector
|
| 29 |
from sqlalchemy import delete
|
| 30 |
from sqlalchemy.orm import Session, declarative_base, relationship
|
|
|
|
| 31 |
|
| 32 |
|
| 33 |
class DistanceStrategy(str, enum.Enum):
|
|
@@ -348,16 +349,19 @@ class CustomPGVector(VectorStore):
|
|
| 348 |
docs = [
|
| 349 |
(
|
| 350 |
Document(
|
| 351 |
-
page_content=result.
|
| 352 |
metadata={
|
| 353 |
-
"id": result.
|
| 354 |
-
"title": result.
|
| 355 |
-
"
|
|
|
|
|
|
|
|
|
|
| 356 |
},
|
| 357 |
),
|
| 358 |
result.distance if self.embedding_function is not None else None,
|
| 359 |
)
|
| 360 |
-
for result in results
|
| 361 |
]
|
| 362 |
return docs
|
| 363 |
|
|
@@ -369,16 +373,31 @@ class CustomPGVector(VectorStore):
|
|
| 369 |
) -> List[Any]:
|
| 370 |
"""Query the collection."""
|
| 371 |
with Session(self._conn) as session:
|
| 372 |
-
results
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 376 |
)
|
| 377 |
-
.order_by(sqlalchemy.asc("distance"))
|
| 378 |
-
.limit(k)
|
| 379 |
-
.all()
|
| 380 |
)
|
| 381 |
-
|
|
|
|
| 382 |
return results
|
| 383 |
|
| 384 |
def similarity_search_by_vector(
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
+
import pandas as pd
|
| 3 |
import asyncio
|
| 4 |
import contextlib
|
| 5 |
import enum
|
|
|
|
| 28 |
from pgvector.sqlalchemy import Vector
|
| 29 |
from sqlalchemy import delete
|
| 30 |
from sqlalchemy.orm import Session, declarative_base, relationship
|
| 31 |
+
from sqlalchemy import text
|
| 32 |
|
| 33 |
|
| 34 |
class DistanceStrategy(str, enum.Enum):
|
|
|
|
| 349 |
docs = [
|
| 350 |
(
|
| 351 |
Document(
|
| 352 |
+
page_content=result.abstract,
|
| 353 |
metadata={
|
| 354 |
+
"id": result.id,
|
| 355 |
+
"title": result.title,
|
| 356 |
+
"authors": result.authors,
|
| 357 |
+
"doi": result.doi,
|
| 358 |
+
"keywords": results.keywords,
|
| 359 |
+
"distance": results.distance,
|
| 360 |
},
|
| 361 |
),
|
| 362 |
result.distance if self.embedding_function is not None else None,
|
| 363 |
)
|
| 364 |
+
for result in results.itertuples()
|
| 365 |
]
|
| 366 |
return docs
|
| 367 |
|
|
|
|
| 373 |
) -> List[Any]:
|
| 374 |
"""Query the collection."""
|
| 375 |
with Session(self._conn) as session:
|
| 376 |
+
results = session.execute(
|
| 377 |
+
text(
|
| 378 |
+
f"""
|
| 379 |
+
select
|
| 380 |
+
a.id,
|
| 381 |
+
a.title,
|
| 382 |
+
a.doi,
|
| 383 |
+
a.abstract,
|
| 384 |
+
string_agg(distinct keyword."name", ',') as keywords,
|
| 385 |
+
string_agg(distinct author."name", ',') as authors,
|
| 386 |
+
abstract_embedding <-> '{str(embedding)}' as distance
|
| 387 |
+
from article a
|
| 388 |
+
left join article_keyword ON article_keyword.article_id = a.id
|
| 389 |
+
left join keyword on article_keyword.keyword_id = keyword.id
|
| 390 |
+
left join article_author ON article_author.article_id = a.id
|
| 391 |
+
left join author on author.id = article_author.author_id
|
| 392 |
+
where abstract != 'NaN'
|
| 393 |
+
GROUP BY a.id
|
| 394 |
+
ORDER BY distance
|
| 395 |
+
LIMIT {k};
|
| 396 |
+
"""
|
| 397 |
)
|
|
|
|
|
|
|
|
|
|
| 398 |
)
|
| 399 |
+
results = results.fetchall()
|
| 400 |
+
results = pd.DataFrame(results, columns=["id", "title", "doi", "abstract", "keywords", "authors", "distance"])
|
| 401 |
return results
|
| 402 |
|
| 403 |
def similarity_search_by_vector(
|