studybot / app.py
amaltese's picture
Update app.py
fbb08a4 verified
import streamlit as st
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
st.title("πŸ“š Study Buddy Chatbot")
st.write("Ask a question or type a topic, and I'll help you learn interactively!")
# Initialize session state for conversation history
if "conversation" not in st.session_state:
st.session_state.conversation = []
# Load model with better caching and memory management
@st.cache_resource
def load_model():
MODEL_NAME = "HuggingFaceH4/zephyr-7b-alpha"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.float16,
device_map="auto",
low_cpu_mem_usage=True
)
return tokenizer, model
# Only load model when needed
if "tokenizer" not in st.session_state or "model" not in st.session_state:
with st.spinner("Loading AI model (this may take a minute)..."):
st.session_state.tokenizer, st.session_state.model = load_model()
def get_response(user_input):
# Get tokenizer and model from session state
tokenizer = st.session_state.tokenizer
model = st.session_state.model
# Format conversation history for context
history = "\n".join(st.session_state.conversation[-6:]) # Last 6 exchanges
prompt = (
f"You are a knowledgeable study coach. Engage the student in conversation. "
f"Ask open-ended questions to deepen understanding. Provide feedback and encourage explanations.\n\n"
f"Previous conversation:\n{history}\n\n"
f"Student: {user_input}\n"
f"Coach: "
)
# Better generation parameters
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)
with torch.no_grad():
output = model.generate(
input_ids,
max_new_tokens=250,
temperature=0.7,
top_p=0.9,
do_sample=True,
repetition_penalty=1.2
)
response = tokenizer.decode(output[0, input_ids.shape[1]:], skip_special_tokens=True)
return response
# User interface
user_input = st.text_input("Type your question or topic:")
if user_input:
with st.spinner("Thinking..."):
response = get_response(user_input)
# Add to conversation history
st.session_state.conversation.append(f"Student: {user_input}")
st.session_state.conversation.append(f"Coach: {response}")
# Display conversation in a better format
st.subheader("Conversation History")
for i, message in enumerate(st.session_state.conversation[-10:]):
if i % 2 == 0: # Student messages
st.markdown(f"**You**: {message.replace('Student: ', '')}")
else: # Coach messages
st.markdown(f"**Coach**: {message.replace('Coach: ', '')}")
# Add a clear conversation button
if st.button("Clear Conversation"):
st.session_state.conversation = []
st.experimental_rerun()