import gradio as gr from transformers import AutoTokenizer from flashpack.integrations.transformers import FlashPackTransformersModelMixin from transformers import AutoModelForCausalLM, pipeline as hf_pipeline # ============================================================ # 1️⃣ Define FlashPack-enabled model class # ============================================================ class FlashPackGemmaModel(AutoModelForCausalLM, FlashPackTransformersModelMixin): """Gemma 3 model wrapped with FlashPackTransformersModelMixin""" pass # ============================================================ # 2️⃣ Load tokenizer # ============================================================ MODEL_ID = "gokaygokay/prompt-enhancer-gemma-3-270m-it" FLASHPACK_REPO = "rahul7star/FlashPack" tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) # ============================================================ # 3️⃣ Load or create FlashPack model # ============================================================ try: print("📂 Loading model from FlashPack repository...") model = FlashPackGemmaModel.from_pretrained_flashpack(FLASHPACK_REPO) except FileNotFoundError: print("⚠️ FlashPack model not found. Loading from HF Hub and uploading FlashPack...") model = FlashPackGemmaModel.from_pretrained(MODEL_ID) model.save_pretrained_flashpack(FLASHPACK_REPO, push_to_hub=True) print(f"✅ FlashPack model uploaded to Hugging Face Hub: {FLASHPACK_REPO}") # ============================================================ # 4️⃣ Build text-generation pipeline # ============================================================ pipe = hf_pipeline( "text-generation", model=model, tokenizer=tokenizer, device_map="auto" ) # ============================================================ # 5️⃣ Define prompt enhancement function # ============================================================ def enhance_prompt(user_prompt, temperature, max_tokens, chat_history): chat_history = chat_history or [] messages = [ {"role": "system", "content": "Enhance and expand the following prompt with more details and context:"}, {"role": "user", "content": user_prompt}, ] prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) outputs = pipe( prompt, max_new_tokens=int(max_tokens), temperature=float(temperature), do_sample=True ) enhanced = outputs[0]["generated_text"].strip() chat_history.append({"role": "user", "content": user_prompt}) chat_history.append({"role": "assistant", "content": enhanced}) return chat_history # ============================================================ # 6️⃣ Gradio UI # ============================================================ with gr.Blocks(title="Prompt Enhancer – Gemma 3 270M", theme=gr.themes.Soft()) as demo: gr.Markdown( """ # ✨ Prompt Enhancer (Gemma 3 270M) Enter a short prompt, and the model will **expand it with details and creative context** using the Gemma chat-template interface. """ ) with gr.Row(): chatbot = gr.Chatbot(height=400, label="Enhanced Prompts", type="messages") with gr.Column(scale=1): user_prompt = gr.Textbox( placeholder="Enter a short prompt...", label="Your Prompt", lines=3, ) temperature = gr.Slider(0.0, 1.0, value=0.7, step=0.05, label="Temperature") max_tokens = gr.Slider(32, 256, value=128, step=16, label="Max Tokens") send_btn = gr.Button("🚀 Enhance Prompt", variant="primary") clear_btn = gr.Button("🧹 Clear Chat") # Bind UI actions send_btn.click(enhance_prompt, [user_prompt, temperature, max_tokens, chatbot], chatbot) user_prompt.submit(enhance_prompt, [user_prompt, temperature, max_tokens, chatbot], chatbot) clear_btn.click(lambda: [], None, chatbot) gr.Markdown( """ --- 💡 **Tips:** - Works best with short, descriptive prompts (e.g., "a cat sitting on a chair") - Increase *Temperature* for more creative output. """ ) # ============================================================ # 7️⃣ Launch # ============================================================ if __name__ == "__main__": demo.launch(show_error=True)