Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline | |
| from peft import PeftModel | |
| from huggingface_hub import snapshot_download | |
| # Page configuration | |
| st.set_page_config(page_title="Fine-tuned Phi-2 Demo", page_icon="🤖", layout="wide") | |
| # Streamlit UI | |
| st.title("Fine-tuned Phi-2 Model Demo") | |
| st.markdown( | |
| """ | |
| This app demonstrates a fine-tuned version of Microsoft's Phi-2 model. | |
| Enter your prompt below and see the model generate text! | |
| """ | |
| ) | |
| def load_model(): | |
| with st.status("Loading model from Hugging Face...", expanded=True) as status: | |
| st.write("Downloading checkpoint files...") | |
| # Download the checkpoint files from your HF repo | |
| model_repo = "Adityak204/phi2-finetuned-checkpoint" | |
| checkpoint_path = snapshot_download(repo_id=model_repo) | |
| st.write("Loading base model...") | |
| # Load the base model | |
| base_model = AutoModelForCausalLM.from_pretrained( | |
| "microsoft/phi-2", | |
| torch_dtype=torch.float16, | |
| device_map="auto", | |
| trust_remote_code=True, | |
| ) | |
| st.write("Loading LoRA adapter weights...") | |
| # Load the LoRA adapter weights | |
| model = PeftModel.from_pretrained( | |
| base_model, checkpoint_path, device_map="auto" | |
| ) | |
| st.write("Loading tokenizer...") | |
| # Load tokenizer | |
| tokenizer = AutoTokenizer.from_pretrained(checkpoint_path) | |
| tokenizer.pad_token = tokenizer.eos_token | |
| st.write("Creating text generation pipeline...") | |
| # Create text generation pipeline | |
| pipe = pipeline(task="text-generation", model=model, tokenizer=tokenizer) | |
| status.update( | |
| label="Model loaded successfully!", state="complete", expanded=False | |
| ) | |
| return pipe | |
| # Load the model and pipeline (cached to avoid reloading) | |
| pipe = load_model() | |
| # Input form | |
| with st.form("generation_form"): | |
| prompt = st.text_area( | |
| "Enter your prompt:", | |
| height=150, | |
| placeholder="Example: What is the square root of 81?", | |
| ) | |
| max_length = st.slider( | |
| "Maximum length:", | |
| min_value=32, | |
| max_value=1024, | |
| value=256, | |
| step=32, | |
| help="Maximum number of tokens to generate", | |
| ) | |
| generate_button = st.form_submit_button("Generate") | |
| # Generate text when button is clicked | |
| if generate_button and prompt: | |
| with st.spinner("Generating..."): | |
| # Generate text | |
| print("Prompt recieved = ", prompt) | |
| generated_text = pipe( | |
| prompt, | |
| max_length=max_length, | |
| do_sample=True, | |
| num_return_sequences=1, | |
| )[0]["generated_text"] | |
| # Display the result | |
| st.subheader("Generated Output:") | |
| st.write(generated_text) | |
| # Copy button | |
| st.text_area("Copy output:", value=generated_text, height=150) | |
| # Sidebar with model information | |
| with st.sidebar: | |
| st.header("About This Model") | |
| st.write( | |
| """ | |
| This is a fine-tuned version of Microsoft's Phi-2 model. | |
| The model was fine-tuned using QLoRA techniques to improve its performance on specific tasks. | |
| Original model: [microsoft/phi-2](https://huggingface.co/microsoft/phi-2) | |
| Fine-tuned model: [Adityak204/phi2-finetuned-checkpoint](https://huggingface.co/Adityak204/phi2-finetuned-checkpoint) | |
| """ | |
| ) | |
| # Footer | |
| st.markdown("---") | |
| st.markdown("Powered by a fine-tuned version of Microsoft's Phi-2 model") | |