Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import json | |
| from typing import Iterable | |
| from moa.agent import MOAgent | |
| from moa.agent.moa import ResponseChunk | |
| from streamlit_ace import st_ace | |
| import copy | |
| # Default configuration | |
| default_config = { | |
| "main_model": "llama3-70b-8192", | |
| "cycles": 3, | |
| "layer_agent_config": {} | |
| } | |
| layer_agent_config_def = { | |
| "layer_agent_1": { | |
| "system_prompt": "Think through your response step by step. {helper_response}", | |
| "model_name": "llama3-8b-8192" | |
| }, | |
| "layer_agent_2": { | |
| "system_prompt": "Respond with a thought and then your response to the question. {helper_response}", | |
| "model_name": "gemma-7b-it", | |
| "temperature": 0.7 | |
| }, | |
| "layer_agent_3": { | |
| "system_prompt": "You are an expert at logic and reasoning. Always take a logical approach to the answer. {helper_response}", | |
| "model_name": "llama3-8b-8192" | |
| }, | |
| } | |
| # Recommended Configuration | |
| rec_config = { | |
| "main_model": "llama3-70b-8192", | |
| "cycles": 2, | |
| "layer_agent_config": {} | |
| } | |
| layer_agent_config_rec = { | |
| "layer_agent_1": { | |
| "system_prompt": "Think through your response step by step. {helper_response}", | |
| "model_name": "llama3-8b-8192", | |
| "temperature": 0.1 | |
| }, | |
| "layer_agent_2": { | |
| "system_prompt": "Respond with a thought and then your response to the question. {helper_response}", | |
| "model_name": "llama3-8b-8192", | |
| "temperature": 0.2 | |
| }, | |
| "layer_agent_3": { | |
| "system_prompt": "You are an expert at logic and reasoning. Always take a logical approach to the answer. {helper_response}", | |
| "model_name": "llama3-8b-8192", | |
| "temperature": 0.4 | |
| }, | |
| "layer_agent_4": { | |
| "system_prompt": "You are an expert planner agent. Create a plan for how to answer the human's query. {helper_response}", | |
| "model_name": "mixtral-8x7b-32768", | |
| "temperature": 0.5 | |
| }, | |
| } | |
| def stream_response(messages: Iterable[ResponseChunk]): | |
| layer_outputs = {} | |
| progress_bar = st.progress(0) | |
| total_steps = len(messages) # Estimate total messages for progress tracking | |
| current_step = 0 | |
| for message in messages: | |
| current_step += 1 | |
| progress_bar.progress(current_step / total_steps) | |
| if message['response_type'] == 'intermediate': | |
| layer = message['metadata']['layer'] | |
| if layer not in layer_outputs: | |
| layer_outputs[layer] = [] | |
| layer_outputs[layer].append(message['delta']) | |
| # Real-time rendering for intermediate outputs | |
| with st.container(): | |
| st.markdown(f"**Layer {layer} (In Progress)**") | |
| for output in layer_outputs[layer]: | |
| st.markdown(f"- {output}") | |
| else: | |
| # Finalize and display accumulated layer outputs | |
| for layer, outputs in layer_outputs.items(): | |
| st.markdown(f"### Layer {layer} Final Output") | |
| for output in outputs: | |
| st.write(output) | |
| layer_outputs = {} # Reset for next layers | |
| # Yield the main agent's output | |
| yield message['delta'] | |
| progress_bar.empty() # Clear progress bar once done | |
| def set_moa_agent( | |
| main_model: str = default_config['main_model'], | |
| cycles: int = default_config['cycles'], | |
| layer_agent_config: dict[dict[str, any]] = copy.deepcopy(layer_agent_config_def), | |
| main_model_temperature: float = 0.1, | |
| override: bool = False | |
| ): | |
| if override or ("main_model" not in st.session_state): | |
| st.session_state.main_model = main_model | |
| if override or ("cycles" not in st.session_state): | |
| st.session_state.cycles = cycles | |
| if override or ("layer_agent_config" not in st.session_state): | |
| st.session_state.layer_agent_config = layer_agent_config | |
| if override or ("main_temp" not in st.session_state): | |
| st.session_state.main_temp = main_model_temperature | |
| cls_ly_conf = copy.deepcopy(st.session_state.layer_agent_config) | |
| if override or ("moa_agent" not in st.session_state): | |
| st.session_state.moa_agent = MOAgent.from_config( | |
| main_model=st.session_state.main_model, | |
| cycles=st.session_state.cycles, | |
| layer_agent_config=cls_ly_conf, | |
| temperature=st.session_state.main_temp | |
| ) | |
| del cls_ly_conf | |
| st.set_page_config( | |
| page_title="Mixture of Agents", | |
| layout="wide", | |
| menu_items={'About': "## Mixture-of-Agents\nPowered by Groq"} | |
| ) | |
| valid_model_names = [ | |
| 'llama3-70b-8192', | |
| 'llama3-8b-8192', | |
| 'gemma-7b-it', | |
| 'gemma2-9b-it', | |
| 'mixtral-8x7b-32768' | |
| ] | |
| if "messages" not in st.session_state: | |
| st.session_state.messages = [] | |
| set_moa_agent() | |
| # Sidebar Configuration | |
| with st.sidebar: | |
| st.title("MOA Configuration") | |
| with st.form("Agent Configuration", clear_on_submit=False): | |
| if st.form_submit_button("Use Recommended Config"): | |
| set_moa_agent( | |
| main_model=rec_config['main_model'], | |
| cycles=rec_config['cycles'], | |
| layer_agent_config=layer_agent_config_rec, | |
| override=True | |
| ) | |
| st.session_state.messages = [] | |
| st.success("Configuration updated successfully!") | |
| # Config toggling | |
| show_advanced = st.checkbox("Show Advanced Configurations") | |
| if show_advanced: | |
| new_main_model = st.selectbox( | |
| "Main Model", | |
| valid_model_names, | |
| index=valid_model_names.index(st.session_state.main_model) | |
| ) | |
| new_cycles = st.number_input( | |
| "Number of Layers", | |
| min_value=1, | |
| max_value=10, | |
| value=st.session_state.cycles | |
| ) | |
| main_temperature = st.slider( | |
| "Main Model Temperature", | |
| min_value=0.0, | |
| max_value=1.0, | |
| value=st.session_state.main_temp, | |
| step=0.05 | |
| ) | |
| new_layer_agent_config = st_ace( | |
| value=json.dumps(st.session_state.layer_agent_config, indent=2), | |
| language="json", | |
| show_gutter=False, | |
| wrap=True, | |
| auto_update=True | |
| ) | |
| if st.form_submit_button("Update Config"): | |
| try: | |
| parsed_config = json.loads(new_layer_agent_config) | |
| set_moa_agent( | |
| main_model=new_main_model, | |
| cycles=new_cycles, | |
| layer_agent_config=parsed_config, | |
| main_model_temperature=main_temperature, | |
| override=True | |
| ) | |
| st.session_state.messages = [] | |
| st.success("Configuration updated successfully!") | |
| except json.JSONDecodeError: | |
| st.error("Invalid JSON in Layer Agent Config.") | |
| except Exception as e: | |
| st.error(f"Error updating config: {str(e)}") | |
| # Main app layout | |
| st.header("Mixture of Agents") | |
| st.markdown("Real-time response tracking with intermediate and final results.") | |
| with st.expander("Current MOA Configuration", expanded=False): | |
| st.json(st.session_state.layer_agent_config) | |
| # Chat interface | |
| for message in st.session_state.messages: | |
| with st.chat_message(message["role"]): | |
| st.markdown(message["content"]) | |
| if query := st.chat_input("Ask a question"): | |
| st.session_state.messages.append({"role": "user", "content": query}) | |
| with st.chat_message("user"): | |
| st.markdown(query) | |
| moa_agent: MOAgent = st.session_state.moa_agent | |
| with st.chat_message("assistant"): | |
| message_placeholder = st.empty() | |
| ast_mess = stream_response(moa_agent.chat(query, output_format="json")) | |
| response = st.write_stream(ast_mess) | |
| st.session_state.messages.append({"role": "assistant", "content": response}) | |
| st.markdown("---") | |
| st.markdown("Powered by [Groq](https://groq.com).") | |