|
|
import streamlit as st |
|
|
import torch |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig |
|
|
|
|
|
|
|
|
st.set_page_config( |
|
|
page_title="Apertus-8B Chat", |
|
|
page_icon="π€", |
|
|
layout="wide" |
|
|
) |
|
|
|
|
|
|
|
|
st.title("π€ Chat with Apertus-8B-Instruct") |
|
|
st.caption("A Streamlit app running swiss-ai/Apertus-8B-Instruct-2509") |
|
|
|
|
|
|
|
|
@st.cache_resource |
|
|
def load_model(): |
|
|
"""Loads the model and tokenizer with 4-bit quantization.""" |
|
|
model_id = "swiss-ai/Apertus-8B-Instruct-2509" |
|
|
|
|
|
|
|
|
bnb_config = BitsAndBytesConfig( |
|
|
load_in_4bit=True, |
|
|
bnb_4bit_quant_type="nf4", |
|
|
bnb_4bit_compute_dtype=torch.bfloat16 |
|
|
) |
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id) |
|
|
|
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
model_id, |
|
|
quantization_config=bnb_config, |
|
|
device_map="auto", |
|
|
) |
|
|
return tokenizer, model |
|
|
|
|
|
|
|
|
with st.spinner("Loading Apertus-8B model... This might take a moment."): |
|
|
tokenizer, model = load_model() |
|
|
|
|
|
|
|
|
|
|
|
if "messages" not in st.session_state: |
|
|
st.session_state.messages = [] |
|
|
|
|
|
|
|
|
for message in st.session_state.messages: |
|
|
with st.chat_message(message["role"]): |
|
|
st.markdown(message["content"]) |
|
|
|
|
|
|
|
|
if prompt := st.chat_input("What would you like to ask?"): |
|
|
|
|
|
st.session_state.messages.append({"role": "user", "content": prompt}) |
|
|
|
|
|
with st.chat_message("user"): |
|
|
st.markdown(prompt) |
|
|
|
|
|
|
|
|
with st.chat_message("assistant"): |
|
|
with st.spinner("Thinking..."): |
|
|
|
|
|
input_ids = tokenizer(prompt, return_tensors="pt").to(model.device) |
|
|
|
|
|
|
|
|
outputs = model.generate( |
|
|
**input_ids, |
|
|
max_new_tokens=256, |
|
|
do_sample=True, |
|
|
temperature=0.7, |
|
|
top_k=50, |
|
|
top_p=0.95 |
|
|
) |
|
|
|
|
|
|
|
|
response = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
cleaned_response = response.replace(prompt, "").strip() |
|
|
|
|
|
st.markdown(cleaned_response) |
|
|
|
|
|
|
|
|
st.session_state.messages.append({"role": "assistant", "content": cleaned_response}) |