safe-playground / app.py
saching0071's picture
Update app.py
43e0eaa
raw
history blame
12 kB
import gradio as gr
from huggingface_hub import InferenceClient
# Define available models (update with your actual model IDs)
model_list = {
"Safe LM": "HuggingFaceH4/zephyr-7b-beta",
"Baseline 1": "HuggingFaceH4/zephyr-7b-beta",
"Another Model": "HuggingFaceH4/zephyr-7b-beta",
"LLaMA3.2-1B": "meta-llama/Llama-3.2-1B-Instruct",
"Mix IFT V2 - Score0 Rephrased": "locuslab/mix_ift_v2-smollm2-360m-smollm2-360m-score0_mix_rephrased_from_beginning-300B",
"Mix IFT V2 - Score0 Only": "locuslab/mix_ift_v2-smollm2-360m-smollm2-360m-score0_only-300B",
"Mix IFT V2 - All Raw Folders Metadata": "locuslab/mix_ift_v2-smollm2-360m-smollm2-360m-all_raw_folders_metadata-300B",
"Mix IFT V2 - All Raw Folders Baseline": "locuslab/mix_ift_v2-smollm2-360m-smollm2-360m-all_raw_folders_baseline-300B",
"Mix IFT V2 - Score0 Only MBS16 GBS1024": "locuslab/mix_ift_v2-smollm2-360m-smollm2-360m-score0_only-300B-mbs16-gbs1024-16feb-lr2e-05-gbs16"
}
# Dictionary to track which models support chat completion vs. text generation
model_tasks = {
"HuggingFaceH4/zephyr-7b-beta": "chat-completion", # This model supports chat completion
# Add other models that support chat completion
}
# Default to text-generation for models not specified above
def respond(message, history, system_message, max_tokens, temperature, top_p, selected_model):
try:
# Get the model ID for the selected model
model_id = model_list.get(selected_model, "HuggingFaceH4/zephyr-7b-beta")
# Create an InferenceClient for the selected model
client = InferenceClient(model_id)
# Always use text generation for locuslab models
if "locuslab" in model_id:
# Format the prompt manually for text generation
# Simple formatting that works with most models
formatted_prompt = ""
# Add minimal formatting for better results with research models
if len(history) > 0:
# Include minimal context from history
last_exchanges = history[-1:] # Just use the last exchange
for user_msg, assistant_msg in last_exchanges:
if user_msg:
formatted_prompt += f"{user_msg}\n"
# Add current message - keep it simple
formatted_prompt += f"{message}"
response = ""
# Use text generation instead of chat completion
print(f"Using text generation with prompt: {formatted_prompt}")
for token in client.text_generation(
formatted_prompt,
max_new_tokens=max_tokens,
stream=True,
temperature=temperature,
top_p=top_p,
do_sample=True # Enable sampling for more creative responses
):
response += token
yield response
else:
# Try chat completion for standard models
try:
messages = [{"role": "system", "content": system_message}]
for user_msg, assistant_msg in history:
if user_msg: # Only add non-empty messages
messages.append({"role": "user", "content": user_msg})
if assistant_msg: # Only add non-empty messages
messages.append({"role": "assistant", "content": assistant_msg})
messages.append({"role": "user", "content": message})
response = ""
# Stream the response from the client
for token_message in client.chat_completion(
messages,
max_tokens=max_tokens,
stream=True,
temperature=temperature,
top_p=top_p,
):
# Safe extraction of token with error handling
try:
token = token_message.choices[0].delta.content
if token is not None: # Handle potential None values
response += token
yield response
except (AttributeError, IndexError) as e:
# Handle cases where token structure might be different
print(f"Error extracting token: {e}")
continue
except Exception as e:
# If chat completion fails, fall back to text generation
print(f"Chat completion failed: {e}. Falling back to text generation.")
formatted_prompt = f"{system_message}\n\n"
for user_msg, assistant_msg in history:
if user_msg:
formatted_prompt += f"User: {user_msg}\n"
if assistant_msg:
formatted_prompt += f"Assistant: {assistant_msg}\n"
formatted_prompt += f"User: {message}\nAssistant:"
response = ""
# Use text generation instead of chat completion
for token in client.text_generation(
formatted_prompt,
max_new_tokens=max_tokens,
stream=True,
temperature=temperature,
top_p=top_p,
):
response += token
yield response
except Exception as e:
# Return detailed error message if the model call fails
error_message = str(e)
print(f"Error calling model API: {error_message}")
yield f"Error: {error_message}. Please try a different model or adjust parameters."
# Custom CSS for styling
css = """
body {
background-color: #f0f5fb; /* Light pastel blue background */
}
.gradio-container {
background-color: white;
border-radius: 16px;
box-shadow: 0 2px 10px rgba(0,0,0,0.05);
max-width: 90%;
margin: 15px auto;
padding-bottom: 20px;
}
/* Header styling with diagonal shield */
.app-header {
position: relative;
overflow: hidden;
}
.app-header::before {
content: "🛡️";
position: absolute;
font-size: 100px;
opacity: 0.1;
right: -20px;
top: -30px;
transform: rotate(15deg);
pointer-events: none;
}
/* Simple styling for buttons */
#send-btn {
background-color: white !important;
color: #333 !important;
border: 2px solid #e6c200 !important;
}
#send-btn:hover {
background-color: #fff9e6 !important;
}
#clear-btn {
background-color: white !important;
color: #333 !important;
border: 2px solid #e6c200 !important;
}
#clear-btn:hover {
background-color: #fff9e6 !important;
}
/* Hide elements */
footer {
display: none !important;
}
.footer {
display: none !important;
}
"""
with gr.Blocks(css=css) as demo:
# Custom header with branding
gr.HTML("""
<div class="app-header" style="background: linear-gradient(135deg, #4a90e2, #75c6ef); padding: 15px; border-radius: 16px 16px 0 0; color: white; border-bottom: 3px solid #e6c200;">
<h1 style="font-size: 32px; font-weight: 600; margin: 0; display: flex; align-items: center; font-family: 'Palatino', serif;">
<span style="margin-right: 10px; font-size: 32px;">🛡️</span>
<span style="font-weight: 700; margin-right: 1px;">Safe</span>
<span style="font-weight: 400; letter-spacing: 1px;">Playground</span>
</h1>
</div>
""")
# Status message for API errors
status_message = gr.Markdown("", elem_id="status-message")
with gr.Row():
# Left sidebar: Model selector
with gr.Column(scale=1):
gr.Markdown("## Models")
model_dropdown = gr.Dropdown(
choices=list(model_list.keys()),
label="Select Model",
value="Safe LM",
elem_classes=["model-select"]
)
# Settings
gr.Markdown("### Settings")
system_message = gr.Textbox(
value="You are a friendly and safe assistant.",
label="System Message",
lines=2
)
max_tokens_slider = gr.Slider(
minimum=1, maximum=2048, value=100, step=1,
label="Max New Tokens"
)
temperature_slider = gr.Slider(
minimum=0.1, maximum=4.0, value=0.7, step=0.1,
label="Temperature"
)
top_p_slider = gr.Slider(
minimum=0.1, maximum=1.0, value=0.95, step=0.05,
label="Top-p (nucleus sampling)"
)
# Main area: Chat interface
with gr.Column(scale=3):
chatbot = gr.Chatbot(
label="Conversation",
show_label=True,
height=400
)
with gr.Row():
user_input = gr.Textbox(
placeholder="Type your message here...",
label="Your Message",
show_label=False,
scale=9
)
send_button = gr.Button(
"Send",
scale=1,
elem_id="send-btn"
)
with gr.Row():
clear_button = gr.Button("Clear Chat", elem_id="clear-btn")
# Define functions for chatbot interactions
def user(user_message, history):
# Add emoji to user message
user_message_with_emoji = f"👤 {user_message}"
return "", history + [[user_message_with_emoji, None]]
def bot(history, system_message, max_tokens, temperature, top_p, selected_model):
# Ensure there's history
if not history or len(history) == 0:
return history
# Get the last user message from history
user_message = history[-1][0]
# Remove emoji for processing if present
if user_message.startswith("👤 "):
user_message = user_message[2:].strip()
# Process previous history to clean emojis
clean_history = []
for h_user, h_bot in history[:-1]:
if h_user and h_user.startswith("👤 "):
h_user = h_user[2:].strip()
if h_bot and h_bot.startswith("🛡️ "):
h_bot = h_bot[2:].strip()
clean_history.append([h_user, h_bot])
# Call respond function with the message
response_generator = respond(
user_message,
clean_history, # Pass clean history
system_message,
max_tokens,
temperature,
top_p,
selected_model
)
# Update history as responses come in, adding emoji
for response in response_generator:
history[-1][1] = f"🛡️ {response}"
yield history
# Wire up the event chain
user_input.submit(
user,
[user_input, chatbot],
[user_input, chatbot],
queue=False
).then(
bot,
[chatbot, system_message, max_tokens_slider, temperature_slider, top_p_slider, model_dropdown],
[chatbot],
queue=True
)
send_button.click(
user,
[user_input, chatbot],
[user_input, chatbot],
queue=False
).then(
bot,
[chatbot, system_message, max_tokens_slider, temperature_slider, top_p_slider, model_dropdown],
[chatbot],
queue=True
)
# Clear the chat history
def clear_history():
return []
clear_button.click(clear_history, None, chatbot, queue=False)
if __name__ == "__main__":
demo.launch()