Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import torch | |
| st.title("π Study Buddy Chatbot") | |
| st.write("Ask a question or type a topic, and I'll help you learn interactively!") | |
| # Initialize session state for conversation history | |
| if "conversation" not in st.session_state: | |
| st.session_state.conversation = [] | |
| # Load model with better caching and memory management | |
| def load_model(): | |
| MODEL_NAME = "HuggingFaceH4/zephyr-7b-alpha" | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_NAME, | |
| torch_dtype=torch.float16, | |
| device_map="auto", | |
| low_cpu_mem_usage=True | |
| ) | |
| return tokenizer, model | |
| # Only load model when needed | |
| if "tokenizer" not in st.session_state or "model" not in st.session_state: | |
| with st.spinner("Loading AI model (this may take a minute)..."): | |
| st.session_state.tokenizer, st.session_state.model = load_model() | |
| def get_response(user_input): | |
| # Get tokenizer and model from session state | |
| tokenizer = st.session_state.tokenizer | |
| model = st.session_state.model | |
| # Format conversation history for context | |
| history = "\n".join(st.session_state.conversation[-6:]) # Last 6 exchanges | |
| prompt = ( | |
| f"You are a knowledgeable study coach. Engage the student in conversation. " | |
| f"Ask open-ended questions to deepen understanding. Provide feedback and encourage explanations.\n\n" | |
| f"Previous conversation:\n{history}\n\n" | |
| f"Student: {user_input}\n" | |
| f"Coach: " | |
| ) | |
| # Better generation parameters | |
| input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device) | |
| with torch.no_grad(): | |
| output = model.generate( | |
| input_ids, | |
| max_new_tokens=250, | |
| temperature=0.7, | |
| top_p=0.9, | |
| do_sample=True, | |
| repetition_penalty=1.2 | |
| ) | |
| response = tokenizer.decode(output[0, input_ids.shape[1]:], skip_special_tokens=True) | |
| return response | |
| # User interface | |
| user_input = st.text_input("Type your question or topic:") | |
| if user_input: | |
| with st.spinner("Thinking..."): | |
| response = get_response(user_input) | |
| # Add to conversation history | |
| st.session_state.conversation.append(f"Student: {user_input}") | |
| st.session_state.conversation.append(f"Coach: {response}") | |
| # Display conversation in a better format | |
| st.subheader("Conversation History") | |
| for i, message in enumerate(st.session_state.conversation[-10:]): | |
| if i % 2 == 0: # Student messages | |
| st.markdown(f"**You**: {message.replace('Student: ', '')}") | |
| else: # Coach messages | |
| st.markdown(f"**Coach**: {message.replace('Coach: ', '')}") | |
| # Add a clear conversation button | |
| if st.button("Clear Conversation"): | |
| st.session_state.conversation = [] | |
| st.experimental_rerun() |