Spaces:
Paused
Paused
| import os | |
| import gradio as gr | |
| import pinecone | |
| from langchain import PromptTemplate | |
| from langchain.chains import RetrievalQA | |
| from langchain.embeddings.huggingface import HuggingFaceEmbeddings | |
| from langchain.llms import HuggingFaceEndpoint | |
| from langchain.memory import ConversationBufferWindowMemory | |
| from langchain.vectorstores import Pinecone | |
| from torch import cuda | |
| LLAMA_2_7B_CHAT_HF_FRANC_V0_9 = os.environ.get("LLAMA_2_7B_CHAT_HF_FRANC_V0_9") | |
| HUGGING_FACE_HUB_TOKEN = os.environ.get("HUGGING_FACE_HUB_TOKEN") | |
| PINECONE_API_KEY = os.environ.get('PINECONE_API_KEY') | |
| PINECONE_ENVIRONMENT = os.environ.get('PINECONE_ENVIRONMENT') | |
| # Set up Pinecone vector store | |
| pinecone.init( | |
| api_key=PINECONE_API_KEY, | |
| environment=PINECONE_ENVIRONMENT | |
| ) | |
| index_name = 'stadion-6237' | |
| index = pinecone.Index(index_name) | |
| embedding_model_id = 'sentence-transformers/paraphrase-mpnet-base-v2' | |
| device = f'cuda:{cuda.current_device()}' if cuda.is_available() else 'cpu' | |
| embedding_model = HuggingFaceEmbeddings( | |
| model_name=embedding_model_id, | |
| model_kwargs={'device': device}, | |
| encode_kwargs={'device': device, 'batch_size': 32} | |
| ) | |
| text_key = 'text' | |
| vector_store = Pinecone( | |
| index, embedding_model.embed_query, text_key | |
| ) | |
| B_INST, E_INST = "[INST] ", " [/INST]" | |
| B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n" | |
| def get_prompt_template(instruction, system_prompt): | |
| system_prompt = B_SYS + system_prompt + E_SYS | |
| prompt_template = B_INST + system_prompt + instruction + E_INST | |
| return prompt_template | |
| template = get_prompt_template( | |
| """Use the following context to answer the question at the end. | |
| Context: | |
| {context} | |
| Question: {question}""", | |
| """Reply in 10 sentences or less. | |
| Do not use emotes.""" | |
| ) | |
| endpoint_url = ( | |
| LLAMA_2_7B_CHAT_HF_FRANC_V0_9 | |
| ) | |
| llm = HuggingFaceEndpoint( | |
| endpoint_url=endpoint_url, | |
| huggingfacehub_api_token=HUGGING_FACE_HUB_TOKEN, | |
| task="text-generation", | |
| model_kwargs={ | |
| "max_new_tokens": 512, | |
| "temperature": 0.1, | |
| "repetition_penalty": 1.1, | |
| "return_full_text": True, | |
| }, | |
| ) | |
| prompt = PromptTemplate( | |
| template=template, | |
| input_variables=["context", "question"] | |
| ) | |
| memory = ConversationBufferWindowMemory( | |
| k=3, | |
| memory_key="history", | |
| input_key="question", | |
| ai_prefix="Franc", | |
| human_prefix="Runner", | |
| ) | |
| rag_chain = RetrievalQA.from_chain_type( | |
| llm=llm, | |
| chain_type='stuff', | |
| retriever=vector_store.as_retriever(search_kwargs={'k': 4}), | |
| chain_type_kwargs={ | |
| "prompt": prompt, | |
| # "memory": memory, | |
| }, | |
| ) | |
| def generate(message, history): | |
| reply = rag_chain(message) | |
| return reply['result'].strip() | |
| gr.ChatInterface( | |
| generate, | |
| title="Franc v1.0", | |
| theme=gr.themes.Soft(), | |
| submit_btn="Ask Franc", | |
| retry_btn="Do better, Franc!", | |
| autofocus=True, | |
| ).queue().launch() | |