File size: 3,529 Bytes
d37ca0c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91c71d1
d37ca0c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import streamlit as st
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from peft import PeftModel
from huggingface_hub import snapshot_download

# Page configuration
st.set_page_config(page_title="Fine-tuned Phi-2 Demo", page_icon="🤖", layout="wide")

# Streamlit UI
st.title("Fine-tuned Phi-2 Model Demo")
st.markdown(
    """
This app demonstrates a fine-tuned version of Microsoft's Phi-2 model.
Enter your prompt below and see the model generate text!
"""
)


@st.cache_resource
def load_model():
    with st.status("Loading model from Hugging Face...", expanded=True) as status:
        st.write("Downloading checkpoint files...")
        # Download the checkpoint files from your HF repo
        model_repo = "Adityak204/phi2-finetuned-checkpoint"
        checkpoint_path = snapshot_download(repo_id=model_repo)

        st.write("Loading base model...")
        # Load the base model
        base_model = AutoModelForCausalLM.from_pretrained(
            "microsoft/phi-2",
            torch_dtype=torch.float16,
            device_map="auto",
            trust_remote_code=True,
        )

        st.write("Loading LoRA adapter weights...")
        # Load the LoRA adapter weights
        model = PeftModel.from_pretrained(
            base_model, checkpoint_path, device_map="auto"
        )

        st.write("Loading tokenizer...")
        # Load tokenizer
        tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)
        tokenizer.pad_token = tokenizer.eos_token

        st.write("Creating text generation pipeline...")
        # Create text generation pipeline
        pipe = pipeline(task="text-generation", model=model, tokenizer=tokenizer)

        status.update(
            label="Model loaded successfully!", state="complete", expanded=False
        )

        return pipe


# Load the model and pipeline (cached to avoid reloading)
pipe = load_model()

# Input form
with st.form("generation_form"):
    prompt = st.text_area(
        "Enter your prompt:",
        height=150,
        placeholder="Example: What is the square root of 81?",
    )

    max_length = st.slider(
        "Maximum length:",
        min_value=32,
        max_value=1024,
        value=256,
        step=32,
        help="Maximum number of tokens to generate",
    )

    generate_button = st.form_submit_button("Generate")

# Generate text when button is clicked
if generate_button and prompt:
    with st.spinner("Generating..."):
        # Generate text
        print("Prompt recieved = ", prompt)
        generated_text = pipe(
            prompt,
            max_length=max_length,
            do_sample=True,
            num_return_sequences=1,
        )[0]["generated_text"]

        # Display the result
        st.subheader("Generated Output:")
        st.write(generated_text)

        # Copy button
        st.text_area("Copy output:", value=generated_text, height=150)

# Sidebar with model information
with st.sidebar:
    st.header("About This Model")
    st.write(
        """
    This is a fine-tuned version of Microsoft's Phi-2 model.
    
    The model was fine-tuned using QLoRA techniques to improve its performance on specific tasks.
    
    Original model: [microsoft/phi-2](https://huggingface.co/microsoft/phi-2)
    
    Fine-tuned model: [Adityak204/phi2-finetuned-checkpoint](https://huggingface.co/Adityak204/phi2-finetuned-checkpoint)
    """
    )

# Footer
st.markdown("---")
st.markdown("Powered by a fine-tuned version of Microsoft's Phi-2 model")