Spaces:
Sleeping
Sleeping
| import json | |
| import time | |
| import ast | |
| import torch | |
| import gradio as gr | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from gradio_consilium_roundtable import consilium_roundtable | |
| # === Constants === | |
| MODEL_NAME = "katanemo/Arch-Router-1.5B" | |
| ARCH_ROUTER = "Arch Router" | |
| WAIT_DEPARTMENT = 5 | |
| WAIT_SYSTEM = 5 | |
| # === Load model/tokenizer === | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_NAME, device_map="auto", torch_dtype="auto", trust_remote_code=True | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| # === Route Definitions === | |
| route_config = [ | |
| {"name": "code_generation", "description": "Generating code based on prompts"}, | |
| {"name": "bug_fixing", "description": "Fixing errors or bugs in code"}, | |
| {"name": "performance_optimization", "description": "Improving code performance"}, | |
| {"name": "api_help", "description": "Assisting with APIs and libraries"}, | |
| {"name": "programming", "description": "General programming Q&A"}, | |
| {"name": "legal", "description": "Legal"}, | |
| {"name": "healthcare", "description": "Healthcare and medical related"}, | |
| ] | |
| departments = { | |
| "code_generation": ("π»", "Code Generation"), | |
| "bug_fixing": ("π", "Bug Fixing"), | |
| "performance_optimization": ("β‘", "Performance Optimization"), | |
| "api_help": ("π", "API Help"), | |
| "programming": ("π", "Programming"), | |
| "legal": ("βοΈ", "Legal"), | |
| "healthcare": ("π©Ί", "Healthcare"), | |
| "other": ("β", "Other / General Inquiry"), | |
| } | |
| # === Prompt Formatting === | |
| TASK_INSTRUCTION = """ | |
| You are a helpful assistant designed to find the best suited route. You are provided with route description within <routes></routes> XML tags: | |
| <routes> | |
| {routes} | |
| </routes> | |
| <conversation> | |
| {conversation} | |
| </conversation> | |
| """ | |
| FORMAT_PROMPT = """ | |
| Your task is to decide which route is best suit with user intent on the conversation in <conversation></conversation> XML tags. Follow the instruction: | |
| 1. If the latest intent from user is irrelevant or user intent is full filled, response with other route {"route": "other"}. | |
| 2. You must analyze the route descriptions and find the best match route for user latest intent. | |
| 3. You only response the name of the route that best matches the user's request, use the exact name in the <routes></routes>. | |
| Based on your analysis, provide your response in the following JSON format: | |
| {"route": "route_name"} | |
| """ | |
| def format_prompt(route_config, conversation): | |
| return TASK_INSTRUCTION.format( | |
| routes=json.dumps(route_config), conversation=json.dumps(conversation) | |
| ) + FORMAT_PROMPT | |
| def parse_route(response_text): | |
| try: | |
| start = response_text.find("{") | |
| end = response_text.rfind("}") + 1 | |
| return ast.literal_eval(response_text[start:end]).get("route", "other") | |
| except Exception as e: | |
| print("Parsing failed:", e) | |
| return "other" | |
| def init_state(): | |
| avatar_emojis = { | |
| ARCH_ROUTER: "https://avatars.githubusercontent.com/u/112724757?s=200&v=4", | |
| "code_generation": "https://raw.githubusercontent.com/twitter/twemoji/master/assets/72x72/1f4bb.png", | |
| "bug_fixing": "https://raw.githubusercontent.com/twitter/twemoji/master/assets/72x72/1f41e.png", | |
| "performance_optimization": "https://raw.githubusercontent.com/twitter/twemoji/master/assets/72x72/26a1.png", | |
| "api_help": "https://raw.githubusercontent.com/twitter/twemoji/master/assets/72x72/1f50c.png", | |
| "programming": "https://raw.githubusercontent.com/twitter/twemoji/master/assets/72x72/1f4da.png", | |
| "legal": "https://raw.githubusercontent.com/twitter/twemoji/master/assets/72x72/2696.png", | |
| "healthcare": "https://raw.githubusercontent.com/twitter/twemoji/master/assets/72x72/1fa7a.png", | |
| "other": "https://raw.githubusercontent.com/twitter/twemoji/master/assets/72x72/2753.png", | |
| } | |
| return { | |
| "messages": [], | |
| "participants": [ARCH_ROUTER] + list(departments.keys()), | |
| "currentSpeaker": None, | |
| "thinking": [], | |
| "showBubbles": [ARCH_ROUTER], | |
| "avatarImages": avatar_emojis, | |
| } | |
| def route_and_visualize(user_input_text, rt_state, chat_history): | |
| chat_history = chat_history or [] | |
| rt_state = rt_state or {"messages": []} | |
| chat_history.append(("User", user_input_text)) | |
| # Step 1: Disable input and show route detection | |
| rt_state["messages"] = [{"speaker": ARCH_ROUTER, "text": "π Identifying route, please wait..."}] | |
| yield rt_state, chat_history, rt_state, gr.update(interactive=False) | |
| # Step 2: Prepare prompt and get route | |
| conversation = [{"role": "user", "content": user_input_text}] | |
| route_prompt = format_prompt(route_config, conversation) | |
| input_ids = tokenizer.apply_chat_template( | |
| [{"role": "user", "content": route_prompt}], | |
| add_generation_prompt=True, | |
| return_tensors="pt" | |
| ).to(model.device) | |
| with torch.no_grad(): | |
| output = model.generate(input_ids=input_ids, max_new_tokens=512) | |
| prompt_len = input_ids.shape[1] | |
| response = tokenizer.decode(output[0][prompt_len:], skip_special_tokens=True).strip() | |
| print("MODEL RAW:", response) | |
| route = parse_route(response) | |
| emoji, dept_name = departments.get(route, departments["other"]) | |
| # Step 3: Show route identified | |
| rt_state["messages"][0] = { | |
| "speaker": ARCH_ROUTER, | |
| "text": f"π Identified department: **{dept_name}**. Forwarding task..." | |
| } | |
| chat_history.append((ARCH_ROUTER, f"π Identified department: {dept_name}. Forwarding task...")) | |
| yield rt_state, chat_history, rt_state, gr.update(interactive=False) | |
| # Step 4: Show processing | |
| time.sleep(3) | |
| rt_state["messages"].extend([ | |
| {"speaker": route, "text": f"{emoji} {dept_name} simulation is processing your request in {WAIT_DEPARTMENT} secs..."}, | |
| {"speaker": ARCH_ROUTER, "text": "β³ Waiting for department to respond..."} | |
| ]) | |
| rt_state["showBubbles"] = [ARCH_ROUTER, route] | |
| yield rt_state, chat_history, rt_state, gr.update(interactive=False) | |
| # Step 5: Simulate delay and complete | |
| time.sleep(WAIT_DEPARTMENT) | |
| rt_state["messages"][-2]["text"] = f"β {dept_name} completed the task." | |
| rt_state["messages"][-1]["text"] = f"β {dept_name} department has completed the task." | |
| chat_history.append((ARCH_ROUTER, f"β {dept_name} department completed the task.")) | |
| yield rt_state, chat_history, rt_state, gr.update(interactive=False) | |
| # Step 6: Reset visible bubbles | |
| rt_state["showBubbles"] = [ARCH_ROUTER] | |
| yield rt_state, chat_history, rt_state, gr.update(interactive=False) | |
| # Step 7: System ready | |
| time.sleep(WAIT_SYSTEM) | |
| rt_state["messages"].append({"speaker": ARCH_ROUTER, "text": "Arch Router is ready to discuss."}) | |
| yield rt_state, chat_history, rt_state, gr.update(interactive=True) | |
| # === Gradio UI === | |
| with gr.Blocks(title="Arch Router Simulation: Smart Department Dispatcher", theme=gr.themes.Ocean()) as demo: | |
| gr.Markdown( | |
| """ | |
| ## π§ Arch Router Simulation: Smart Department Dispatcher | |
| **This is a demo simulation of <a href="https://huggingface.co/katanemo/Arch-Router-1.5B" target="_blank">katanemo/Arch-Router-1.5B</a>.** | |
| **Kindly refer official documentation for more details** | |
| * See how Arch Router identifies the best route **(or Domain β the high-level category)** based on user prompt and take desired **Action (specific type of operation user wants to perform)** by forwarding it to respective department. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| rt_state = gr.State(init_state()) | |
| chat_state = gr.State([]) | |
| roundtable = consilium_roundtable(value=init_state()) | |
| with gr.Column(scale=1): | |
| chatbot = gr.Chatbot(label="Chat History", max_height=300) | |
| textbox = gr.Textbox(placeholder="Describe your issue...", label="Ask Arch Router") | |
| submit_btn = gr.Button("Submit") | |
| example_inputs = [ | |
| "How do I optimize this loop in Python?", | |
| "Generate a function to sort an array in python", | |
| "Help me anonymize patient health records before storing them", | |
| "I'm getting a TypeError in following code", | |
| "Do I need to include attribution for MIT-licensed software?", | |
| "How do I connect to external API from this code?" | |
| ] | |
| # Trigger submission via Enter or Button | |
| for trigger in (textbox.submit, submit_btn.click): | |
| trigger( | |
| route_and_visualize, | |
| inputs=[textbox, rt_state, chat_state], | |
| outputs=[roundtable, chatbot, rt_state, textbox], | |
| concurrency_limit=1 | |
| ) | |
| # Example block | |
| gr.Examples( | |
| examples=example_inputs, | |
| inputs=textbox, | |
| label="Try one of these examples" | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |