import streamlit as st import torch from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline from peft import PeftModel from huggingface_hub import snapshot_download # Page configuration st.set_page_config(page_title="Fine-tuned Phi-2 Demo", page_icon="🤖", layout="wide") # Streamlit UI st.title("Fine-tuned Phi-2 Model Demo") st.markdown( """ This app demonstrates a fine-tuned version of Microsoft's Phi-2 model. Enter your prompt below and see the model generate text! """ ) @st.cache_resource def load_model(): with st.status("Loading model from Hugging Face...", expanded=True) as status: st.write("Downloading checkpoint files...") # Download the checkpoint files from your HF repo model_repo = "Adityak204/phi2-finetuned-checkpoint" checkpoint_path = snapshot_download(repo_id=model_repo) st.write("Loading base model...") # Load the base model base_model = AutoModelForCausalLM.from_pretrained( "microsoft/phi-2", torch_dtype=torch.float16, device_map="auto", trust_remote_code=True, ) st.write("Loading LoRA adapter weights...") # Load the LoRA adapter weights model = PeftModel.from_pretrained( base_model, checkpoint_path, device_map="auto" ) st.write("Loading tokenizer...") # Load tokenizer tokenizer = AutoTokenizer.from_pretrained(checkpoint_path) tokenizer.pad_token = tokenizer.eos_token st.write("Creating text generation pipeline...") # Create text generation pipeline pipe = pipeline(task="text-generation", model=model, tokenizer=tokenizer) status.update( label="Model loaded successfully!", state="complete", expanded=False ) return pipe # Load the model and pipeline (cached to avoid reloading) pipe = load_model() # Input form with st.form("generation_form"): prompt = st.text_area( "Enter your prompt:", height=150, placeholder="Example: What is the square root of 81?", ) max_length = st.slider( "Maximum length:", min_value=32, max_value=1024, value=256, step=32, help="Maximum number of tokens to generate", ) generate_button = st.form_submit_button("Generate") # Generate text when button is clicked if generate_button and prompt: with st.spinner("Generating..."): # Generate text print("Prompt recieved = ", prompt) generated_text = pipe( prompt, max_length=max_length, do_sample=True, num_return_sequences=1, )[0]["generated_text"] # Display the result st.subheader("Generated Output:") st.write(generated_text) # Copy button st.text_area("Copy output:", value=generated_text, height=150) # Sidebar with model information with st.sidebar: st.header("About This Model") st.write( """ This is a fine-tuned version of Microsoft's Phi-2 model. The model was fine-tuned using QLoRA techniques to improve its performance on specific tasks. Original model: [microsoft/phi-2](https://huggingface.co/microsoft/phi-2) Fine-tuned model: [Adityak204/phi2-finetuned-checkpoint](https://huggingface.co/Adityak204/phi2-finetuned-checkpoint) """ ) # Footer st.markdown("---") st.markdown("Powered by a fine-tuned version of Microsoft's Phi-2 model")