Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,9 +1,31 @@
|
|
| 1 |
import spaces
|
| 2 |
import gradio as gr
|
| 3 |
-
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 4 |
import torch
|
| 5 |
from datetime import datetime
|
| 6 |
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
# --- Configuration ---
|
| 9 |
# Updated model ID
|
|
@@ -13,20 +35,20 @@ model_link = f"https://huggingface.co/{model_id}"
|
|
| 13 |
website_link = "https://tesslate.com"
|
| 14 |
discord_link = "https://discord.gg/DkzMzwBTaw"
|
| 15 |
|
| 16 |
-
# --- Text Content ---
|
| 17 |
Title = f"""
|
| 18 |
<div style="text-align: center; margin-bottom: 20px;">
|
| 19 |
<img src="https://huggingface.co/Tesslate/Tessa-T1-14B/resolve/main/tesslate_logo_color.png?download=true" alt="Tesslate Logo" style="height: 80px; margin-bottom: 10px;">
|
| 20 |
<h1 style="margin-bottom: 5px;">π Welcome to the Tessa-T1-14B Demo π</h1>
|
| 21 |
<p style="font-size: 1.1em;">Experience the power of specialized React reasoning!</p>
|
| 22 |
-
<p>Model by <a href="{creator_link}" target="_blank">TesslateAI</a> | <a href="{model_link}" target="_blank">View on Hugging Face</a
|
| 23 |
</div>
|
| 24 |
"""
|
| 25 |
|
| 26 |
description = f"""
|
| 27 |
Interact with **[{model_id}]({model_link})**, an innovative 14B parameter transformer model fine-tuned from Qwen2.5-Coder-14B-Instruct.
|
| 28 |
Tessa-T1 specializes in **React frontend development**, leveraging advanced reasoning to autonomously generate well-structured, semantic React components.
|
| 29 |
-
|
| 30 |
"""
|
| 31 |
|
| 32 |
about_tesslate = f"""
|
|
@@ -72,16 +94,16 @@ join_us = f"""
|
|
| 72 |
# --- Model and Tokenizer Loading ---
|
| 73 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 74 |
print(f"Using device: {device}")
|
|
|
|
|
|
|
|
|
|
| 75 |
|
| 76 |
# Get the token from environment variables
|
| 77 |
hf_token = os.getenv('HF_TOKEN') # Standard env var name for HF token
|
| 78 |
if not hf_token:
|
| 79 |
-
# Try to load from Hugging Face login if available, otherwise raise error
|
| 80 |
try:
|
| 81 |
-
|
| 82 |
-
hf_token = HfFolder.get_token() # Use HfFolder to get token saved by login
|
| 83 |
if not hf_token:
|
| 84 |
-
# If still not found, try HfApi (less common for user login token)
|
| 85 |
hf_token = HfApi().token
|
| 86 |
if not hf_token:
|
| 87 |
raise ValueError("HF token not found. Please set HF_TOKEN env var or login via `huggingface-cli login`.")
|
|
@@ -92,28 +114,65 @@ if not hf_token:
|
|
| 92 |
raise ValueError(f"HF token acquisition failed. Please set the HF_TOKEN environment variable or login via `huggingface-cli login`. Error: {e}")
|
| 93 |
|
| 94 |
print(f"Loading Tokenizer: {model_id}")
|
| 95 |
-
# Initialize tokenizer and model with token authentication
|
| 96 |
-
# trust_remote_code=True is necessary for models with custom code (like Qwen2)
|
| 97 |
tokenizer = AutoTokenizer.from_pretrained(
|
| 98 |
model_id,
|
| 99 |
token=hf_token,
|
| 100 |
trust_remote_code=True
|
| 101 |
)
|
| 102 |
|
| 103 |
-
print(f"Loading Model: {model_id}")
|
| 104 |
-
#
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
)
|
| 112 |
-
print("Model loaded successfully.")
|
| 113 |
|
| 114 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
try:
|
| 116 |
config_json = model.config.to_dict()
|
|
|
|
|
|
|
| 117 |
model_config_info = f"""
|
| 118 |
**Model Type:** {config_json.get('model_type', 'N/A')}
|
| 119 |
**Architecture:** {config_json.get('architectures', ['N/A'])[0]}
|
|
@@ -122,13 +181,16 @@ try:
|
|
| 122 |
**Num Hidden Layers:** {config_json.get('num_hidden_layers', 'N/A')}
|
| 123 |
**Num Attention Heads:** {config_json.get('num_attention_heads', 'N/A')}
|
| 124 |
**Max Position Embeddings:** {config_json.get('max_position_embeddings', 'N/A')}
|
| 125 |
-
**
|
|
|
|
| 126 |
"""
|
| 127 |
except Exception as e:
|
| 128 |
-
print(f"Could not retrieve model config: {e}")
|
| 129 |
-
model_config_info = f"**Error:** Could not load config for {model_id}.
|
|
|
|
| 130 |
|
| 131 |
# --- Helper Function for Tokenizer Info ---
|
|
|
|
| 132 |
def format_tokenizer_info(tokenizer_instance):
|
| 133 |
try:
|
| 134 |
info = [
|
|
@@ -152,76 +214,69 @@ def format_tokenizer_info(tokenizer_instance):
|
|
| 152 |
|
| 153 |
tokenizer_info = format_tokenizer_info(tokenizer)
|
| 154 |
|
|
|
|
| 155 |
# --- Generation Function ---
|
| 156 |
-
@spaces.GPU(duration=180) #
|
| 157 |
def generate_response(system_prompt, user_prompt, temperature, max_new_tokens, top_p, repetition_penalty, top_k, min_p):
|
| 158 |
-
#
|
| 159 |
-
#
|
| 160 |
-
#
|
| 161 |
|
| 162 |
-
# Use the tokenizer's chat template (Recommended for Qwen2 based models)
|
| 163 |
messages = []
|
| 164 |
if system_prompt and system_prompt.strip():
|
| 165 |
-
# Qwen2 template might prefer system prompt directly or integrated differently.
|
| 166 |
-
# Using the standard 'system' role here, assuming tokenizer handles it.
|
| 167 |
messages.append({"role": "system", "content": system_prompt})
|
| 168 |
messages.append({"role": "user", "content": user_prompt})
|
| 169 |
|
| 170 |
try:
|
| 171 |
-
# Let the tokenizer handle the template - crucial for models like Qwen2
|
| 172 |
full_prompt = tokenizer.apply_chat_template(
|
| 173 |
messages,
|
| 174 |
tokenize=False,
|
| 175 |
-
add_generation_prompt=True
|
| 176 |
)
|
| 177 |
-
print("Applied tokenizer's chat template.")
|
| 178 |
except Exception as e:
|
| 179 |
-
# Fallback only if template application fails catastrophically
|
| 180 |
print(f"Warning: Could not use apply_chat_template (Error: {e}). Falling back to basic format. This might degrade performance.")
|
| 181 |
prompt_parts = []
|
| 182 |
if system_prompt and system_prompt.strip():
|
| 183 |
prompt_parts.append(f"System: {system_prompt}")
|
| 184 |
prompt_parts.append(f"\nUser: {user_prompt}")
|
| 185 |
-
prompt_parts.append("\nAssistant:")
|
| 186 |
full_prompt = "\n".join(prompt_parts)
|
| 187 |
|
| 188 |
-
print(f"\n--- Generating ---")
|
| 189 |
-
# print(f"Prompt:\n{full_prompt}")
|
| 190 |
-
print(f"Params: Temp={temperature}, TopK={top_k}, TopP={top_p}, RepPen={repetition_penalty}, MaxNew={max_new_tokens}, MinP={min_p} (MinP ignored
|
| 191 |
-
print("-" * 20)
|
| 192 |
|
| 193 |
-
inputs
|
|
|
|
|
|
|
| 194 |
|
| 195 |
-
# Generation arguments
|
| 196 |
generation_kwargs = dict(
|
| 197 |
**inputs,
|
| 198 |
max_new_tokens=int(max_new_tokens),
|
| 199 |
-
temperature=float(temperature) if float(temperature) > 0 else None,
|
| 200 |
top_p=float(top_p),
|
| 201 |
top_k=int(top_k),
|
| 202 |
repetition_penalty=float(repetition_penalty),
|
| 203 |
do_sample=True if float(temperature) > 0 else False,
|
| 204 |
-
pad_token_id=tokenizer.eos_token_id,
|
| 205 |
eos_token_id=tokenizer.eos_token_id
|
| 206 |
-
# min_p cannot be directly passed here.
|
| 207 |
)
|
| 208 |
|
| 209 |
-
if temperature == 0:
|
| 210 |
generation_kwargs.pop('top_p', None)
|
| 211 |
generation_kwargs.pop('top_k', None)
|
| 212 |
generation_kwargs['do_sample'] = False
|
| 213 |
|
| 214 |
-
|
| 215 |
-
# Generate response
|
| 216 |
with torch.inference_mode():
|
| 217 |
outputs = model.generate(**generation_kwargs)
|
| 218 |
|
| 219 |
-
# Decode response, skipping special tokens and the input prompt part
|
| 220 |
input_length = inputs['input_ids'].shape[1]
|
| 221 |
generated_tokens = outputs[0][input_length:]
|
| 222 |
response = tokenizer.decode(generated_tokens, skip_special_tokens=True)
|
| 223 |
|
| 224 |
-
print(f"--- Response ---\n{response}\n---------------\n")
|
| 225 |
return response.strip()
|
| 226 |
|
| 227 |
# --- Gradio Interface ---
|
|
@@ -231,7 +286,6 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="sky"), cs
|
|
| 231 |
|
| 232 |
with gr.Row():
|
| 233 |
with gr.Column(scale=3):
|
| 234 |
-
# Main Interaction Area
|
| 235 |
with gr.Group():
|
| 236 |
system_prompt = gr.Textbox(
|
| 237 |
label="System Prompt (Persona & Instructions)",
|
|
@@ -247,6 +301,7 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="sky"), cs
|
|
| 247 |
|
| 248 |
with gr.Accordion("π οΈ Generation Parameters", open=True):
|
| 249 |
with gr.Row():
|
|
|
|
| 250 |
temperature = gr.Slider(minimum=0.0, maximum=2.0, value=0.7, step=0.05, label="π‘οΈ Temperature", info="Controls randomness. 0 = deterministic, >0 = random.")
|
| 251 |
max_new_tokens = gr.Slider(minimum=64, maximum=4096, value=1024, step=32, label="π Max New Tokens", info="Max length of the generated response.")
|
| 252 |
with gr.Row():
|
|
@@ -254,29 +309,27 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="sky"), cs
|
|
| 254 |
top_p = gr.Slider(minimum=0.05, maximum=1.0, value=0.95, step=0.01, label="π
Top-p (nucleus)", info="Sample from tokens with cumulative probability >= top_p.")
|
| 255 |
with gr.Row():
|
| 256 |
repetition_penalty = gr.Slider(minimum=1.0, maximum=2.0, value=1.1, step=0.01, label="π¦ Repetition Penalty", info="Penalizes repeating tokens ( > 1).")
|
| 257 |
-
# Add min_p slider, but note it's not used in backend currently
|
| 258 |
min_p = gr.Slider(minimum=0.0, maximum=0.5, value=0.05, step=0.01, label="π Min-p (Not Active)", info="Filters tokens below this probability threshold (Requires custom logic - currently ignored).")
|
| 259 |
|
| 260 |
generate_btn = gr.Button("π Generate Response", variant="primary", size="lg")
|
| 261 |
|
| 262 |
with gr.Column(scale=2):
|
| 263 |
-
|
|
|
|
| 264 |
output = gr.Code(
|
| 265 |
-
label=f"π Tessa-T1-14B Output",
|
| 266 |
-
language="markdown",
|
| 267 |
lines=25,
|
| 268 |
-
show_copy_button=True,
|
| 269 |
)
|
| 270 |
|
| 271 |
-
# Model & Tokenizer Info in an Accordion
|
| 272 |
with gr.Accordion("βοΈ Model & Tokenizer Details", open=False):
|
| 273 |
gr.Markdown("### Model Configuration")
|
| 274 |
-
gr.Markdown(model_config_info)
|
| 275 |
gr.Markdown("---")
|
| 276 |
gr.Markdown("### Tokenizer Configuration")
|
| 277 |
gr.Markdown(tokenizer_info)
|
| 278 |
|
| 279 |
-
|
| 280 |
# About Tesslate Section
|
| 281 |
with gr.Row():
|
| 282 |
with gr.Accordion("π‘ About Tesslate & Our Mission", open=False):
|
|
@@ -285,25 +338,19 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="sky"), cs
|
|
| 285 |
# Links Section
|
| 286 |
gr.Markdown(join_us)
|
| 287 |
|
| 288 |
-
# Examples (
|
| 289 |
gr.Examples(
|
| 290 |
examples=[
|
| 291 |
-
# [system_prompt, user_prompt, temperature, max_tokens, top_p, rep_penalty, top_k, min_p]
|
| 292 |
[
|
| 293 |
"You are Tessa, an expert AI assistant specialized in React development.",
|
| 294 |
"Create a simple React functional component for a button that alerts 'Hello!' when clicked.",
|
| 295 |
-
0.
|
| 296 |
],
|
| 297 |
[
|
| 298 |
"You are Tessa, an expert AI assistant specialized in React development.",
|
| 299 |
"Explain the difference between `useState` and `useEffect` hooks in React with simple examples.",
|
| 300 |
0.7, 1024, 0.95, 1.1, 40, 0.05
|
| 301 |
],
|
| 302 |
-
[
|
| 303 |
-
"You are a helpful AI assistant.",
|
| 304 |
-
"Write a short explanation of how React's reconciliation algorithm works.",
|
| 305 |
-
0.6, 768, 0.9, 1.15, 50, 0.05
|
| 306 |
-
],
|
| 307 |
[
|
| 308 |
"You are Tessa, an expert AI assistant specialized in React development. Use Tailwind CSS for styling.",
|
| 309 |
"Generate a React component for a responsive card with an image, title, and description, using Tailwind CSS classes.",
|
|
@@ -312,7 +359,7 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="sky"), cs
|
|
| 312 |
[
|
| 313 |
"You are a helpful AI assistant.",
|
| 314 |
"What are the pros and cons of using Next.js compared to Create React App?",
|
| 315 |
-
0.8, 1024, 0.98, 1.05, 60, 0.05
|
| 316 |
]
|
| 317 |
],
|
| 318 |
inputs=[
|
|
@@ -323,7 +370,7 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="sky"), cs
|
|
| 323 |
top_p,
|
| 324 |
repetition_penalty,
|
| 325 |
top_k,
|
| 326 |
-
min_p
|
| 327 |
],
|
| 328 |
outputs=output,
|
| 329 |
label="β¨ Example Prompts (Click to Load)"
|
|
@@ -339,6 +386,7 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="sky"), cs
|
|
| 339 |
|
| 340 |
# Launch the demo
|
| 341 |
if __name__ == "__main__":
|
| 342 |
-
#
|
| 343 |
-
#
|
|
|
|
| 344 |
demo.queue().launch(debug=True, share=False) # Set share=True if deploying on HF Spaces
|
|
|
|
| 1 |
import spaces
|
| 2 |
import gradio as gr
|
|
|
|
| 3 |
import torch
|
| 4 |
from datetime import datetime
|
| 5 |
import os
|
| 6 |
+
import subprocess # For Flash Attention install
|
| 7 |
+
|
| 8 |
+
# --- Install Flash Attention (specific method for compatibility) ---
|
| 9 |
+
# This method attempts to install flash-attn without building CUDA extensions locally,
|
| 10 |
+
# which can be helpful in restricted environments like ZeroGPU or when build tools are missing.
|
| 11 |
+
print("Attempting to install Flash Attention 2...")
|
| 12 |
+
try:
|
| 13 |
+
subprocess.run(
|
| 14 |
+
'pip install flash-attn --no-build-isolation',
|
| 15 |
+
env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"},
|
| 16 |
+
shell=True,
|
| 17 |
+
check=True # Raise an error if the command fails
|
| 18 |
+
)
|
| 19 |
+
print("Flash Attention installed successfully using subprocess method.")
|
| 20 |
+
_flash_attn_2_available = True
|
| 21 |
+
except Exception as e:
|
| 22 |
+
print(f"Could not install Flash Attention 2 using subprocess: {e}")
|
| 23 |
+
print("Proceeding without Flash Attention 2. Performance may be impacted.")
|
| 24 |
+
_flash_attn_2_available = False
|
| 25 |
+
|
| 26 |
+
# --- Import Transformers AFTER potential install ---
|
| 27 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
|
| 28 |
+
from huggingface_hub import HfApi, HfFolder
|
| 29 |
|
| 30 |
# --- Configuration ---
|
| 31 |
# Updated model ID
|
|
|
|
| 35 |
website_link = "https://tesslate.com"
|
| 36 |
discord_link = "https://discord.gg/DkzMzwBTaw"
|
| 37 |
|
| 38 |
+
# --- Text Content (Keep the cool UI elements) ---
|
| 39 |
Title = f"""
|
| 40 |
<div style="text-align: center; margin-bottom: 20px;">
|
| 41 |
<img src="https://huggingface.co/Tesslate/Tessa-T1-14B/resolve/main/tesslate_logo_color.png?download=true" alt="Tesslate Logo" style="height: 80px; margin-bottom: 10px;">
|
| 42 |
<h1 style="margin-bottom: 5px;">π Welcome to the Tessa-T1-14B Demo π</h1>
|
| 43 |
<p style="font-size: 1.1em;">Experience the power of specialized React reasoning!</p>
|
| 44 |
+
<p>Model by <a href="{creator_link}" target="_blank">TesslateAI</a> | <a href="{model_link}" target="_blank">View on Hugging Face</a> | Running with 8-bit Quantization</p>
|
| 45 |
</div>
|
| 46 |
"""
|
| 47 |
|
| 48 |
description = f"""
|
| 49 |
Interact with **[{model_id}]({model_link})**, an innovative 14B parameter transformer model fine-tuned from Qwen2.5-Coder-14B-Instruct.
|
| 50 |
Tessa-T1 specializes in **React frontend development**, leveraging advanced reasoning to autonomously generate well-structured, semantic React components.
|
| 51 |
+
This demo uses **8-bit quantization** via `bitsandbytes` for reduced memory footprint. **Flash Attention 2** is enabled if available for potentially faster inference.
|
| 52 |
"""
|
| 53 |
|
| 54 |
about_tesslate = f"""
|
|
|
|
| 94 |
# --- Model and Tokenizer Loading ---
|
| 95 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 96 |
print(f"Using device: {device}")
|
| 97 |
+
if device == "cpu":
|
| 98 |
+
print("Warning: Running on CPU. Quantization and Flash Attention require CUDA.")
|
| 99 |
+
_flash_attn_2_available = False # Cannot use flash attn on CPU
|
| 100 |
|
| 101 |
# Get the token from environment variables
|
| 102 |
hf_token = os.getenv('HF_TOKEN') # Standard env var name for HF token
|
| 103 |
if not hf_token:
|
|
|
|
| 104 |
try:
|
| 105 |
+
hf_token = HfFolder.get_token()
|
|
|
|
| 106 |
if not hf_token:
|
|
|
|
| 107 |
hf_token = HfApi().token
|
| 108 |
if not hf_token:
|
| 109 |
raise ValueError("HF token not found. Please set HF_TOKEN env var or login via `huggingface-cli login`.")
|
|
|
|
| 114 |
raise ValueError(f"HF token acquisition failed. Please set the HF_TOKEN environment variable or login via `huggingface-cli login`. Error: {e}")
|
| 115 |
|
| 116 |
print(f"Loading Tokenizer: {model_id}")
|
|
|
|
|
|
|
| 117 |
tokenizer = AutoTokenizer.from_pretrained(
|
| 118 |
model_id,
|
| 119 |
token=hf_token,
|
| 120 |
trust_remote_code=True
|
| 121 |
)
|
| 122 |
|
| 123 |
+
print(f"Loading Model: {model_id} with 8-bit quantization")
|
| 124 |
+
# Define quantization configuration
|
| 125 |
+
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
|
| 126 |
+
|
| 127 |
+
# Determine attn_implementation based on install success and device
|
| 128 |
+
attn_implementation = "flash_attention_2" if _flash_attn_2_available and device == "cuda" else "sdpa" # sdpa is a fallback
|
| 129 |
+
print(f"Using attention implementation: {attn_implementation}")
|
| 130 |
+
# Note: You might see a warning from bitsandbytes about library paths on ZeroGPU, this is often normal.
|
|
|
|
|
|
|
| 131 |
|
| 132 |
+
try:
|
| 133 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 134 |
+
model_id,
|
| 135 |
+
token=hf_token,
|
| 136 |
+
device_map="auto", # Automatically distributes layers, crucial for large quantized models
|
| 137 |
+
quantization_config=quantization_config,
|
| 138 |
+
attn_implementation=attn_implementation, # Enable Flash Attention 2 if available
|
| 139 |
+
trust_remote_code=True
|
| 140 |
+
)
|
| 141 |
+
print("Model loaded successfully with 8-bit quantization.")
|
| 142 |
+
except ImportError as e:
|
| 143 |
+
print(f"ImportError during model loading: {e}")
|
| 144 |
+
print("Ensure 'bitsandbytes' and 'accelerate' are installed.")
|
| 145 |
+
# Optionally fall back to no quantization if bitsandbytes is missing,
|
| 146 |
+
# but for this request, we assume it's intended.
|
| 147 |
+
raise e
|
| 148 |
+
except Exception as e:
|
| 149 |
+
print(f"Error loading model: {e}")
|
| 150 |
+
# If Flash Attention was requested but is incompatible, Transformers might raise an error.
|
| 151 |
+
# Let's try falling back to SDPA (Scaled Dot Product Attention) if FA2 fails at load time.
|
| 152 |
+
if attn_implementation == "flash_attention_2":
|
| 153 |
+
print("Flash Attention 2 failed at load time. Trying fallback 'sdpa' attention...")
|
| 154 |
+
try:
|
| 155 |
+
attn_implementation = "sdpa"
|
| 156 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 157 |
+
model_id,
|
| 158 |
+
token=hf_token,
|
| 159 |
+
device_map="auto",
|
| 160 |
+
quantization_config=quantization_config,
|
| 161 |
+
attn_implementation=attn_implementation,
|
| 162 |
+
trust_remote_code=True
|
| 163 |
+
)
|
| 164 |
+
print("Model loaded successfully with 8-bit quantization and SDPA attention.")
|
| 165 |
+
except Exception as e2:
|
| 166 |
+
print(f"Fallback to SDPA attention also failed: {e2}")
|
| 167 |
+
raise e2 # Re-raise the error if fallback fails too
|
| 168 |
+
else:
|
| 169 |
+
raise e # Re-raise original error if it wasn't FA2 related
|
| 170 |
+
|
| 171 |
+
# Get config info (might need adjustment based on quantized model structure)
|
| 172 |
try:
|
| 173 |
config_json = model.config.to_dict()
|
| 174 |
+
# Add quantization info
|
| 175 |
+
quant_info = model.config.quantization_config.to_dict() if hasattr(model.config, 'quantization_config') else {}
|
| 176 |
model_config_info = f"""
|
| 177 |
**Model Type:** {config_json.get('model_type', 'N/A')}
|
| 178 |
**Architecture:** {config_json.get('architectures', ['N/A'])[0]}
|
|
|
|
| 181 |
**Num Hidden Layers:** {config_json.get('num_hidden_layers', 'N/A')}
|
| 182 |
**Num Attention Heads:** {config_json.get('num_attention_heads', 'N/A')}
|
| 183 |
**Max Position Embeddings:** {config_json.get('max_position_embeddings', 'N/A')}
|
| 184 |
+
**Attention Implementation:** `{attn_implementation}`
|
| 185 |
+
**Quantization:** 8-bit (`load_in_8bit={quant_info.get('load_in_8bit', 'N/A')}`)
|
| 186 |
"""
|
| 187 |
except Exception as e:
|
| 188 |
+
print(f"Could not retrieve full model config: {e}")
|
| 189 |
+
model_config_info = f"**Error:** Could not load full config details for {model_id}."
|
| 190 |
+
|
| 191 |
|
| 192 |
# --- Helper Function for Tokenizer Info ---
|
| 193 |
+
# (Keep the existing format_tokenizer_info function - no changes needed)
|
| 194 |
def format_tokenizer_info(tokenizer_instance):
|
| 195 |
try:
|
| 196 |
info = [
|
|
|
|
| 214 |
|
| 215 |
tokenizer_info = format_tokenizer_info(tokenizer)
|
| 216 |
|
| 217 |
+
|
| 218 |
# --- Generation Function ---
|
| 219 |
+
@spaces.GPU(duration=180) # Keep duration, can be adjusted if needed
|
| 220 |
def generate_response(system_prompt, user_prompt, temperature, max_new_tokens, top_p, repetition_penalty, top_k, min_p):
|
| 221 |
+
# (Keep the existing generate_response function structure)
|
| 222 |
+
# It correctly uses apply_chat_template and handles generation parameters.
|
| 223 |
+
# min_p is still noted as ignored by the standard HF generate function.
|
| 224 |
|
|
|
|
| 225 |
messages = []
|
| 226 |
if system_prompt and system_prompt.strip():
|
|
|
|
|
|
|
| 227 |
messages.append({"role": "system", "content": system_prompt})
|
| 228 |
messages.append({"role": "user", "content": user_prompt})
|
| 229 |
|
| 230 |
try:
|
|
|
|
| 231 |
full_prompt = tokenizer.apply_chat_template(
|
| 232 |
messages,
|
| 233 |
tokenize=False,
|
| 234 |
+
add_generation_prompt=True
|
| 235 |
)
|
| 236 |
+
# print("Applied tokenizer's chat template.") # Less verbose logging
|
| 237 |
except Exception as e:
|
|
|
|
| 238 |
print(f"Warning: Could not use apply_chat_template (Error: {e}). Falling back to basic format. This might degrade performance.")
|
| 239 |
prompt_parts = []
|
| 240 |
if system_prompt and system_prompt.strip():
|
| 241 |
prompt_parts.append(f"System: {system_prompt}")
|
| 242 |
prompt_parts.append(f"\nUser: {user_prompt}")
|
| 243 |
+
prompt_parts.append("\nAssistant:")
|
| 244 |
full_prompt = "\n".join(prompt_parts)
|
| 245 |
|
| 246 |
+
# print(f"\n--- Generating ---")
|
| 247 |
+
# print(f"Prompt:\n{full_prompt}")
|
| 248 |
+
# print(f"Params: Temp={temperature}, TopK={top_k}, TopP={top_p}, RepPen={repetition_penalty}, MaxNew={max_new_tokens}, MinP={min_p} (MinP ignored)")
|
| 249 |
+
# print("-" * 20)
|
| 250 |
|
| 251 |
+
# Ensure inputs are on the correct device (handled by device_map="auto")
|
| 252 |
+
# Added truncation safeguard during tokenization
|
| 253 |
+
inputs = tokenizer(full_prompt, return_tensors="pt", truncation=True, max_length=4096).to(model.device)
|
| 254 |
|
|
|
|
| 255 |
generation_kwargs = dict(
|
| 256 |
**inputs,
|
| 257 |
max_new_tokens=int(max_new_tokens),
|
| 258 |
+
temperature=float(temperature) if float(temperature) > 0 else None,
|
| 259 |
top_p=float(top_p),
|
| 260 |
top_k=int(top_k),
|
| 261 |
repetition_penalty=float(repetition_penalty),
|
| 262 |
do_sample=True if float(temperature) > 0 else False,
|
| 263 |
+
pad_token_id=tokenizer.eos_token_id,
|
| 264 |
eos_token_id=tokenizer.eos_token_id
|
|
|
|
| 265 |
)
|
| 266 |
|
| 267 |
+
if temperature == 0:
|
| 268 |
generation_kwargs.pop('top_p', None)
|
| 269 |
generation_kwargs.pop('top_k', None)
|
| 270 |
generation_kwargs['do_sample'] = False
|
| 271 |
|
|
|
|
|
|
|
| 272 |
with torch.inference_mode():
|
| 273 |
outputs = model.generate(**generation_kwargs)
|
| 274 |
|
|
|
|
| 275 |
input_length = inputs['input_ids'].shape[1]
|
| 276 |
generated_tokens = outputs[0][input_length:]
|
| 277 |
response = tokenizer.decode(generated_tokens, skip_special_tokens=True)
|
| 278 |
|
| 279 |
+
# print(f"--- Response ---\n{response}\n---------------\n")
|
| 280 |
return response.strip()
|
| 281 |
|
| 282 |
# --- Gradio Interface ---
|
|
|
|
| 286 |
|
| 287 |
with gr.Row():
|
| 288 |
with gr.Column(scale=3):
|
|
|
|
| 289 |
with gr.Group():
|
| 290 |
system_prompt = gr.Textbox(
|
| 291 |
label="System Prompt (Persona & Instructions)",
|
|
|
|
| 301 |
|
| 302 |
with gr.Accordion("π οΈ Generation Parameters", open=True):
|
| 303 |
with gr.Row():
|
| 304 |
+
# --- Set Default Params ---
|
| 305 |
temperature = gr.Slider(minimum=0.0, maximum=2.0, value=0.7, step=0.05, label="π‘οΈ Temperature", info="Controls randomness. 0 = deterministic, >0 = random.")
|
| 306 |
max_new_tokens = gr.Slider(minimum=64, maximum=4096, value=1024, step=32, label="π Max New Tokens", info="Max length of the generated response.")
|
| 307 |
with gr.Row():
|
|
|
|
| 309 |
top_p = gr.Slider(minimum=0.05, maximum=1.0, value=0.95, step=0.01, label="π
Top-p (nucleus)", info="Sample from tokens with cumulative probability >= top_p.")
|
| 310 |
with gr.Row():
|
| 311 |
repetition_penalty = gr.Slider(minimum=1.0, maximum=2.0, value=1.1, step=0.01, label="π¦ Repetition Penalty", info="Penalizes repeating tokens ( > 1).")
|
|
|
|
| 312 |
min_p = gr.Slider(minimum=0.0, maximum=0.5, value=0.05, step=0.01, label="π Min-p (Not Active)", info="Filters tokens below this probability threshold (Requires custom logic - currently ignored).")
|
| 313 |
|
| 314 |
generate_btn = gr.Button("π Generate Response", variant="primary", size="lg")
|
| 315 |
|
| 316 |
with gr.Column(scale=2):
|
| 317 |
+
# --- Fix: Remove show_copy_button=True ---
|
| 318 |
+
# gr.Code inherently has a copy button in modern Gradio versions
|
| 319 |
output = gr.Code(
|
| 320 |
+
label=f"π Tessa-T1-14B (8-bit) Output",
|
| 321 |
+
language="markdown",
|
| 322 |
lines=25,
|
| 323 |
+
# show_copy_button=True, # REMOVED - This caused the TypeError
|
| 324 |
)
|
| 325 |
|
|
|
|
| 326 |
with gr.Accordion("βοΈ Model & Tokenizer Details", open=False):
|
| 327 |
gr.Markdown("### Model Configuration")
|
| 328 |
+
gr.Markdown(model_config_info) # Display updated info including quantization/attn
|
| 329 |
gr.Markdown("---")
|
| 330 |
gr.Markdown("### Tokenizer Configuration")
|
| 331 |
gr.Markdown(tokenizer_info)
|
| 332 |
|
|
|
|
| 333 |
# About Tesslate Section
|
| 334 |
with gr.Row():
|
| 335 |
with gr.Accordion("π‘ About Tesslate & Our Mission", open=False):
|
|
|
|
| 338 |
# Links Section
|
| 339 |
gr.Markdown(join_us)
|
| 340 |
|
| 341 |
+
# Examples (Keep the relevant examples)
|
| 342 |
gr.Examples(
|
| 343 |
examples=[
|
|
|
|
| 344 |
[
|
| 345 |
"You are Tessa, an expert AI assistant specialized in React development.",
|
| 346 |
"Create a simple React functional component for a button that alerts 'Hello!' when clicked.",
|
| 347 |
+
0.7, 512, 0.95, 1.1, 40, 0.05 # Default params match the sliders now
|
| 348 |
],
|
| 349 |
[
|
| 350 |
"You are Tessa, an expert AI assistant specialized in React development.",
|
| 351 |
"Explain the difference between `useState` and `useEffect` hooks in React with simple examples.",
|
| 352 |
0.7, 1024, 0.95, 1.1, 40, 0.05
|
| 353 |
],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 354 |
[
|
| 355 |
"You are Tessa, an expert AI assistant specialized in React development. Use Tailwind CSS for styling.",
|
| 356 |
"Generate a React component for a responsive card with an image, title, and description, using Tailwind CSS classes.",
|
|
|
|
| 359 |
[
|
| 360 |
"You are a helpful AI assistant.",
|
| 361 |
"What are the pros and cons of using Next.js compared to Create React App?",
|
| 362 |
+
0.8, 1024, 0.98, 1.05, 60, 0.05 # Example with slightly different params
|
| 363 |
]
|
| 364 |
],
|
| 365 |
inputs=[
|
|
|
|
| 370 |
top_p,
|
| 371 |
repetition_penalty,
|
| 372 |
top_k,
|
| 373 |
+
min_p
|
| 374 |
],
|
| 375 |
outputs=output,
|
| 376 |
label="β¨ Example Prompts (Click to Load)"
|
|
|
|
| 386 |
|
| 387 |
# Launch the demo
|
| 388 |
if __name__ == "__main__":
|
| 389 |
+
# The progress bar noise during shard loading is normal output from the `transformers` library
|
| 390 |
+
# during the download/loading phase before the Gradio app starts serving.
|
| 391 |
+
# It cannot be suppressed from within this script.
|
| 392 |
demo.queue().launch(debug=True, share=False) # Set share=True if deploying on HF Spaces
|