Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import pandas as pd | |
| import os | |
| from datetime import datetime | |
| import random | |
| from pathlib import Path | |
| from openai import OpenAI | |
| from dotenv import load_dotenv | |
| from langchain_core.prompts import PromptTemplate | |
| # Initialize the client | |
| # Load environment variables | |
| load_dotenv() | |
| client = OpenAI( | |
| base_url="https://api-inference.huggingface.co/v1", | |
| api_key=os.environ.get('TOKEN2') # Add your Huggingface token here | |
| ) | |
| # Load environment variables | |
| ##load_dotenv() | |
| ##openai_api_key = os.getenv("OPENAI_API_KEY") | |
| # Initialize OpenAI client | |
| ##client = OpenAI(api_key=os.getenv('OPENAI_API_KEY')) | |
| # Custom CSS for better appearance | |
| st.markdown(""" | |
| <style> | |
| .stButton > button { | |
| width: 100%; | |
| margin-bottom: 10px; | |
| background-color: #4CAF50; | |
| color: white; | |
| border: none; | |
| padding: 10px; | |
| border-radius: 5px; | |
| } | |
| .task-button { | |
| background-color: #2196F3 !important; | |
| } | |
| .stSelectbox { | |
| margin-bottom: 20px; | |
| } | |
| .output-container { | |
| padding: 20px; | |
| border-radius: 5px; | |
| border: 1px solid #ddd; | |
| margin: 10px 0; | |
| } | |
| .status-container { | |
| padding: 10px; | |
| border-radius: 5px; | |
| margin: 10px 0; | |
| } | |
| .sidebar-info { | |
| padding: 10px; | |
| background-color: #f0f2f6; | |
| border-radius: 5px; | |
| margin: 10px 0; | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| # Create data directories if they don't exist | |
| if not os.path.exists('data'): | |
| os.makedirs('data') | |
| def read_csv_with_encoding(file): | |
| encodings = ['utf-8', 'latin1', 'iso-8859-1', 'cp1252'] | |
| for encoding in encodings: | |
| try: | |
| return pd.read_csv(file, encoding=encoding) | |
| except UnicodeDecodeError: | |
| continue | |
| raise UnicodeDecodeError("Failed to read file with any supported encoding") | |
| def save_to_csv(data, filename): | |
| df = pd.DataFrame(data) | |
| df.to_csv(f'data/{filename}', index=False) | |
| return df | |
| def load_from_csv(filename): | |
| try: | |
| return pd.read_csv(f'data/{filename}') | |
| except: | |
| return pd.DataFrame() | |
| # Define reset function | |
| def reset_conversation(): | |
| st.session_state.conversation = [] | |
| st.session_state.messages = [] | |
| # Initialize session state | |
| if "messages" not in st.session_state: | |
| st.session_state.messages = [] | |
| # Main app title | |
| st.title("🤖 LangChain-Based Data Interaction App") | |
| # Sidebar settings | |
| with st.sidebar: | |
| st.title("⚙️ Settings") | |
| selected_model = st.selectbox( | |
| "Select Model", | |
| ["meta-llama/Meta-Llama-3-8B-Instruct"], | |
| key='model_select' | |
| ) | |
| temperature = st.slider( | |
| "Temperature", | |
| 0.0, 1.0, 0.5, | |
| help="Controls randomness in generation" | |
| ) | |
| st.button("🔄 Reset Conversation", on_click=reset_conversation) | |
| with st.container(): | |
| st.markdown(""" | |
| <div class="sidebar-info"> | |
| <h4>Current Model: {}</h4> | |
| <p><em>Note: Generated content may be inaccurate or false.</em></p> | |
| </div> | |
| """.format(selected_model), unsafe_allow_html=True) | |
| # Main content | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| if st.button("📝 Data Generation", key="gen_button", help="Generate new data"): | |
| st.session_state.task_choice = "Data Generation" | |
| with col2: | |
| if st.button("🏷️ Data Labeling", key="label_button", help="Label existing data"): | |
| st.session_state.task_choice = "Data Labeling" | |
| if "task_choice" in st.session_state: | |
| if st.session_state.task_choice == "Data Generation": | |
| st.header("📝 Data Generation") | |
| classification_type = st.selectbox( | |
| "Classification Type", | |
| ["Sentiment Analysis", "Binary Classification", "Multi-Class Classification"] | |
| ) | |
| if classification_type == "Sentiment Analysis": | |
| labels = ["Positive", "Negative", "Neutral"] | |
| elif classification_type == "Binary Classification": | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| label_1 = st.text_input("First class", "Positive") | |
| with col2: | |
| label_2 = st.text_input("Second class", "Negative") | |
| labels = [label_1, label_2] if label_1 and label_2 else ["Positive", "Negative"] | |
| else: | |
| num_classes = st.slider("Number of classes", 3, 10, 3) | |
| labels = [] | |
| cols = st.columns(3) | |
| for i in range(num_classes): | |
| with cols[i % 3]: | |
| label = st.text_input(f"Class {i+1}", f"Class_{i+1}") | |
| labels.append(label) | |
| domain = st.selectbox("Domain", ["Restaurant reviews", "E-commerce reviews", "Custom"]) | |
| if domain == "Custom": | |
| domain = st.text_input("Specify custom domain") | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| min_words = st.number_input("Min words", 10, 90, 20) | |
| with col2: | |
| max_words = st.number_input("Max words", min_words, 90, 50) | |
| use_few_shot = st.toggle("Use few-shot examples") | |
| few_shot_examples = [] | |
| if use_few_shot: | |
| num_examples = st.slider("Number of few-shot examples", 1, 5, 1) | |
| for i in range(num_examples): | |
| with st.expander(f"Example {i+1}"): | |
| content = st.text_area(f"Content", key=f"few_shot_content_{i}") | |
| label = st.selectbox(f"Label", labels, key=f"few_shot_label_{i}") | |
| if content and label: | |
| few_shot_examples.append({"content": content, "label": label}) | |
| num_to_generate = st.number_input("Number of examples", 1, 100, 10) | |
| user_prompt = st.text_area("Additional instructions (optional)") | |
| # Updated prompt template with word length constraints | |
| prompt_template = PromptTemplate( | |
| input_variables=["classification_type", "domain", "num_examples", "min_words", "max_words", "labels", "user_prompt"], | |
| template=( | |
| "You are a professional {classification_type} expert tasked with generating examples for {domain}.\n" | |
| "Use the following parameters:\n" | |
| "- Generate exactly {num_examples} examples\n" | |
| "- Each example MUST be between {min_words} and {max_words} words long\n" | |
| "- Use these labels: {labels}\n" | |
| "- Generate the examples in this format: 'Example text. Label: [label]'\n" | |
| "- Do not include word counts or any additional information\n" | |
| "Additional instructions: {user_prompt}\n\n" | |
| "Generate numbered examples:" | |
| ) | |
| ) | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| if st.button("🎯 Generate Examples"): | |
| with st.spinner("Generating examples..."): | |
| system_prompt = prompt_template.format( | |
| classification_type=classification_type, | |
| domain=domain, | |
| num_examples=num_to_generate, | |
| min_words=min_words, | |
| max_words=max_words, | |
| labels=", ".join(labels), | |
| user_prompt=user_prompt | |
| ) | |
| try: | |
| stream = client.chat.completions.create( | |
| model=selected_model, | |
| messages=[{"role": "system", "content": system_prompt}], | |
| temperature=temperature, | |
| stream=True, | |
| max_tokens=3000, | |
| ) | |
| response = st.write_stream(stream) | |
| st.session_state.messages.append({"role": "assistant", "content": response}) | |
| except Exception as e: | |
| st.error("An error occurred during generation.") | |
| st.error(f"Details: {e}") | |
| with col2: | |
| if st.button("🔄 Regenerate"): | |
| st.session_state.messages = st.session_state.messages[:-1] if st.session_state.messages else [] | |
| with st.spinner("Regenerating examples..."): | |
| system_prompt = prompt_template.format( | |
| classification_type=classification_type, | |
| domain=domain, | |
| num_examples=num_to_generate, | |
| min_words=min_words, | |
| max_words=max_words, | |
| labels=", ".join(labels), | |
| user_prompt=user_prompt | |
| ) | |
| try: | |
| stream = client.chat.completions.create( | |
| model=selected_model, | |
| messages=[{"role": "system", "content": system_prompt}], | |
| temperature=temperature, | |
| stream=True, | |
| max_tokens=3000, | |
| ) | |
| response = st.write_stream(stream) | |
| st.session_state.messages.append({"role": "assistant", "content": response}) | |
| except Exception as e: | |
| st.error("An error occurred during regeneration.") | |
| st.error(f"Details: {e}") | |
| elif st.session_state.task_choice == "Data Labeling": | |
| st.header("🏷️ Data Labeling") | |
| classification_type = st.selectbox( | |
| "Classification Type", | |
| ["Sentiment Analysis", "Binary Classification", "Multi-Class Classification"], | |
| key="label_class_type" | |
| ) | |
| if classification_type == "Sentiment Analysis": | |
| labels = ["Positive", "Negative", "Neutral"] | |
| elif classification_type == "Binary Classification": | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| label_1 = st.text_input("First class", "Positive", key="label_first") | |
| with col2: | |
| label_2 = st.text_input("Second class", "Negative", key="label_second") | |
| labels = [label_1, label_2] if label_1 and label_2 else ["Positive", "Negative"] | |
| else: | |
| num_classes = st.slider("Number of classes", 3, 10, 3, key="label_num_classes") | |
| labels = [] | |
| cols = st.columns(3) | |
| for i in range(num_classes): | |
| with cols[i % 3]: | |
| label = st.text_input(f"Class {i+1}", f"Class_{i+1}", key=f"label_class_{i}") | |
| labels.append(label) | |
| use_few_shot = st.toggle("Use few-shot examples for labeling") | |
| few_shot_examples = [] | |
| if use_few_shot: | |
| num_few_shot = st.slider("Number of few-shot examples", 1, 5, 1) | |
| for i in range(num_few_shot): | |
| with st.expander(f"Few-shot Example {i+1}"): | |
| content = st.text_area(f"Content", key=f"label_few_shot_content_{i}") | |
| label = st.selectbox(f"Label", labels, key=f"label_few_shot_label_{i}") | |
| if content and label: | |
| few_shot_examples.append(f"{content}\nLabel: {label}") | |
| num_examples = st.number_input("Number of examples to classify", 1, 100, 1) | |
| examples_to_classify = [] | |
| if num_examples <= 20: | |
| for i in range(num_examples): | |
| example = st.text_area(f"Example {i+1}", key=f"example_{i}") | |
| if example: | |
| examples_to_classify.append(example) | |
| else: | |
| examples_text = st.text_area( | |
| "Enter examples (one per line)", | |
| height=300, | |
| help="Enter each example on a new line" | |
| ) | |
| if examples_text: | |
| examples_to_classify = [ex.strip() for ex in examples_text.split('\n') if ex.strip()] | |
| if len(examples_to_classify) > num_examples: | |
| examples_to_classify = examples_to_classify[:num_examples] | |
| user_prompt = st.text_area("Additional instructions (optional)", key="label_instructions") | |
| # Updated prompt template for labeling | |
| few_shot_text = "\n\n".join(few_shot_examples) if few_shot_examples else "" | |
| examples_text = "\n".join([f"{i+1}. {ex}" for i, ex in enumerate(examples_to_classify)]) | |
| label_prompt_template = PromptTemplate( | |
| input_variables=["classification_type", "labels", "few_shot_examples", "examples", "user_prompt"], | |
| template=( | |
| "You are a professional {classification_type} expert. Classify the following examples using these labels: {labels}.\n" | |
| "Instructions:\n" | |
| "- Return the numbered example followed by its classification in the format: 'Example text. Label: [label]'\n" | |
| "- Do not provide any additional information or explanations\n" | |
| "{user_prompt}\n\n" | |
| "Few-shot examples:\n{few_shot_examples}\n\n" | |
| "Examples to classify:\n{examples}\n\n" | |
| "Output:\n" | |
| ) | |
| ) | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| if st.button("🏷️ Label Data"): | |
| if examples_to_classify: | |
| with st.spinner("Labeling data..."): | |
| system_prompt = label_prompt_template.format( | |
| classification_type=classification_type, | |
| labels=", ".join(labels), | |
| few_shot_examples=few_shot_text, | |
| examples=examples_text, | |
| user_prompt=user_prompt | |
| ) | |
| try: | |
| stream = client.chat.completions.create( | |
| model=selected_model, | |
| messages=[{"role": "system", "content": system_prompt}], | |
| temperature=temperature, | |
| stream=True, | |
| max_tokens=3000, | |
| ) | |
| response = st.write_stream(stream) | |
| st.session_state.messages.append({"role": "assistant", "content": response}) | |
| except Exception as e: | |
| st.error("An error occurred during labeling.") | |
| st.error(f"Details: {e}") | |
| else: | |
| st.warning("Please enter at least one example to classify.") | |
| with col2: | |
| if st.button("🔄 Relabel"): | |
| if examples_to_classify: | |
| st.session_state.messages = st.session_state.messages[:-1] if st.session_state.messages else [] | |
| with st.spinner("Relabeling data..."): | |
| system_prompt = label_prompt_template.format( | |
| classification_type=classification_type, | |
| labels=", ".join(labels), | |
| few_shot_examples=few_shot_text, | |
| examples=examples_text, | |
| user_prompt=user_prompt | |
| ) | |
| try: | |
| stream = client.chat.completions.create( | |
| model=selected_model, | |
| messages=[{"role": "system", "content": system_prompt}], | |
| temperature=temperature, | |
| stream=True, | |
| max_tokens=3000, | |
| ) | |
| response = st.write_stream(stream) | |
| st.session_state.messages.append({"role": "assistant", "content": response}) | |
| except Exception as e: | |
| st.error("An error occurred during relabeling.") | |
| st.error(f"Details: {e}") | |
| else: | |
| st.warning("Please enter at least one example to classify.") |