Spaces:
Sleeping
Sleeping
| #!/usr/bin/env -S poetry run python | |
| import os | |
| import json | |
| import streamlit as st | |
| from openai import OpenAI | |
| from dotenv import load_dotenv | |
| # Load environment variables from .env file | |
| load_dotenv() | |
| # Get the OpenAI API key from environment variables | |
| api_key = os.getenv("OPENAI_API_KEY") | |
| if not api_key: | |
| raise ValueError("The OPENAI_API_KEY environment variable is not set.") | |
| client = OpenAI() | |
| def load_user_data(user_id): | |
| file_path = os.path.join(os.getcwd(), "data", "user_data", f"user_data_{user_id}.json") | |
| #st.write(f"Loading user data from: {file_path}") | |
| #st.write(f"Current working directory: {os.getcwd()}") | |
| #Verify if the file exists | |
| if not os.path.exists(file_path): | |
| #st.write("File does not exist.") | |
| return {} | |
| try: | |
| with open(file_path, "r") as file: | |
| data = json.load(file) | |
| #st.write(f"Loaded data: {data}") | |
| return data | |
| except json.JSONDecodeError: | |
| st.write("Error decoding JSON.") | |
| return {} | |
| except Exception as e: | |
| st.write(f"An error occurred: {e}") | |
| return {} | |
| def save_user_data(user_id, data): | |
| file_path = os.path.join("data", "user_data", f"user_data_{user_id}.json") | |
| os.makedirs(os.path.dirname(file_path), exist_ok=True) | |
| with open(file_path, "w") as file: | |
| json.dump(data, file) | |
| def parseBill(data): | |
| billDate = data.get("billDate") | |
| billNo = data.get("billNo") | |
| amountDue = data.get("amountDue") | |
| extraCharge = data.get("extraCharge") | |
| taxItems = data.get("taxItem", []) | |
| subscribers = data.get("subscribers", []) | |
| totalBillCosts = [{"categorie": t.get("cat"), "amount": t.get("amt")} for t in taxItems] | |
| subscriberCosts = [] | |
| categories = set() | |
| names = set() | |
| for tax in taxItems: | |
| categories.add(tax.get("cat")) | |
| for sub in subscribers: | |
| logicalResource = sub.get("logicalResource") | |
| billSummaryItems = sub.get("billSummaryItem", []) | |
| for item in billSummaryItems: | |
| try: | |
| categories.add(item["cat"]), | |
| categories.add(item["name"]), | |
| names.add(item["name"]) | |
| except KeyError: | |
| continue | |
| subscriberCosts.append({ | |
| "Numar telefon": logicalResource, | |
| "Categorie cost": item["cat"], | |
| "Cost": item["name"], | |
| "Valoare": item["amt"] | |
| }) | |
| #st.write(f"Costuri totale factura: {totalBillCosts}") | |
| #st.write(f"Costuri utilizatori: {subscriberCosts}") | |
| #st.write(f"Categorii: {categories}") | |
| return { | |
| "Data factura": billDate, | |
| "Serie numar factura": billNo, | |
| "Total de plata": amountDue, | |
| "Costuri suplimentare": extraCharge, | |
| "Total plata factura": totalBillCosts, | |
| "Costuri utilizatori": subscriberCosts, | |
| "Entities": list(categories), | |
| "Costuri": list(names) | |
| } | |
| def check_related_keys(question, user_id): | |
| user_data = load_user_data(user_id) | |
| categories = set() | |
| for bill in user_data.get("bills", []): | |
| categories.update(bill.get("Entities", [])) | |
| #st.write(f"Entities: {categories}") | |
| return [category for category in categories if any(word.lower() in question.lower() for word in category.split())] | |
| def process_query(query, user_id, model_name): | |
| user_data = load_user_data(user_id) | |
| bill_info = user_data.get("bills", []) | |
| related_keys = check_related_keys(query, user_id) | |
| related_keys_str = ", ".join(related_keys) if related_keys else "N/A" | |
| if related_keys_str != "N/A": | |
| context = ( | |
| f"Citeste informatiile despre costrurile in lei facturate din json: {bill_info}" | |
| f"si raspunde la intrebarea sau afirmatia: '{query}' dar mai ales cu info legate de: {related_keys_str}. Pentru orice alt subiect raspunde ca nu poti oferi decat informatii despre facturi." | |
| ) | |
| else: | |
| context = ( | |
| f"Citeste informatiile despre costrurile in lei facturate din json: {bill_info}" | |
| f"si raspunde la intrebarea sau afirmatia: '{query}' mai ales cu info din factura. Pentru orice alt subiect raspunde ca nu poti oferi decat informatii despre facturi." | |
| ) | |
| max_input_length = 7550 | |
| #st.write(f"Context:\n{context}") | |
| st.write(f"Context size: {len(context)} characters") | |
| if len(context) > max_input_length: | |
| st.warning("Prea multe caractere în context, solicitarea nu va fi trimisă.") | |
| return None | |
| # Update this part to run the chosen model | |
| if model_name == "gpt-4o-mini": | |
| # Code to run model 4o mini | |
| st.write("Running model GPT-4o-mini") | |
| elif model_name == "gpt-4o": | |
| # Code to run model 4o | |
| st.write("Running model GPT-4o") | |
| return context | |
| # import the datetime class from the datetime module | |
| from datetime import datetime | |
| def log_conversation(user_id, user_query, assistant_response, tokens, cost): | |
| log_entry = { | |
| "timestamp": datetime.now().isoformat(), | |
| "user_id": user_id, | |
| "user_query": user_query, | |
| "assistant_response": assistant_response, | |
| "tokens": tokens, | |
| "cost": cost | |
| } | |
| log_file_path = os.path.join("logs", "conversation_logs.json") | |
| os.makedirs(os.path.dirname(log_file_path), exist_ok=True) | |
| if os.path.exists(log_file_path): | |
| with open(log_file_path, "r") as log_file: | |
| logs = json.load(log_file) | |
| else: | |
| logs = [] | |
| logs.append(log_entry) | |
| with open(log_file_path, "w") as log_file: | |
| json.dump(logs, log_file, indent=4) | |
| def main(): | |
| st.title("Bill info LLM Agent (OpenAI)") | |
| st.image("https://miro.medium.com/v2/resize:fit:100/format:webp/1*NfE0G4nEj4xX7Z_8dSx83g.png") | |
| # Create a sidebar menu to choose between models | |
| model_name = st.sidebar.selectbox("Choose OpenAI Model", ["gpt-4o-mini", "gpt-4o"]) | |
| if "user_id" not in st.session_state: | |
| st.session_state.user_id = None | |
| user_id = st.sidebar.text_input("Introdu numărul de telefon:", placeholder="Incearca 0724077190") | |
| # use a predefined user_id for testing | |
| # display the user data if the user_id is set | |
| #st.write(f"User ID: {user_id}") | |
| st.session_state.user_data = None | |
| if user_id and user_id != st.session_state.user_id: | |
| data = load_user_data(user_id) | |
| st.session_state.user_data = data | |
| if data: | |
| st.session_state.user_id = user_id | |
| st.success("Utilizator găsit!") | |
| st.write(f"Numar telefon: {st.session_state.user_id}") | |
| st.session_state.user_data = data | |
| else: | |
| st.warning("Nu am găsit date pentru acest ID.") | |
| st.warning("Încărcați o factură json.") | |
| st.session_state.user_id = user_id | |
| st.session_state.user_data = None | |
| # If the user has no data yet Show the upload (st.file_uploader...) dialog , | |
| # If the user has stored data in data\user_data\"user_data{user_id}.json, display the existing bills data - st.write(bill) but compacted | |
| if st.session_state.user_data: | |
| st.write("Facturi existente (extras):") | |
| for bill in st.session_state.user_data.get("bills", []): | |
| st.write({ | |
| "Data factura": bill.get("Data factura"), | |
| "Serie numar factura": bill.get("Serie numar factura"), | |
| "Total de plata": bill.get("Total de plata"), | |
| "Costuri suplimentare": bill.get("Costuri suplimentare") | |
| }) | |
| if not st.session_state.user_data: | |
| uploaded_file = st.file_uploader("Incarca factura", type="json") | |
| if uploaded_file and st.session_state.user_id: | |
| bill_data = json.load(uploaded_file) | |
| parsed_bill = parseBill(bill_data) | |
| existing_data = load_user_data(st.session_state.user_id) | |
| # Check if the billNo already exists in the existing data | |
| existing_bill_nos = [bill.get("Data factura") for bill in existing_data.get("bills", [])] | |
| if parsed_bill.get("Data factura") in existing_bill_nos: | |
| st.warning("Factură existentă.") | |
| else: | |
| if "bills" not in existing_data: | |
| existing_data["bills"] = [] | |
| existing_data["bills"].append(parsed_bill) | |
| save_user_data(st.session_state.user_id, existing_data) | |
| st.success("Factura a fost încărcată și salvată cu succes!") | |
| # Initialize conversation in the session state | |
| # "context_prompt_added" indicates whether we've added the specialized "bill info" context yet. | |
| if "messages" not in st.session_state: | |
| st.session_state["messages"] = [ | |
| {"role": "assistant", "content": "Cu ce te pot ajuta?"} | |
| ] | |
| if "context_prompt_added" not in st.session_state: | |
| st.session_state.context_prompt_added = False | |
| st.write("---") | |
| st.subheader("Chat") | |
| for msg in st.session_state["messages"]: | |
| st.chat_message(msg["role"]).write(msg["content"]) | |
| if prompt := st.chat_input("Introduceți întrebarea aici:"): | |
| if not st.session_state.user_id: | |
| st.error("Trebuie să introduci un număr de telefon valid sau să încarci date.") | |
| return | |
| # If the context prompt hasn't been added yet, build & inject it once; | |
| # otherwise, just add the user's raw question. | |
| if not st.session_state.context_prompt_added: | |
| final_prompt = process_query(prompt, st.session_state["user_id"], model_name) | |
| if final_prompt is None: | |
| st.stop() | |
| st.session_state["messages"].append({"role": "user", "content": final_prompt}) | |
| st.session_state.context_prompt_added = True | |
| else: | |
| st.session_state["messages"].append({"role": "user", "content": prompt}) | |
| # Display the latest user message in the chat | |
| st.chat_message("user").write(st.session_state["messages"][-1]["content"]) | |
| # Display the related keys | |
| related_keys = check_related_keys(prompt, st.session_state["user_id"]) | |
| st.write("Focus pe entitatile:", related_keys) | |
| # Now call GPT-4 with the entire conversation | |
| completion = client.chat.completions.create( | |
| model=model_name, | |
| messages=st.session_state["messages"] | |
| ) | |
| response_text = completion.choices[0].message.content.strip() | |
| st.session_state["messages"].append({"role": "assistant", "content": response_text}) | |
| st.chat_message("assistant").write(response_text) | |
| if hasattr(completion, "usage"): | |
| st.write("Prompt tokens:", completion.usage.prompt_tokens) | |
| st.write("Completion tokens:", completion.usage.completion_tokens) | |
| st.write("Total tokens:", completion.usage.total_tokens) | |
| # Estimate cost per conversation (find the OpenAI costs for gpt-4o and gpt-4o-mini model per token) | |
| prompt_tokens = completion.usage.prompt_tokens | |
| completion_tokens = completion.usage.completion_tokens | |
| total_tokens = completion.usage.total_tokens | |
| # Estimate cost per conversation | |
| if model_name == "gpt-4o": | |
| cost_per_token = 0.03 / 1000 # $0.03 per 1,000 tokens | |
| elif model_name == "gpt-4o-mini": | |
| cost_per_token = 0.015 / 1000 # $0.015 per 1,000 tokens | |
| estimated_cost = total_tokens * cost_per_token | |
| st.write("Estimated cost:", estimated_cost) | |
| # Log the conversation | |
| log_conversation( | |
| st.session_state["user_id"], | |
| prompt, | |
| response_text, | |
| { | |
| "prompt_tokens": prompt_tokens, | |
| "completion_tokens": completion_tokens, | |
| "total_tokens": total_tokens | |
| }, | |
| estimated_cost | |
| ) | |
| if __name__ == "__main__": | |
| main() |