import gradio as gr import torch from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline # ============================================================ # 1️⃣ Load model and tokenizer # ============================================================ MODEL_ID = "gokaygokay/prompt-enhancer-gemma-3-270m-it" # Use GPU if available device = 0 if torch.cuda.is_available() else -1 tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) model = AutoModelForCausalLM.from_pretrained(MODEL_ID) pipe = pipeline( "text-generation", model=model, tokenizer=tokenizer, device=device, # 0 for GPU, -1 for CPU ) # ============================================================ # 2️⃣ Define the generation function (chat-template style) # ============================================================ def enhance_prompt(user_prompt, temperature, max_tokens, chat_history): chat_history = chat_history or [] # Build messages using proper roles messages = [ {"role": "system", "content": "Enhance and expand the following prompt with more details and context:"}, {"role": "user", "content": user_prompt} ] # Use tokenizer chat template to build the input prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) # Generate output output = pipe( prompt, max_new_tokens=int(max_tokens), temperature=float(temperature), do_sample=True, )[0]["generated_text"].strip() print(output) print(output[0]) # Append conversation to history chat_history.append({"role": "user", "content": user_prompt}) chat_history.append({"role": "assistant", "content": output}) return chat_history import re def extract_later_part(user_prompt, generated_text): """ Cleans the model output and extracts only the enhanced (later) portion. Removes prompt echoes and system tags like , , etc. """ # Step 1: Clean up internal tags cleaned = re.sub(r"<.*?>", "", generated_text) # Remove , , etc. cleaned = cleaned.strip() # Step 2: Normalize spaces cleaned = re.sub(r"\s+", " ", cleaned) # Step 3: Try removing the original prompt if repeated user_prompt_clean = user_prompt.strip().lower() cleaned_lower = cleaned.lower() if cleaned_lower.startswith(user_prompt_clean): cleaned = cleaned[len(user_prompt):].strip(",. ").strip() return cleaned # ===================== Prompt Enhancer Function ===================== def enhance_prompt1(user_prompt, temperature, max_tokens, chat_history): messages = [ {"role": "system", "content": "Enhance and expand the following prompt with more details and context:"}, {"role": "user", "content": user_prompt} ] prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) output = pipe(prompt, max_new_tokens=256) raw_output = output[0]['generated_text'] print("=== RAW MODEL OUTPUT ===") print(raw_output) # Extract the cleaned, later portion later_part = extract_later_part(user_prompt, raw_output) print("=== EXTRACTED CLEANED OUTPUT ===") print(later_part) # Append to chat history for Gradio chat_history = chat_history or [] chat_history.append({"role": "user", "content": user_prompt}) chat_history.append({"role": "assistant", "content": later_part}) return chat_history # ============================================================ # 3️⃣ Gradio UI # ============================================================ with gr.Blocks(title="Prompt Enhancer – Gemma 3 270M", theme=gr.themes.Soft()) as demo: gr.Markdown( """ # ✨ Prompt Enhancer (Gemma 3 270M) Enter a short prompt, and the model will **expand it with details and creative context** using the Gemma chat-template interface. """ ) with gr.Row(): chatbot = gr.Chatbot(height=400, label="Enhanced Prompts", type="messages") with gr.Column(scale=1): user_prompt = gr.Textbox( placeholder="Enter a short prompt...", label="Your Prompt", lines=3, ) temperature = gr.Slider(0.0, 1.0, value=0.7, step=0.05, label="Temperature") max_tokens = gr.Slider(32, 256, value=128, step=16, label="Max Tokens") send_btn = gr.Button("🚀 dev dont click", variant="primary") clear_btn = gr.Button("🧹 Clear Chat") add_btn = gr.Button("🚀 Enchance Prompt", variant="primary") # Bind UI actions #send_btn.click(enhance_prompt, [user_prompt, temperature, max_tokens, chatbot], chatbot) user_prompt.submit(enhance_prompt, [user_prompt, temperature, max_tokens, chatbot], chatbot) clear_btn.click(lambda: [], None, chatbot) add_btn.click(enhance_prompt1, [user_prompt, temperature, max_tokens, chatbot], chatbot) gr.Markdown( """ --- 💡 **Tips:** - Works best with short, descriptive prompts (e.g., "a cat sitting on a chair") - Increase *Temperature* for more creative output. """ ) # ============================================================ # 4️⃣ Launch # ============================================================ if __name__ == "__main__": demo.launch(show_error=True)