phi2-finetuned / app.py
Adityak204's picture
Removed template
91c71d1
import streamlit as st
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from peft import PeftModel
from huggingface_hub import snapshot_download
# Page configuration
st.set_page_config(page_title="Fine-tuned Phi-2 Demo", page_icon="🤖", layout="wide")
# Streamlit UI
st.title("Fine-tuned Phi-2 Model Demo")
st.markdown(
"""
This app demonstrates a fine-tuned version of Microsoft's Phi-2 model.
Enter your prompt below and see the model generate text!
"""
)
@st.cache_resource
def load_model():
with st.status("Loading model from Hugging Face...", expanded=True) as status:
st.write("Downloading checkpoint files...")
# Download the checkpoint files from your HF repo
model_repo = "Adityak204/phi2-finetuned-checkpoint"
checkpoint_path = snapshot_download(repo_id=model_repo)
st.write("Loading base model...")
# Load the base model
base_model = AutoModelForCausalLM.from_pretrained(
"microsoft/phi-2",
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True,
)
st.write("Loading LoRA adapter weights...")
# Load the LoRA adapter weights
model = PeftModel.from_pretrained(
base_model, checkpoint_path, device_map="auto"
)
st.write("Loading tokenizer...")
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)
tokenizer.pad_token = tokenizer.eos_token
st.write("Creating text generation pipeline...")
# Create text generation pipeline
pipe = pipeline(task="text-generation", model=model, tokenizer=tokenizer)
status.update(
label="Model loaded successfully!", state="complete", expanded=False
)
return pipe
# Load the model and pipeline (cached to avoid reloading)
pipe = load_model()
# Input form
with st.form("generation_form"):
prompt = st.text_area(
"Enter your prompt:",
height=150,
placeholder="Example: What is the square root of 81?",
)
max_length = st.slider(
"Maximum length:",
min_value=32,
max_value=1024,
value=256,
step=32,
help="Maximum number of tokens to generate",
)
generate_button = st.form_submit_button("Generate")
# Generate text when button is clicked
if generate_button and prompt:
with st.spinner("Generating..."):
# Generate text
print("Prompt recieved = ", prompt)
generated_text = pipe(
prompt,
max_length=max_length,
do_sample=True,
num_return_sequences=1,
)[0]["generated_text"]
# Display the result
st.subheader("Generated Output:")
st.write(generated_text)
# Copy button
st.text_area("Copy output:", value=generated_text, height=150)
# Sidebar with model information
with st.sidebar:
st.header("About This Model")
st.write(
"""
This is a fine-tuned version of Microsoft's Phi-2 model.
The model was fine-tuned using QLoRA techniques to improve its performance on specific tasks.
Original model: [microsoft/phi-2](https://huggingface.co/microsoft/phi-2)
Fine-tuned model: [Adityak204/phi2-finetuned-checkpoint](https://huggingface.co/Adityak204/phi2-finetuned-checkpoint)
"""
)
# Footer
st.markdown("---")
st.markdown("Powered by a fine-tuned version of Microsoft's Phi-2 model")