Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from huggingface_hub import InferenceClient | |
| # Define available models (update with your actual model IDs) | |
| model_list = { | |
| "Safe LM": "HuggingFaceH4/zephyr-7b-beta", | |
| "Baseline 1": "HuggingFaceH4/zephyr-7b-beta", | |
| "Another Model": "HuggingFaceH4/zephyr-7b-beta", | |
| "LLaMA3.2-1B": "meta-llama/Llama-3.2-1B-Instruct", | |
| "Mix IFT V2 - Score0 Rephrased": "locuslab/mix_ift_v2-smollm2-360m-smollm2-360m-score0_mix_rephrased_from_beginning-300B", | |
| "Mix IFT V2 - Score0 Only": "locuslab/mix_ift_v2-smollm2-360m-smollm2-360m-score0_only-300B", | |
| "Mix IFT V2 - All Raw Folders Metadata": "locuslab/mix_ift_v2-smollm2-360m-smollm2-360m-all_raw_folders_metadata-300B", | |
| "Mix IFT V2 - All Raw Folders Baseline": "locuslab/mix_ift_v2-smollm2-360m-smollm2-360m-all_raw_folders_baseline-300B", | |
| "Mix IFT V2 - Score0 Only MBS16 GBS1024": "locuslab/mix_ift_v2-smollm2-360m-smollm2-360m-score0_only-300B-mbs16-gbs1024-16feb-lr2e-05-gbs16" | |
| } | |
| # Dictionary to track which models support chat completion vs. text generation | |
| model_tasks = { | |
| "HuggingFaceH4/zephyr-7b-beta": "chat-completion", # This model supports chat completion | |
| # Add other models that support chat completion | |
| } | |
| # Default to text-generation for models not specified above | |
| def respond(message, history, system_message, max_tokens, temperature, top_p, selected_model): | |
| try: | |
| # Get the model ID for the selected model | |
| model_id = model_list.get(selected_model, "HuggingFaceH4/zephyr-7b-beta") | |
| # Create an InferenceClient for the selected model | |
| client = InferenceClient(model_id) | |
| # Always use text generation for locuslab models | |
| if "locuslab" in model_id: | |
| # Format the prompt manually for text generation | |
| # Simple formatting that works with most models | |
| formatted_prompt = "" | |
| # Add minimal formatting for better results with research models | |
| if len(history) > 0: | |
| # Include minimal context from history | |
| last_exchanges = history[-1:] # Just use the last exchange | |
| for user_msg, assistant_msg in last_exchanges: | |
| if user_msg: | |
| formatted_prompt += f"{user_msg}\n" | |
| # Add current message - keep it simple | |
| formatted_prompt += f"{message}" | |
| response = "" | |
| # Use text generation instead of chat completion | |
| print(f"Using text generation with prompt: {formatted_prompt}") | |
| for token in client.text_generation( | |
| formatted_prompt, | |
| max_new_tokens=max_tokens, | |
| stream=True, | |
| temperature=temperature, | |
| top_p=top_p, | |
| do_sample=True # Enable sampling for more creative responses | |
| ): | |
| response += token | |
| yield response | |
| else: | |
| # Try chat completion for standard models | |
| try: | |
| messages = [{"role": "system", "content": system_message}] | |
| for user_msg, assistant_msg in history: | |
| if user_msg: # Only add non-empty messages | |
| messages.append({"role": "user", "content": user_msg}) | |
| if assistant_msg: # Only add non-empty messages | |
| messages.append({"role": "assistant", "content": assistant_msg}) | |
| messages.append({"role": "user", "content": message}) | |
| response = "" | |
| # Stream the response from the client | |
| for token_message in client.chat_completion( | |
| messages, | |
| max_tokens=max_tokens, | |
| stream=True, | |
| temperature=temperature, | |
| top_p=top_p, | |
| ): | |
| # Safe extraction of token with error handling | |
| try: | |
| token = token_message.choices[0].delta.content | |
| if token is not None: # Handle potential None values | |
| response += token | |
| yield response | |
| except (AttributeError, IndexError) as e: | |
| # Handle cases where token structure might be different | |
| print(f"Error extracting token: {e}") | |
| continue | |
| except Exception as e: | |
| # If chat completion fails, fall back to text generation | |
| print(f"Chat completion failed: {e}. Falling back to text generation.") | |
| formatted_prompt = f"{system_message}\n\n" | |
| for user_msg, assistant_msg in history: | |
| if user_msg: | |
| formatted_prompt += f"User: {user_msg}\n" | |
| if assistant_msg: | |
| formatted_prompt += f"Assistant: {assistant_msg}\n" | |
| formatted_prompt += f"User: {message}\nAssistant:" | |
| response = "" | |
| # Use text generation instead of chat completion | |
| for token in client.text_generation( | |
| formatted_prompt, | |
| max_new_tokens=max_tokens, | |
| stream=True, | |
| temperature=temperature, | |
| top_p=top_p, | |
| ): | |
| response += token | |
| yield response | |
| except Exception as e: | |
| # Return detailed error message if the model call fails | |
| error_message = str(e) | |
| print(f"Error calling model API: {error_message}") | |
| yield f"Error: {error_message}. Please try a different model or adjust parameters." | |
| # Custom CSS for styling | |
| css = """ | |
| body { | |
| background-color: #f0f5fb; /* Light pastel blue background */ | |
| } | |
| .gradio-container { | |
| background-color: white; | |
| border-radius: 16px; | |
| box-shadow: 0 2px 10px rgba(0,0,0,0.05); | |
| max-width: 90%; | |
| margin: 15px auto; | |
| padding-bottom: 20px; | |
| } | |
| /* Header styling with diagonal shield */ | |
| .app-header { | |
| position: relative; | |
| overflow: hidden; | |
| } | |
| .app-header::before { | |
| content: "🛡️"; | |
| position: absolute; | |
| font-size: 100px; | |
| opacity: 0.1; | |
| right: -20px; | |
| top: -30px; | |
| transform: rotate(15deg); | |
| pointer-events: none; | |
| } | |
| /* Simple styling for buttons */ | |
| #send-btn { | |
| background-color: white !important; | |
| color: #333 !important; | |
| border: 2px solid #e6c200 !important; | |
| } | |
| #send-btn:hover { | |
| background-color: #fff9e6 !important; | |
| } | |
| #clear-btn { | |
| background-color: white !important; | |
| color: #333 !important; | |
| border: 2px solid #e6c200 !important; | |
| } | |
| #clear-btn:hover { | |
| background-color: #fff9e6 !important; | |
| } | |
| /* Hide elements */ | |
| footer { | |
| display: none !important; | |
| } | |
| .footer { | |
| display: none !important; | |
| } | |
| """ | |
| with gr.Blocks(css=css) as demo: | |
| # Custom header with branding | |
| gr.HTML(""" | |
| <div class="app-header" style="background: linear-gradient(135deg, #4a90e2, #75c6ef); padding: 15px; border-radius: 16px 16px 0 0; color: white; border-bottom: 3px solid #e6c200;"> | |
| <h1 style="font-size: 32px; font-weight: 600; margin: 0; display: flex; align-items: center; font-family: 'Palatino', serif;"> | |
| <span style="margin-right: 10px; font-size: 32px;">🛡️</span> | |
| <span style="font-weight: 700; margin-right: 1px;">Safe</span> | |
| <span style="font-weight: 400; letter-spacing: 1px;">Playground</span> | |
| </h1> | |
| </div> | |
| """) | |
| # Status message for API errors | |
| status_message = gr.Markdown("", elem_id="status-message") | |
| with gr.Row(): | |
| # Left sidebar: Model selector | |
| with gr.Column(scale=1): | |
| gr.Markdown("## Models") | |
| model_dropdown = gr.Dropdown( | |
| choices=list(model_list.keys()), | |
| label="Select Model", | |
| value="Safe LM", | |
| elem_classes=["model-select"] | |
| ) | |
| # Settings | |
| gr.Markdown("### Settings") | |
| system_message = gr.Textbox( | |
| value="You are a friendly and safe assistant.", | |
| label="System Message", | |
| lines=2 | |
| ) | |
| max_tokens_slider = gr.Slider( | |
| minimum=1, maximum=2048, value=100, step=1, | |
| label="Max New Tokens" | |
| ) | |
| temperature_slider = gr.Slider( | |
| minimum=0.1, maximum=4.0, value=0.7, step=0.1, | |
| label="Temperature" | |
| ) | |
| top_p_slider = gr.Slider( | |
| minimum=0.1, maximum=1.0, value=0.95, step=0.05, | |
| label="Top-p (nucleus sampling)" | |
| ) | |
| # Main area: Chat interface | |
| with gr.Column(scale=3): | |
| chatbot = gr.Chatbot( | |
| label="Conversation", | |
| show_label=True, | |
| height=400 | |
| ) | |
| with gr.Row(): | |
| user_input = gr.Textbox( | |
| placeholder="Type your message here...", | |
| label="Your Message", | |
| show_label=False, | |
| scale=9 | |
| ) | |
| send_button = gr.Button( | |
| "Send", | |
| scale=1, | |
| elem_id="send-btn" | |
| ) | |
| with gr.Row(): | |
| clear_button = gr.Button("Clear Chat", elem_id="clear-btn") | |
| # Define functions for chatbot interactions | |
| def user(user_message, history): | |
| # Add emoji to user message | |
| user_message_with_emoji = f"👤 {user_message}" | |
| return "", history + [[user_message_with_emoji, None]] | |
| def bot(history, system_message, max_tokens, temperature, top_p, selected_model): | |
| # Ensure there's history | |
| if not history or len(history) == 0: | |
| return history | |
| # Get the last user message from history | |
| user_message = history[-1][0] | |
| # Remove emoji for processing if present | |
| if user_message.startswith("👤 "): | |
| user_message = user_message[2:].strip() | |
| # Process previous history to clean emojis | |
| clean_history = [] | |
| for h_user, h_bot in history[:-1]: | |
| if h_user and h_user.startswith("👤 "): | |
| h_user = h_user[2:].strip() | |
| if h_bot and h_bot.startswith("🛡️ "): | |
| h_bot = h_bot[2:].strip() | |
| clean_history.append([h_user, h_bot]) | |
| # Call respond function with the message | |
| response_generator = respond( | |
| user_message, | |
| clean_history, # Pass clean history | |
| system_message, | |
| max_tokens, | |
| temperature, | |
| top_p, | |
| selected_model | |
| ) | |
| # Update history as responses come in, adding emoji | |
| for response in response_generator: | |
| history[-1][1] = f"🛡️ {response}" | |
| yield history | |
| # Wire up the event chain | |
| user_input.submit( | |
| user, | |
| [user_input, chatbot], | |
| [user_input, chatbot], | |
| queue=False | |
| ).then( | |
| bot, | |
| [chatbot, system_message, max_tokens_slider, temperature_slider, top_p_slider, model_dropdown], | |
| [chatbot], | |
| queue=True | |
| ) | |
| send_button.click( | |
| user, | |
| [user_input, chatbot], | |
| [user_input, chatbot], | |
| queue=False | |
| ).then( | |
| bot, | |
| [chatbot, system_message, max_tokens_slider, temperature_slider, top_p_slider, model_dropdown], | |
| [chatbot], | |
| queue=True | |
| ) | |
| # Clear the chat history | |
| def clear_history(): | |
| return [] | |
| clear_button.click(clear_history, None, chatbot, queue=False) | |
| if __name__ == "__main__": | |
| demo.launch() |