Spaces:
Running
Running
| from typing import List, Tuple, Dict, TypedDict, Optional, Any | |
| import os | |
| import gradio as gr | |
| from langchain_core.language_models.llms import LLM | |
| from langchain_openai.chat_models import ChatOpenAI | |
| from langchain_aws import ChatBedrock | |
| import boto3 | |
| from ask_candid.base.config.rest import OPENAI | |
| from ask_candid.base.config.models import Name2Endpoint | |
| from ask_candid.base.config.data import ALL_INDICES | |
| from ask_candid.utils import format_chat_ag_response | |
| from ask_candid.chat import run_chat | |
| try: | |
| from feedback import FeedbackApi | |
| except ImportError: | |
| from demos.feedback import FeedbackApi | |
| ROOT = os.path.dirname(os.path.abspath(__file__)) | |
| class LoggedComponents(TypedDict): | |
| context: List[gr.components.Component] | |
| found_helpful: gr.components.Component | |
| will_recommend: gr.components.Component | |
| comments: gr.components.Component | |
| email: gr.components.Component | |
| def send_feedback( | |
| chat_context, | |
| found_helpful, | |
| will_recommend, | |
| comments, | |
| ): | |
| api = FeedbackApi() | |
| total_submissions = 0 | |
| try: | |
| response = api( | |
| context=chat_context, | |
| found_helpful=found_helpful, | |
| will_recommend=will_recommend, | |
| comments=comments, | |
| email=email | |
| ) | |
| total_submissions = response.get("response", 0) | |
| gr.Info("Thank you for submitting feedback") | |
| except Exception as ex: | |
| raise gr.Error(f"Error submitting feedback: {ex}") | |
| return total_submissions | |
| def select_foundation_model(model_name: str, max_new_tokens: int) -> LLM: | |
| if model_name == "gpt-4o": | |
| llm = ChatOpenAI( | |
| model_name=Name2Endpoint[model_name], | |
| max_tokens=max_new_tokens, | |
| api_key=OPENAI["key"], | |
| temperature=0.0, | |
| streaming=True, | |
| ) | |
| elif model_name in {"claude-3.5-haiku", "llama-3.1-70b-instruct", "mistral-large", "mixtral-8x7B"}: | |
| llm = ChatBedrock( | |
| client=boto3.client("bedrock-runtime", region_name="us-east-1"), | |
| model=Name2Endpoint[model_name], | |
| max_tokens=max_new_tokens, | |
| temperature=0.0 | |
| ) | |
| else: | |
| raise gr.Error(f"Base model `{model_name}` is not supported") | |
| return llm | |
| def execute( | |
| thread_id: str, | |
| user_input: Dict[str, Any], | |
| history: List[Dict], | |
| model_name: str, | |
| max_new_tokens: int, | |
| indices: Optional[List[str]] = None, | |
| ): | |
| return run_chat( | |
| thread_id=thread_id, | |
| user_input=user_input, | |
| history=history, | |
| llm=select_foundation_model(model_name=model_name, max_new_tokens=max_new_tokens), | |
| indices=indices | |
| ) | |
| def build_rag_chat() -> Tuple[LoggedComponents, gr.Blocks]: | |
| with gr.Blocks(theme=gr.themes.Soft(), title="Candid's AI assistant") as demo: | |
| gr.Markdown( | |
| """ | |
| <h1>Candid's AI assistant</h1> | |
| <p> | |
| Please read the <a | |
| href='https://info.candid.org/chatbot-reference-guide' | |
| target="_blank" | |
| rel="noopener noreferrer" | |
| >guide</a> to get started. | |
| </p> | |
| <hr> | |
| """ | |
| ) | |
| with gr.Accordion(label="Advanced settings", open=False): | |
| es_indices = gr.CheckboxGroup( | |
| choices=list(ALL_INDICES), | |
| value=list(ALL_INDICES), | |
| label="Sources to include", | |
| interactive=True, | |
| ) | |
| llmname = gr.Radio( | |
| label="Language model", | |
| value="claude-3.5-haiku", | |
| choices=list(Name2Endpoint.keys()), | |
| interactive=True, | |
| ) | |
| max_new_tokens = gr.Slider( | |
| value=256 * 3, | |
| minimum=128, | |
| maximum=2048, | |
| step=128, | |
| label="Max new tokens", | |
| interactive=True, | |
| ) | |
| with gr.Column(): | |
| chatbot = gr.Chatbot( | |
| label="AskCandid", | |
| elem_id="chatbot", | |
| bubble_full_width=True, | |
| avatar_images=( | |
| None, | |
| os.path.join(ROOT, "static", "candid_logo_yellow.png"), | |
| ), | |
| height="45vh", | |
| type="messages", | |
| show_label=False, | |
| show_copy_button=True, | |
| show_share_button=None, | |
| show_copy_all_button=False, | |
| ) | |
| msg = gr.MultimodalTextbox(label="Your message", interactive=True) | |
| thread_id = gr.Text(visible=False, value="", label="thread_id") | |
| gr.ClearButton(components=[msg, chatbot, thread_id], size="sm") | |
| # pylint: disable=no-member | |
| chat_msg = msg.submit( | |
| fn=execute, | |
| inputs=[thread_id, msg, chatbot, llmname, max_new_tokens, es_indices], | |
| outputs=[msg, chatbot, thread_id], | |
| ) | |
| chat_msg.then(format_chat_ag_response, chatbot, chatbot, api_name="bot_response") | |
| logged = LoggedComponents(context=chatbot) | |
| return logged, demo | |
| def build_feedback(components: LoggedComponents) -> gr.Blocks: | |
| with gr.Blocks(theme=gr.themes.Soft(), title="Candid AI demo") as demo: | |
| gr.Markdown("<h1>Help us improve this tool with your valuable feedback</h1>") | |
| with gr.Row(): | |
| with gr.Column(): | |
| found_helpful = gr.Radio( | |
| [True, False], label="Did you find what you were looking for?" | |
| ) | |
| will_recommend = gr.Radio( | |
| [True, False], | |
| label="Will you recommend this Chatbot to others?", | |
| ) | |
| comment = gr.Textbox(label="Additional comments (optional)", lines=4) | |
| email = gr.Textbox(label="Your email (optional)", lines=1) | |
| submit = gr.Button("Submit Feedback") | |
| components["found_helpful"] = found_helpful | |
| components["will_recommend"] = will_recommend | |
| components["comments"] = comment | |
| components["email"] = email | |
| # pylint: disable=no-member | |
| submit.click( | |
| fn=send_feedback, | |
| inputs=[ | |
| components["context"], | |
| components["found_helpful"], | |
| components["will_recommend"], | |
| components["comments"], | |
| components["email"] | |
| ], | |
| outputs=None, | |
| show_api=False, | |
| api_name=False, | |
| preprocess=False, | |
| ) | |
| return demo | |
| def build_app(): | |
| logger, candid_chat = build_rag_chat() | |
| feedback = build_feedback(logger) | |
| with open(os.path.join(ROOT, "static", "chatStyle.css"), "r", encoding="utf8") as f: | |
| css_chat = f.read() | |
| demo = gr.TabbedInterface( | |
| interface_list=[ | |
| candid_chat, | |
| feedback | |
| ], | |
| tab_names=[ | |
| "Candid's AI assistant", | |
| "Feedback" | |
| ], | |
| title="Candid's AI assistant", | |
| theme=gr.themes.Soft(), | |
| css=css_chat, | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| app = build_app() | |
| app.queue(max_size=5).launch( | |
| show_api=False, | |
| auth=[ | |
| (os.getenv("APP_USERNAME"), os.getenv("APP_PASSWORD")), | |
| (os.getenv("APP_PUBLIC_USERNAME"), os.getenv("APP_PUBLIC_PASSWORD")), | |
| ], | |
| auth_message="Login to Candid's AI assistant", | |
| ssr_mode=False | |
| ) | |