Spaces:
Paused
Paused
| from time import time | |
| from typing import Iterable | |
| # import gradio as gr | |
| import streamlit as st | |
| from langchain.chains import RetrievalQA | |
| from langchain.embeddings import OpenAIEmbeddings | |
| from langchain.embeddings import HuggingFaceEmbeddings | |
| # from langchain.prompts import PromptTemplate | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline | |
| from langchain.llms import HuggingFacePipeline | |
| # from langchain.llms import OpenAI | |
| from langchain.chat_models import ChatOpenAI | |
| from langchain.vectorstores import Qdrant | |
| from openai.error import InvalidRequestError | |
| from qdrant_client import QdrantClient | |
| from config import DB_CONFIG, DB_E5_CONFIG | |
| def load_e5_embeddings(): | |
| model_name = "intfloat/multilingual-e5-large" | |
| model_kwargs = {"device": "cuda:0" if torch.cuda.is_available() else "cpu"} | |
| encode_kwargs = {"normalize_embeddings": False} | |
| embeddings = HuggingFaceEmbeddings( | |
| model_name=model_name, | |
| model_kwargs=model_kwargs, | |
| encode_kwargs=encode_kwargs, | |
| ) | |
| return embeddings | |
| def load_rinna_model(): | |
| if torch.cuda.is_available(): | |
| model_name = "rinna/bilingual-gpt-neox-4b-instruction-ppo" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| load_in_8bit=True, | |
| torch_dtype=torch.float16, | |
| device_map="auto", | |
| ) | |
| return tokenizer, model | |
| else: | |
| return None, None | |
| E5_EMBEDDINGS = load_e5_embeddings() | |
| RINNA_TOKENIZER, RINNA_MODEL = load_rinna_model() | |
| def _get_config_and_embeddings(collection_name: str | None) -> tuple: | |
| if collection_name is None or collection_name == "E5": | |
| db_config = DB_E5_CONFIG | |
| embeddings = E5_EMBEDDINGS | |
| elif collection_name == "OpenAI": | |
| db_config = DB_CONFIG | |
| embeddings = OpenAIEmbeddings() | |
| else: | |
| raise ValueError("Unknow collection name") | |
| return db_config, embeddings | |
| def _get_rinna_llm(temperature: float) -> HuggingFacePipeline | None: | |
| if RINNA_MODEL is not None: | |
| pipe = pipeline( | |
| "text-generation", | |
| model=RINNA_MODEL, | |
| tokenizer=RINNA_TOKENIZER, | |
| max_new_tokens=1024, | |
| temperature=temperature, | |
| ) | |
| llm = HuggingFacePipeline(pipeline=pipe) | |
| else: | |
| llm = None | |
| return llm | |
| def _get_llm_model( | |
| model_name: str | None, | |
| temperature: float, | |
| ): | |
| if model_name is None: | |
| model = "gpt-3.5-turbo" | |
| elif model_name == "rinna": | |
| model = "rinna" | |
| elif model_name == "GPT-3.5": | |
| model = "gpt-3.5-turbo" | |
| elif model_name == "GPT-4": | |
| model = "gpt-4" | |
| else: | |
| raise ValueError("Unknow model name") | |
| if model.startswith("gpt"): | |
| llm = ChatOpenAI(model=model, temperature=temperature) | |
| elif model == "rinna": | |
| llm = _get_rinna_llm(temperature) | |
| return llm | |
| def get_retrieval_qa( | |
| collection_name: str | None, | |
| model_name: str | None, | |
| temperature: float, | |
| option: str | None, | |
| ): | |
| db_config, embeddings = _get_config_and_embeddings(collection_name) | |
| db_url, db_api_key, db_collection_name = db_config | |
| client = QdrantClient(url=db_url, api_key=db_api_key) | |
| db = Qdrant( | |
| client=client, collection_name=db_collection_name, embeddings=embeddings | |
| ) | |
| if option is None or option == "All": | |
| retriever = db.as_retriever() | |
| else: | |
| retriever = db.as_retriever( | |
| search_kwargs={ | |
| "filter": {"category": option}, | |
| } | |
| ) | |
| llm = _get_llm_model(model_name, temperature) | |
| # chain_type_kwargs = {"prompt": PROMPT} | |
| result = RetrievalQA.from_chain_type( | |
| llm=llm, | |
| chain_type="stuff", | |
| retriever=retriever, | |
| return_source_documents=True, | |
| # chain_type_kwargs=chain_type_kwargs, | |
| ) | |
| return result | |
| def get_related_url(metadata) -> Iterable[str]: | |
| urls = set() | |
| for m in metadata: | |
| # p = m['source'] | |
| url = m["url"] | |
| if url in urls: | |
| continue | |
| urls.add(url) | |
| category = m["category"] | |
| # print(m) | |
| yield f'<p>URL: <a href="{url}">{url}</a> (category: {category})</p>' | |
| def run_qa(query: str, qa: RetrievalQA) -> tuple[str, str]: | |
| now = time() | |
| try: | |
| result = qa(query) | |
| except InvalidRequestError as e: | |
| return "回答が見つかりませんでした。別な質問をしてみてください", str(e) | |
| else: | |
| metadata = [s.metadata for s in result["source_documents"]] | |
| sec_html = f"<p>実行時間: {(time() - now):.2f}秒</p>" | |
| html = "<div>" + sec_html + "\n".join(get_related_url(metadata)) + "</div>" | |
| return result["result"], html | |
| def main( | |
| query: str, | |
| collection_name: str | None, | |
| model_name: str | None, | |
| option: str | None, | |
| temperature: float, | |
| e5_option: list[str], | |
| ) -> Iterable[tuple[str, tuple[str, str]]]: | |
| qa = get_retrieval_qa(collection_name, model_name, temperature, option) | |
| if collection_name == "E5": | |
| for option in e5_option: | |
| if option == "No": | |
| yield "E5 No", run_qa(query, qa) | |
| elif option == "Query": | |
| yield "E5 Query", run_qa("query: " + query, qa) | |
| elif option == "Passage": | |
| yield "E5 Passage", run_qa("passage: " + query, qa) | |
| else: | |
| raise ValueError("Unknow option") | |
| else: | |
| yield "OpenAI", run_qa(query, qa) | |
| AVAILABLE_LLMS = ["GPT-3.5", "GPT-4"] | |
| if RINNA_MODEL is not None: | |
| AVAILABLE_LLMS.append("rinna") | |
| with st.form("my_form"): | |
| query = st.text_input(label="query") | |
| collection_name = st.radio(options=["E5", "OpenAI"], label="Embedding") | |
| # if collection_name == "E5": # TODO : 選択肢で選べるようにする | |
| e5_option = st.multiselect("E5 option", ["No", "Query", "Passage"], default="No") | |
| model_name = st.radio( | |
| options=AVAILABLE_LLMS, | |
| label="Model", | |
| help="GPU環境だとrinnaが選択可能", | |
| ) | |
| option = st.radio( | |
| options=["All", "ja-book", "ja-nvda-user-guide", "en-nvda-user-guide"], | |
| label="絞り込み", | |
| help="ドキュメント制限する?", | |
| ) | |
| temperature = st.slider(label="temperature", min_value=0, max_value=2) | |
| submitted = st.form_submit_button("Submit") | |
| if submitted: | |
| with st.spinner("Searching..."): | |
| results = main( | |
| query, collection_name, model_name, option, temperature, e5_option | |
| ) | |
| for type_, (answer, html) in results: | |
| with st.container(): | |
| st.header(type_) | |
| st.write(answer) | |
| st.markdown(html, unsafe_allow_html=True) | |
| st.divider() | |