Spaces:
Sleeping
Sleeping
| import os | |
| import shutil | |
| import streamlit as st | |
| from dotenv import load_dotenv | |
| from llama_index.core import ( | |
| VectorStoreIndex, | |
| Settings, | |
| StorageContext, | |
| load_index_from_storage, | |
| ) | |
| from llama_index.embeddings.huggingface import HuggingFaceEmbedding | |
| from llama_index.llms.groq import Groq | |
| import pandas as pd | |
| from llama_index.core import Document | |
| PERSIST_DIR = "./storage" | |
| EMBED_MODEL = "./all-MiniLM-L6-v2" | |
| LLM_MODEL = "llama3-8b-8192" | |
| CSV_FILE_PATH = "shl_assessments.csv" | |
| GROQ_API_KEY = st.secrets["GROQ_API_KEY"] or os.getenv("GROQ_API_KEY") | |
| def load_data_from_csv(csv_path): | |
| """Loads assessment data from a CSV file.""" | |
| try: | |
| df = pd.read_csv(csv_path) | |
| required_columns = ["Assessment Name", "URL", "Remote Testing Support", | |
| "Adaptive/IRT Support", "Duration (min)", "Test Type"] | |
| if not all(col in df.columns for col in required_columns): | |
| raise ValueError(f"CSV file must contain columns: {', '.join(required_columns)}") | |
| return df.to_dict(orient="records") | |
| except FileNotFoundError: | |
| raise FileNotFoundError(f"Error: CSV file not found at {csv_path}") | |
| except ValueError as e: | |
| raise ValueError(f"Error reading CSV: {e}") | |
| except Exception as e: | |
| raise Exception(f"An unexpected error occurred while loading CSV data: {e}") | |
| def load_groq_llm(): | |
| try: | |
| api_key = st.secrets.get("GROQ_API_KEY") or os.getenv("GROQ_API_KEY") | |
| except KeyError: | |
| raise ValueError("GROQ_API_KEY not found in Streamlit secrets.") | |
| return Groq(model=LLM_MODEL, api_key=api_key, temperature=0.1) | |
| def load_embeddings(): | |
| return HuggingFaceEmbedding(model_name="all-MiniLM-L6-v2") | |
| def build_index(data): | |
| """Builds the vector index from the provided assessment data.""" | |
| return HuggingFaceEmbedding(model_name=EMBED_MODEL) | |
| Settings.llm = load_groq_llm() | |
| documents = [Document(text=f"Name: {item['Assessment Name']}, URL: {item['URL']}, Remote Testing: {item['Remote Testing Support']}, Adaptive/IRT: {item['Adaptive/IRT Support']}, Duration: {item['Duration (min)']}, Type: {item['Test Type']}") for item in data] | |
| index = VectorStoreIndex.from_documents(documents) | |
| index.storage_context.persist(persist_dir=PERSIST_DIR) | |
| return index | |
| def load_chat_engine(): | |
| """Loads the chat engine from the persisted index.""" | |
| if not os.path.exists(PERSIST_DIR): | |
| return None | |
| Settings.embed_model = load_embeddings() | |
| Settings.llm = load_groq_llm() | |
| storage_context = StorageContext.from_defaults(persist_dir=PERSIST_DIR) | |
| index = load_index_from_storage(storage_context) | |
| return index.as_chat_engine(chat_mode="context", verbose=True) | |
| def reset_index(): | |
| """Resets the persisted index and chat history.""" | |
| try: | |
| shutil.rmtree(PERSIST_DIR, ignore_errors=True) | |
| st.success("Knowledge index reset successfully!") | |
| st.session_state.messages = [{"role": "assistant", "content": "Hello! I'm your SHL assessment assistant. How can I help you?"}] | |
| st.session_state["index_built"] = False | |
| if 'chat_engine' in st.session_state: | |
| del st.session_state['chat_engine'] | |
| return None | |
| except Exception as e: | |
| st.error(f"Error resetting index: {str(e)}") | |
| return None | |
| def main(): | |
| st.set_page_config( | |
| page_title="SHL Assessment Chatbot", | |
| layout="wide", | |
| initial_sidebar_state="collapsed" | |
| ) | |
| st.markdown(""" | |
| <style> | |
| :root { | |
| --primary: #6eb5ff; | |
| --background: #000000; | |
| --card: #f0f2f6; | |
| --text: #ffffff; | |
| --background: #000000; | |
| --card: #f0f2f6; | |
| --text: #ffffff; | |
| } | |
| .stApp { | |
| background-color: var(--background) !important; | |
| color: var(--text) !important; | |
| } | |
| .stMarkdown, .stTextInput, .stChatMessage, .stChatInputContainer, .css-10trblm, .css-1cpxqw2 { | |
| color: var(--text) !important; | |
| } | |
| .stApp { | |
| background-color: var(--background) !important; | |
| color: var(--text) !important; | |
| } | |
| .stMarkdown, .stTextInput, .stChatMessage, .stChatInputContainer, .css-10trblm, .css-1cpxqw2 { | |
| color: var(--text) !important; | |
| } | |
| .stApp { | |
| background-color: var(--background) !important; | |
| color: var(--text) !important; | |
| } | |
| .stMarkdown, .stTextInput, .stChatMessage, .stChatInputContainer, .css-10trblm, .css-1cpxqw2 { | |
| color: var(--text) !important; | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| load_dotenv() | |
| os.environ["STREAMLIT_SERVER_ENABLE_FILE_WATCHER"] = "false" | |
| os.environ["TORCH_DISABLE_STREAMLIT_WATCHER"] = "1" | |
| os.environ["LLAMA_INDEX_DISABLE_OPENAI"] = "1" | |
| if "messages" not in st.session_state: | |
| st.session_state.messages = [{ | |
| "role": "assistant", | |
| "content": "Hello! I'm your SHL assessment assistant. How can I help you?" | |
| }] | |
| if "index_built" not in st.session_state: | |
| st.session_state["index_built"] = False | |
| if not st.session_state["index_built"]: | |
| try: | |
| with st.spinner("Loading data and building index..."): | |
| assessment_data = load_data_from_csv(CSV_FILE_PATH) | |
| if assessment_data: | |
| build_index(assessment_data) | |
| st.session_state['chat_engine'] = load_chat_engine() | |
| st.session_state["index_built"] = True | |
| else: | |
| st.error("Failed to load assessment data. Please check the CSV file.") | |
| except Exception as e: | |
| st.error(f"Error initializing application: {e}") | |
| # --- Chat Interface --- | |
| chat_engine = st.session_state.get('chat_engine') | |
| if chat_engine: | |
| for msg in st.session_state.messages: | |
| icon = "π€" if msg["role"] == "assistant" else "π€" | |
| with st.chat_message(msg["role"]): | |
| st.markdown(f"<span style='color: white;'>{icon} {msg['content']}</span>", unsafe_allow_html=True) | |
| if prompt := st.chat_input("Ask me about SHL assessments..."): | |
| st.session_state.messages.append({"role": "user", "content": prompt}) | |
| with st.chat_message("user"): | |
| st.markdown(f"<span style='color: white;'>π€ {prompt}</span>", unsafe_allow_html=True) | |
| with st.chat_message("assistant"): | |
| try: | |
| # Add formatting instructions to the prompt | |
| formatted_prompt = f""" | |
| {prompt} | |
| Please provide a list of all matching SHL assessments (minimum 1, maximum 10). | |
| For each matching assessment, follow this exact format: | |
| β’ Assessment Name: [Name] | |
| URL: [URL] | |
| Remote Testing Support: [Yes/No] | |
| Adaptive/IRT Support: [Yes/No] | |
| Duration: [Duration in minutes] | |
| Test Type: [Test Type] | |
| If there are no matches, clearly state that. Respond in a clean, readable bullet-point format.Do not use any "+" signs. Do not return JSON or markdown tables. Do not bold anything. | |
| """ | |
| response = chat_engine.chat(formatted_prompt) | |
| st.markdown(f"<span style='color: white;'>π€ {response.response}</span>", unsafe_allow_html=True) | |
| st.session_state.messages.append({"role": "assistant", "content": response.response}) | |
| except Exception as e: | |
| st.error(f"An error occurred during chat: {e}") | |
| else: | |
| st.info("π¬ Chat is ready! Ask me anything about SHL assessments.") | |
| if __name__ == "__main__": | |
| main() | |