apertus / app.py
Rudran's picture
Create app.py
9d6426d verified
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
# Set page configuration
st.set_page_config(
page_title="Apertus-8B Chat",
page_icon="πŸ€–",
layout="wide"
)
# Add a title to the app
st.title("πŸ€– Chat with Apertus-8B-Instruct")
st.caption("A Streamlit app running swiss-ai/Apertus-8B-Instruct-2509")
# --- MODEL LOADING ---
@st.cache_resource
def load_model():
"""Loads the model and tokenizer with 4-bit quantization."""
model_id = "swiss-ai/Apertus-8B-Instruct-2509"
# Configure quantization to reduce memory usage
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id)
# Load the model
model = AutoModelForCausalLM.from_pretrained(
model_id,
quantization_config=bnb_config,
device_map="auto", # Automatically maps model layers to available hardware (GPU/CPU)
)
return tokenizer, model
# Load the model and display a spinner while doing so
with st.spinner("Loading Apertus-8B model... This might take a moment."):
tokenizer, model = load_model()
# --- CHAT INTERFACE ---
# Initialize chat history
if "messages" not in st.session_state:
st.session_state.messages = []
# Display chat messages from history on app rerun
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
# Accept user input
if prompt := st.chat_input("What would you like to ask?"):
# Add user message to chat history
st.session_state.messages.append({"role": "user", "content": prompt})
# Display user message in chat message container
with st.chat_message("user"):
st.markdown(prompt)
# --- GENERATION ---
with st.chat_message("assistant"):
with st.spinner("Thinking..."):
# Prepare the input for the model
input_ids = tokenizer(prompt, return_tensors="pt").to(model.device)
# Generate a response
outputs = model.generate(
**input_ids,
max_new_tokens=256,
do_sample=True,
temperature=0.7,
top_k=50,
top_p=0.95
)
# Decode and display the response
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# The model often repeats the prompt, so we can clean it up
cleaned_response = response.replace(prompt, "").strip()
st.markdown(cleaned_response)
# Add assistant response to chat history
st.session_state.messages.append({"role": "assistant", "content": cleaned_response})