Spaces:
Paused
Paused
try to change embedding model
Browse files
app.py
CHANGED
|
@@ -1,6 +1,9 @@
|
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
from langchain.chains import RetrievalQA
|
| 3 |
-
from langchain.embeddings import OpenAIEmbeddings
|
|
|
|
|
|
|
| 4 |
from langchain.llms import OpenAI
|
| 5 |
from langchain.chat_models import ChatOpenAI
|
| 6 |
from langchain.vectorstores import Qdrant
|
|
@@ -16,7 +19,16 @@ PERSIST_DIR_NAME = "nvdajp-book"
|
|
| 16 |
|
| 17 |
|
| 18 |
def get_retrieval_qa(model_name: str | None, temperature: int, option: str | None) -> RetrievalQA:
|
| 19 |
-
embeddings = OpenAIEmbeddings()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
db_url, db_api_key, db_collection_name = DB_CONFIG
|
| 21 |
client = QdrantClient(url=db_url, api_key=db_api_key)
|
| 22 |
db = Qdrant(client=client, collection_name=db_collection_name, embeddings=embeddings)
|
|
@@ -36,7 +48,7 @@ def get_retrieval_qa(model_name: str | None, temperature: int, option: str | Non
|
|
| 36 |
"filter": {"category": option},
|
| 37 |
}
|
| 38 |
)
|
| 39 |
-
|
| 40 |
llm=ChatOpenAI(
|
| 41 |
model=model,
|
| 42 |
temperature=temperature
|
|
@@ -45,6 +57,7 @@ def get_retrieval_qa(model_name: str | None, temperature: int, option: str | Non
|
|
| 45 |
retriever=retriever,
|
| 46 |
return_source_documents=True,
|
| 47 |
)
|
|
|
|
| 48 |
|
| 49 |
|
| 50 |
def get_related_url(metadata):
|
|
|
|
| 1 |
+
from time import time
|
| 2 |
import gradio as gr
|
| 3 |
from langchain.chains import RetrievalQA
|
| 4 |
+
# from langchain.embeddings import OpenAIEmbeddings
|
| 5 |
+
from langchain.embeddings import HuggingFaceEmbeddings
|
| 6 |
+
from langchain.embeddings import GPT4AllEmbeddings
|
| 7 |
from langchain.llms import OpenAI
|
| 8 |
from langchain.chat_models import ChatOpenAI
|
| 9 |
from langchain.vectorstores import Qdrant
|
|
|
|
| 19 |
|
| 20 |
|
| 21 |
def get_retrieval_qa(model_name: str | None, temperature: int, option: str | None) -> RetrievalQA:
|
| 22 |
+
# embeddings = OpenAIEmbeddings()
|
| 23 |
+
model_name = "sentence-transformers/all-mpnet-base-v2"
|
| 24 |
+
model_kwargs = {'device': 'cpu'}
|
| 25 |
+
encode_kwargs = {'normalize_embeddings': False}
|
| 26 |
+
embeddings = HuggingFaceEmbeddings(
|
| 27 |
+
model_name=model_name,
|
| 28 |
+
model_kwargs=model_kwargs,
|
| 29 |
+
encode_kwargs=encode_kwargs,
|
| 30 |
+
)
|
| 31 |
+
# embeddings = GPT4AllEmbeddings()
|
| 32 |
db_url, db_api_key, db_collection_name = DB_CONFIG
|
| 33 |
client = QdrantClient(url=db_url, api_key=db_api_key)
|
| 34 |
db = Qdrant(client=client, collection_name=db_collection_name, embeddings=embeddings)
|
|
|
|
| 48 |
"filter": {"category": option},
|
| 49 |
}
|
| 50 |
)
|
| 51 |
+
result = RetrievalQA.from_chain_type(
|
| 52 |
llm=ChatOpenAI(
|
| 53 |
model=model,
|
| 54 |
temperature=temperature
|
|
|
|
| 57 |
retriever=retriever,
|
| 58 |
return_source_documents=True,
|
| 59 |
)
|
| 60 |
+
return result
|
| 61 |
|
| 62 |
|
| 63 |
def get_related_url(metadata):
|
config.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
import os
|
| 2 |
|
| 3 |
|
| 4 |
-
SAAS =
|
| 5 |
|
| 6 |
|
| 7 |
def get_db_config():
|
|
|
|
| 1 |
import os
|
| 2 |
|
| 3 |
|
| 4 |
+
SAAS = False
|
| 5 |
|
| 6 |
|
| 7 |
def get_db_config():
|
store.py
CHANGED
|
@@ -1,6 +1,8 @@
|
|
| 1 |
from langchain.document_loaders import ReadTheDocsLoader
|
| 2 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
| 3 |
-
from langchain.embeddings import OpenAIEmbeddings
|
|
|
|
|
|
|
| 4 |
from langchain.vectorstores import Qdrant
|
| 5 |
# from qdrant_client import QdrantClient
|
| 6 |
from nvda_ug_loader import NVDAUserGuideLoader
|
|
@@ -35,7 +37,16 @@ def get_text_chunk(docs):
|
|
| 35 |
|
| 36 |
|
| 37 |
def store(texts):
|
| 38 |
-
embeddings = OpenAIEmbeddings()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
db_url, db_api_key, db_collection_name = DB_CONFIG
|
| 40 |
# client = QdrantClient(url=db_url, api_key=db_api_key, prefer_grpc=True)
|
| 41 |
_ = Qdrant.from_documents(
|
|
|
|
| 1 |
from langchain.document_loaders import ReadTheDocsLoader
|
| 2 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
| 3 |
+
# from langchain.embeddings import OpenAIEmbeddings
|
| 4 |
+
from langchain.embeddings import HuggingFaceEmbeddings
|
| 5 |
+
from langchain.embeddings import GPT4AllEmbeddings
|
| 6 |
from langchain.vectorstores import Qdrant
|
| 7 |
# from qdrant_client import QdrantClient
|
| 8 |
from nvda_ug_loader import NVDAUserGuideLoader
|
|
|
|
| 37 |
|
| 38 |
|
| 39 |
def store(texts):
|
| 40 |
+
# embeddings = OpenAIEmbeddings()
|
| 41 |
+
model_name = "sentence-transformers/all-mpnet-base-v2"
|
| 42 |
+
model_kwargs = {'device': 'cuda'}
|
| 43 |
+
encode_kwargs = {'normalize_embeddings': False}
|
| 44 |
+
embeddings = HuggingFaceEmbeddings(
|
| 45 |
+
model_name=model_name,
|
| 46 |
+
model_kwargs=model_kwargs,
|
| 47 |
+
encode_kwargs=encode_kwargs,
|
| 48 |
+
)
|
| 49 |
+
# embeddings = GPT4AllEmbeddings()
|
| 50 |
db_url, db_api_key, db_collection_name = DB_CONFIG
|
| 51 |
# client = QdrantClient(url=db_url, api_key=db_api_key, prefer_grpc=True)
|
| 52 |
_ = Qdrant.from_documents(
|