Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import sqlite3 | |
| import json | |
| import pandas as pd | |
| from openai import OpenAI | |
| import traceback | |
| from typing import Dict, List, Tuple, Any | |
| import re | |
| from datetime import datetime | |
| import threading | |
| import queue | |
| import html | |
| import sys | |
| import os | |
| # Force stdout to use UTF-8 encoding to handle Unicode characters | |
| if sys.stdout.encoding != 'utf-8': | |
| sys.stdout = open(sys.stdout.fileno(), mode='w', encoding='utf-8', buffering=1) | |
| class DatabaseQueryAgent: | |
| def __init__(self, db_path: str = "innovativeskills.db"): | |
| self.db_path = db_path | |
| self.client = None | |
| # Available models | |
| self.models = { | |
| "llama": "meta-llama/llama-3.3-70b-instruct:free", | |
| "mistral": "mistralai/mistral-7b-instruct:free", | |
| "gemma": "google/gemma-2-9b-it:free" # Verification model | |
| } | |
| # Initialize database connection | |
| self.init_db_connection() | |
| def init_db_connection(self): | |
| """Initialize database connection with UTF-8 encoding""" | |
| try: | |
| conn = sqlite3.connect(self.db_path, check_same_thread=False) | |
| conn.execute("PRAGMA encoding = 'UTF-8';") | |
| cursor = conn.cursor() | |
| # Load table metadata | |
| self.table_metadata = self.get_table_metadata(conn, cursor) | |
| self.column_metadata = self.get_column_metadata(conn, cursor) | |
| self.actual_schema = self.get_actual_schema(conn, cursor) | |
| conn.close() | |
| except Exception as e: | |
| print(f"Database initialization error: {e}") | |
| self.table_metadata = {} | |
| self.column_metadata = {} | |
| self.actual_schema = {} | |
| def get_db_connection(self): | |
| """Get a new database connection with UTF-8 encoding""" | |
| conn = sqlite3.connect(self.db_path, check_same_thread=False) | |
| conn.execute("PRAGMA encoding = 'UTF-8';") | |
| return conn | |
| def get_actual_schema(self, conn, cursor) -> Dict: | |
| """Get actual database schema""" | |
| try: | |
| cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'") | |
| tables = [row[0] for row in cursor.fetchall()] | |
| schema = {} | |
| for table in tables: | |
| cursor.execute(f"PRAGMA table_info({table})") | |
| columns = cursor.fetchall() | |
| try: | |
| cursor.execute(f"SELECT * FROM {table} LIMIT 3") | |
| sample_data = cursor.fetchall() | |
| except Exception: | |
| sample_data = [] | |
| try: | |
| cursor.execute(f"SELECT COUNT(*) FROM {table}") | |
| row_count = cursor.fetchone()[0] | |
| except Exception: | |
| row_count = 0 | |
| schema[table] = { | |
| 'columns': [{'name': col[1], 'type': col[2], 'notnull': col[3], 'pk': col[5]} for col in columns], | |
| 'sample_data': sample_data, | |
| 'row_count': row_count | |
| } | |
| return schema | |
| except Exception as e: | |
| print(f"Error getting actual schema: {e}") | |
| return {} | |
| def get_table_metadata(self, conn, cursor) -> Dict: | |
| """Get table metadata""" | |
| try: | |
| query = """ | |
| SELECT table_name, domain, description, row_count | |
| FROM table_catalog | |
| WHERE table_name NOT IN ('table_catalog', 'column_catalog') | |
| """ | |
| results = cursor.execute(query).fetchall() | |
| metadata = {} | |
| for table_name, domain, description, row_count in results: | |
| metadata[table_name] = { | |
| 'domain': domain, | |
| 'description': description, | |
| 'row_count': row_count | |
| } | |
| return metadata | |
| except Exception as e: | |
| print(f"Error loading table metadata: {e}") | |
| return {} | |
| def get_column_metadata(self, conn, cursor) -> Dict: | |
| """Get column metadata""" | |
| try: | |
| query = """ | |
| SELECT table_name, column_name, data_type, is_foreign_key, references_table, description | |
| FROM column_catalog | |
| """ | |
| results = cursor.execute(query).fetchall() | |
| metadata = {} | |
| for table_name, column_name, data_type, is_fk, ref_table, description in results: | |
| if table_name not in metadata: | |
| metadata[table_name] = [] | |
| metadata[table_name].append({ | |
| 'name': column_name, | |
| 'type': data_type, | |
| 'is_foreign_key': bool(is_fk), | |
| 'references': ref_table, | |
| 'description': description | |
| }) | |
| return metadata | |
| except Exception as e: | |
| print(f"Error loading column metadata: {e}") | |
| return {} | |
| def setup_client(self, api_key: str): | |
| """Setup OpenRouter client""" | |
| self.client = OpenAI( | |
| base_url="https://openrouter.ai/api/v1", | |
| api_key=api_key, | |
| ) | |
| def get_relevant_tables_for_query(self, query: str) -> str: | |
| """Analyze query and return relevant table info""" | |
| query_lower = query.lower() | |
| relevant_tables = [] | |
| keywords = { | |
| 'customer': ['customer', 'client', 'buyer', 'user'], | |
| 'order': ['order', 'purchase', 'transaction', 'sale'], | |
| 'product': ['product', 'item', 'inventory', 'stock'], | |
| 'employee': ['employee', 'staff', 'worker', 'personnel'], | |
| 'patient': ['patient', 'medical', 'health'], | |
| 'student': ['student', 'enrollment', 'grade', 'course'], | |
| 'supplier': ['supplier', 'vendor', 'provider'], | |
| 'shipping': ['shipping', 'delivery', 'logistics'], | |
| 'payment': ['payment', 'invoice', 'billing'], | |
| 'account': ['account', 'financial', 'balance'] | |
| } | |
| for concept, search_terms in keywords.items(): | |
| if any(term in query_lower for term in search_terms): | |
| for table_name in self.actual_schema.keys(): | |
| table_lower = table_name.lower() | |
| if any(term in table_lower for term in search_terms): | |
| if table_name not in relevant_tables: | |
| relevant_tables.append(table_name) | |
| if not relevant_tables: | |
| relevant_tables = [name for name, info in self.actual_schema.items() | |
| if info['row_count'] > 10][:10] | |
| schema_info = "" | |
| for table in relevant_tables[:15]: | |
| if table in self.actual_schema: | |
| info = self.actual_schema[table] | |
| columns_str = ", ".join([f"{col['name']}({col['type']})" for col in info['columns']]) | |
| schema_info += f"\nTable: {table}\n" | |
| schema_info += f" Columns: {columns_str}\n" | |
| schema_info += f" Rows: {info['row_count']}\n" | |
| if table in self.table_metadata: | |
| meta = self.table_metadata[table] | |
| schema_info += f" Domain: {meta['domain']}\n" | |
| schema_info += f" Description: {meta['description']}\n" | |
| if info['sample_data']: | |
| schema_info += f" Sample: {info['sample_data'][0] if info['sample_data'] else 'No data'}\n" | |
| return schema_info | |
| def get_system_prompt(self, user_query: str) -> str: | |
| """Generate system prompt with actual schema""" | |
| relevant_schema = self.get_relevant_tables_for_query(user_query) | |
| return f"""You are an intelligent database query agent that specializes in identifying relevant tables and generating accurate SQL queries. | |
| DATABASE SCHEMA INFORMATION: | |
| {relevant_schema} | |
| CRITICAL SQL RULES: | |
| 1. NEVER use reserved words as table aliases (like 'to', 'from', 'where', 'select', etc.) | |
| 2. Use descriptive aliases like 'cust', 'ord', 'prod' instead | |
| 3. Only JOIN tables if you can identify a logical relationship between them | |
| 4. If no clear JOIN relationship exists, use separate SELECT statements or UNION | |
| 5. Always use the EXACT column names shown in the schema | |
| 6. Do not assume foreign key relationships unless explicitly shown | |
| CRITICAL: You MUST respond with ONLY a valid JSON object. No markdown, no explanations outside the JSON. | |
| Your response must be exactly in this JSON format: | |
| {{ | |
| "analysis": "Brief analysis of the query and table selection reasoning", | |
| "identified_tables": ["table1", "table2", "table3"], | |
| "domains_involved": ["domain1", "domain2"], | |
| "sql_query": "SELECT ... FROM ... WHERE ...", | |
| "explanation": "Step-by-step explanation of the query logic", | |
| "confidence": 0.95, | |
| "alternative_queries": ["Alternative SQL if applicable"] | |
| }} | |
| IMPORTANT RULES: | |
| 1. Respond with ONLY valid JSON - no markdown formatting | |
| 2. Use ONLY the actual table names shown in the schema above | |
| 3. Use ONLY the actual column names shown in the schema above | |
| 4. Generate syntactically correct SQL queries with proper aliases | |
| 5. Focus on tables that actually exist and have relevant data | |
| 6. Include confidence scores between 0.0 and 1.0 | |
| 7. Provide clear explanations | |
| 8. Ensure table names in 'identified_tables' match those used in 'sql_query' | |
| 9. Check that columns referenced in SQL actually exist in the tables | |
| 10. If no perfect match exists, choose the closest relevant tables and explain the compromise | |
| 11. Avoid reserved word aliases like 'to', 'from', 'order', 'select' | |
| QUERY ANALYSIS GUIDELINES: | |
| - For customer/order queries: Look for tables with customer-related or order-related names and columns | |
| - For employee queries: Look for tables with employee, staff, or HR-related names | |
| - For product queries: Look for tables with product, inventory, or item-related names | |
| - Always verify column names exist before using them in SQL | |
| - Use proper JOIN syntax when combining tables, but only if logical relationships exist | |
| - Include appropriate WHERE clauses when filtering is implied | |
| - If unsure about relationships, prefer simpler queries or multiple separate queries""" | |
| def extract_json_from_response(self, response_text: str) -> Dict: | |
| """Extract JSON from response text""" | |
| try: | |
| return json.loads(response_text) | |
| except json.JSONDecodeError: | |
| json_pattern = r'```json\s*(.*?)\s*```' | |
| json_match = re.search(json_pattern, response_text, re.DOTALL) | |
| if json_match: | |
| try: | |
| return json.loads(json_match.group(1)) | |
| except json.JSONDecodeError: | |
| pass | |
| json_pattern = r'\{.*\}' | |
| json_match = re.search(json_pattern, response_text, re.DOTALL) | |
| if json_match: | |
| try: | |
| return json.loads(json_match.group(0)) | |
| except json.JSONDecodeError: | |
| pass | |
| return self.create_fallback_response(response_text) | |
| def create_fallback_response(self, response_text: str) -> Dict: | |
| """Create a fallback response when JSON parsing fails""" | |
| sql_pattern = r'SELECT.*?(?:;|$)' | |
| sql_match = re.search(sql_pattern, response_text, re.IGNORECASE | re.DOTALL) | |
| sql_query = sql_match.group(0).strip(';') if sql_match else "" | |
| identified_tables = [table_name for table_name in self.actual_schema.keys() | |
| if table_name.lower() in response_text.lower()] | |
| domains_involved = [self.table_metadata[table]['domain'] for table in identified_tables | |
| if table in self.table_metadata and self.table_metadata[table]['domain'] not in domains_involved] | |
| return { | |
| "analysis": "Fallback analysis from unparseable response", | |
| "identified_tables": identified_tables[:5], | |
| "domains_involved": domains_involved[:3], | |
| "sql_query": sql_query, | |
| "explanation": "Response could not be parsed as JSON, extracted information where possible", | |
| "confidence": 0.5, | |
| "alternative_queries": [] | |
| } | |
| def validate_sql_query(self, sql_query: str, identified_tables: List[str]) -> Tuple[bool, str]: | |
| """Validate SQL query against schema""" | |
| try: | |
| if not sql_query.strip(): | |
| return False, "Empty SQL query" | |
| for table in identified_tables: | |
| if table not in self.actual_schema: | |
| return False, f"Table '{table}' does not exist in database" | |
| sql_upper = sql_query.upper() | |
| if not sql_upper.strip().startswith('SELECT'): | |
| return False, "Only SELECT queries are allowed" | |
| reserved_words = ['TO', 'FROM', 'WHERE', 'SELECT', 'ORDER', 'GROUP', 'HAVING', 'UNION', 'JOIN', 'ON'] | |
| alias_pattern = r'(?:FROM|JOIN)\s+(\w+)\s+(\w+)' | |
| aliases = re.findall(alias_pattern, sql_query, re.IGNORECASE) | |
| for table, alias in aliases: | |
| if alias.upper() in reserved_words: | |
| return False, f"Cannot use reserved word '{alias}' as table alias" | |
| for table in identified_tables: | |
| if table in sql_query: | |
| table_info = self.actual_schema[table] | |
| available_columns = [col['name'] for col in table_info['columns']] | |
| column_patterns = [ | |
| rf'{re.escape(table)}\.(\w+)', | |
| rf'\b(\w+)\.(\w+)', | |
| rf'SELECT\s+([^FROM]+)' | |
| ] | |
| for pattern in column_patterns: | |
| matches = re.findall(pattern, sql_query, re.IGNORECASE) | |
| for match in matches: | |
| if isinstance(match, tuple): | |
| column = match[1] if len(match) == 2 else match[0] if match else '' | |
| else: | |
| column = match | |
| if column.upper() in ['*', 'COUNT', 'SUM', 'AVG', 'MAX', 'MIN', 'DISTINCT']: | |
| continue | |
| if column and column not in available_columns and f'{table}.{column}' in sql_query: | |
| return False, f"Column '{column}' does not exist in table '{table}'" | |
| return True, "Query validation passed" | |
| except Exception as e: | |
| return False, f"Validation error: {str(e)}" | |
| def call_model(self, model_key: str, prompt: str, user_query: str) -> Dict: | |
| """Call specific model with prompt""" | |
| try: | |
| messages = [ | |
| {"role": "system", "content": prompt}, | |
| {"role": "user", "content": f"Query: {user_query}\n\nRespond with ONLY a valid JSON object following the exact format specified in the system prompt."} | |
| ] | |
| completion = self.client.chat.completions.create( | |
| model=self.models[model_key], | |
| messages=messages, | |
| temperature=0.1, | |
| max_tokens=2000 | |
| ) | |
| response = completion.choices[0].message.content.strip() | |
| parsed_response = self.extract_json_from_response(response) | |
| sql_query = parsed_response.get('sql_query', '') | |
| identified_tables = parsed_response.get('identified_tables', []) | |
| if sql_query: | |
| is_valid, validation_message = self.validate_sql_query(sql_query, identified_tables) | |
| parsed_response['sql_validation'] = { | |
| 'is_valid': is_valid, | |
| 'message': validation_message | |
| } | |
| return { | |
| "success": True, | |
| "response": parsed_response, | |
| "raw_response": response, | |
| "model": model_key | |
| } | |
| except Exception as e: | |
| return { | |
| "success": False, | |
| "error": str(e), | |
| "model": model_key | |
| } | |
| def verify_response(self, api_key: str, original_query: str, llama_response: Dict, mistral_response: Dict) -> Dict: | |
| """Use Gemma to verify responses""" | |
| self.setup_client(api_key) | |
| relevant_schema = self.get_relevant_tables_for_query(original_query) | |
| verification_prompt = f"""You are a database query verification expert. You have access to the actual database schema and must verify responses against it. | |
| ACTUAL DATABASE SCHEMA: | |
| {relevant_schema} | |
| ORIGINAL QUERY: {original_query} | |
| LLAMA RESPONSE: {json.dumps(llama_response.get('response', {}), indent=2)} | |
| MISTRAL RESPONSE: {json.dumps(mistral_response.get('response', {}), indent=2)} | |
| Verify these responses against the ACTUAL schema above. Check: | |
| 1. Do the table names actually exist in the schema? | |
| 2. Do the column names actually exist in those tables? | |
| 3. Are the table selections appropriate for the query? | |
| 4. Is the SQL syntax correct? | |
| 5. Are table aliases proper (not reserved words)? | |
| Respond with ONLY a valid JSON object: | |
| {{ | |
| "verification_summary": "Overall assessment based on actual schema", | |
| "table_selection_accuracy": "Assessment of table choices against actual schema", | |
| "sql_correctness": "SQL syntax and schema validation", | |
| "consistency_check": "Comparison between responses", | |
| "recommended_response": "llama, mistral, or neither", | |
| "confidence_score": 0.85, | |
| "suggested_improvements": ["improvement1", "improvement2"], | |
| "potential_issues": ["issue1", "issue2"], | |
| "schema_compliance": "Assessment of how well responses match actual schema" | |
| }}""" | |
| return self.call_model("gemma", verification_prompt, "Verify the above responses against the actual database schema.") | |
| def execute_query_in_thread(self, sql_query: str, result_queue: queue.Queue): | |
| """Execute SQL query in a thread""" | |
| try: | |
| if not sql_query.strip().upper().startswith('SELECT'): | |
| result_queue.put((False, "Only SELECT queries are allowed")) | |
| return | |
| sql_query = sql_query.strip().rstrip(';') | |
| conn = self.get_db_connection() | |
| try: | |
| df = pd.read_sql_query(sql_query, conn) | |
| result_queue.put((True, df)) | |
| except Exception as e: | |
| result_queue.put((False, str(e))) | |
| finally: | |
| conn.close() | |
| except Exception as e: | |
| result_queue.put((False, f"Query execution error: {str(e)}")) | |
| def execute_query(self, sql_query: str) -> Tuple[bool, Any]: | |
| """Execute SQL query using thread-safe approach""" | |
| try: | |
| result_queue = queue.Queue() | |
| thread = threading.Thread( | |
| target=self.execute_query_in_thread, | |
| args=(sql_query, result_queue) | |
| ) | |
| thread.start() | |
| thread.join(timeout=30) | |
| if thread.is_alive(): | |
| return False, "Query execution timed out" | |
| if not result_queue.empty(): | |
| return result_queue.get() | |
| else: | |
| return False, "No result returned from query execution" | |
| except Exception as e: | |
| return False, f"Execution error: {str(e)}" | |
| def process_query(self, api_key: str, user_query: str) -> Dict: | |
| """Process user query""" | |
| if not api_key: | |
| return {"error": "Please provide OpenRouter API key"} | |
| try: | |
| self.setup_client(api_key) | |
| system_prompt = self.get_system_prompt(user_query) | |
| llama_result = self.call_model("llama", system_prompt, user_query) | |
| mistral_result = self.call_model("mistral", system_prompt, user_query) | |
| verification_result = self.verify_response(api_key, user_query, llama_result, mistral_result) | |
| execution_results = {} | |
| for model_name, result in [("llama", llama_result), ("mistral", mistral_result)]: | |
| if result.get("success") and result.get("response", {}).get("sql_query"): | |
| sql_query = result["response"]["sql_query"] | |
| validation_info = result["response"].get("sql_validation", {}) | |
| if sql_query.strip(): | |
| if validation_info.get("is_valid", True): | |
| success, data = self.execute_query(sql_query) | |
| execution_results[model_name] = { | |
| "success": success, | |
| "data": data.to_dict('records') if success and isinstance(data, pd.DataFrame) else str(data), | |
| "row_count": len(data) if success and isinstance(data, pd.DataFrame) else 0, | |
| "sql_query": sql_query, | |
| "validation": validation_info | |
| } | |
| else: | |
| execution_results[model_name] = { | |
| "success": False, | |
| "data": f"Query validation failed: {validation_info.get('message', 'Unknown error')}", | |
| "row_count": 0, | |
| "sql_query": sql_query, | |
| "validation": validation_info | |
| } | |
| else: | |
| execution_results[model_name] = { | |
| "success": False, | |
| "data": "No SQL query generated", | |
| "row_count": 0, | |
| "sql_query": "", | |
| "validation": {"is_valid": False, "message": "Empty query"} | |
| } | |
| else: | |
| execution_results[model_name] = { | |
| "success": False, | |
| "data": "Model failed to generate response", | |
| "row_count": 0, | |
| "sql_query": "", | |
| "validation": {"is_valid": False, "message": "Model error"} | |
| } | |
| return { | |
| "llama_response": llama_result, | |
| "mistral_response": mistral_result, | |
| "verification": verification_result, | |
| "execution_results": execution_results, | |
| "timestamp": datetime.now().isoformat(), | |
| "schema_info": self.get_relevant_tables_for_query(user_query) | |
| } | |
| except Exception as e: | |
| return {"error": f"Processing error: {str(e)}", "traceback": traceback.format_exc()} | |
| def response_to_markdown(response_dict: Dict) -> str: | |
| """Convert model response to Markdown""" | |
| if not response_dict.get("success", False): | |
| return f"**Error**: {response_dict.get('error', 'Unknown error')}" | |
| response = response_dict.get("response", {}) | |
| markdown = "**Query Analysis Results**\n\n" | |
| markdown += f"- **Analysis**: {response.get('analysis', 'N/A')}\n\n" | |
| identified_tables = response.get('identified_tables', []) | |
| markdown += f"- **Identified Tables**: {', '.join(identified_tables) if identified_tables else 'None'}\n\n" | |
| domains_involved = response.get('domains_involved', []) | |
| markdown += f"- **Domains Involved**: {', '.join(domains_involved) if domains_involved else 'None'}\n\n" | |
| sql_query = response.get('sql_query', '') | |
| if sql_query: | |
| markdown += "- **SQL Query**:\n\n```sql\n" + sql_query + "\n```\n\n" | |
| else: | |
| markdown += "- **SQL Query**: None\n\n" | |
| markdown += f"- **Explanation**: {response.get('explanation', 'N/A')}\n\n" | |
| markdown += f"- **Confidence**: {response.get('confidence', 'N/A')}\n\n" | |
| alternative_queries = response.get('alternative_queries', []) | |
| if alternative_queries: | |
| markdown += "- **Alternative Queries**:\n" | |
| for query in alternative_queries: | |
| markdown += f" - {query}\n" | |
| else: | |
| markdown += "- **Alternative Queries**: None\n" | |
| validation = response.get('sql_validation', {}) | |
| if validation: | |
| is_valid = validation.get('is_valid', False) | |
| message = validation.get('message', 'N/A') | |
| markdown += f"\n- **SQL Validation**: {'Passed' if is_valid else 'Failed'} - {message}\n" | |
| return markdown | |
| def verification_to_markdown(verification_dict: Dict) -> str: | |
| """Convert verification response to Markdown""" | |
| if not verification_dict.get("success", False): | |
| return f"**Error**: {verification_dict.get('error', 'Unknown error')}" | |
| response = verification_dict.get("response", {}) | |
| markdown = "**Verification Results**\n\n" | |
| markdown += f"- **Verification Summary**: {response.get('verification_summary', 'N/A')}\n\n" | |
| markdown += f"- **Table Selection Accuracy**: {response.get('table_selection_accuracy', 'N/A')}\n\n" | |
| markdown += f"- **SQL Correctness**: {response.get('sql_correctness', 'N/A')}\n\n" | |
| markdown += f"- **Consistency Check**: {response.get('consistency_check', 'N/A')}\n\n" | |
| markdown += f"- **Recommended Response**: {response.get('recommended_response', 'N/A')}\n\n" | |
| markdown += f"- **Confidence Score**: {response.get('confidence_score', 'N/A')}\n\n" | |
| suggested_improvements = response.get('suggested_improvements', []) | |
| if suggested_improvements: | |
| markdown += "- **Suggested Improvements**:\n" | |
| for improvement in suggested_improvements: | |
| markdown += f" - {improvement}\n" | |
| else: | |
| markdown += "- **Suggested Improvements**: None\n" | |
| potential_issues = response.get('potential_issues', []) | |
| if potential_issues: | |
| markdown += "- **Potential Issues**:\n" | |
| for issue in potential_issues: | |
| markdown += f" - {issue}\n" | |
| else: | |
| markdown += "- **Potential Issues**: None\n" | |
| markdown += f"- **Schema Compliance**: {response.get('schema_compliance', 'N/A')}\n" | |
| return markdown | |
| def create_gradio_interface(): | |
| """Create Gradio interface""" | |
| agent = DatabaseQueryAgent() | |
| sample_queries = [ | |
| "Find all customers from customer tables", | |
| "Show me employee information from HR tables", | |
| "Get patient data from healthcare tables", | |
| "List all products with their details", | |
| "Find students enrolled in courses", | |
| "Show financial transaction records", | |
| "Get shipping information for deliveries", | |
| "Find all suppliers and their information", | |
| "Show retail store data", | |
| "Get manufacturing production records" | |
| ] | |
| def process_user_query(api_key, query): | |
| """Process query and return formatted results""" | |
| if not query.strip(): | |
| return "Please enter a query", "", "", "", "", "" | |
| results = agent.process_query(api_key, query) | |
| if "error" in results: | |
| return f"**Error**: {results['error']}", "", "", "", "", "" | |
| # Format responses as Markdown | |
| llama_markdown = response_to_markdown(results.get("llama_response", {})) | |
| mistral_markdown = response_to_markdown(results.get("mistral_response", {})) | |
| verification_markdown = verification_to_markdown(results.get("verification", {})) | |
| # Format execution results | |
| exec_results = results.get("execution_results", {}) | |
| execution_formatted = "" | |
| for model, result in exec_results.items(): | |
| execution_formatted += f"\n=== {model.upper()} EXECUTION ===\n" | |
| execution_formatted += f"SQL Query: {result.get('sql_query', 'N/A')}\n" | |
| validation = result.get('validation', {}) | |
| if validation.get('is_valid'): | |
| execution_formatted += f"β Query Validation: PASSED\n" | |
| else: | |
| execution_formatted += f"β Query Validation: FAILED - {validation.get('message', 'Unknown error')}\n" | |
| if result["success"]: | |
| execution_formatted += f"β Execution: Success! Retrieved {result['row_count']} rows\n" | |
| if result["row_count"] > 0: | |
| sample_data = result['data'][:3] if isinstance(result['data'], list) else [] | |
| execution_formatted += f"Sample data:\n{json.dumps(sample_data, indent=2)}\n" | |
| else: | |
| execution_formatted += "No data returned (empty result set)\n" | |
| else: | |
| execution_formatted += f"β Execution Error: {result['data']}\n" | |
| execution_formatted += "\n" | |
| if not execution_formatted: | |
| execution_formatted = "No queries were executed. Check if valid SQL was generated." | |
| schema_info = results.get('schema_info', 'No schema information available') | |
| # Format summary as Markdown | |
| verification_resp = results.get('verification', {}).get('response', {}) | |
| summary = f""" | |
| **π QUERY ANALYSIS COMPLETE** | |
| ββββββββββββββββββββββββ | |
| **π Models Used**: Llama 3.1 8B, Mistral 7B, Gemma 2 9B (verification) | |
| **β° Processed**: {results.get('timestamp', 'N/A')} | |
| **π― Verification Summary**: | |
| {verification_resp.get('verification_summary', 'N/A')} | |
| **π‘ Recommended Model**: {verification_resp.get('recommended_response', 'N/A')} | |
| **π Confidence**: {verification_resp.get('confidence_score', 'N/A')} | |
| **ποΈ Schema Compliance**: {verification_resp.get('schema_compliance', 'N/A')} | |
| **ποΈ Query Execution Status**: | |
| {len(exec_results)} queries attempted | |
| """ | |
| return summary, llama_markdown, mistral_markdown, verification_markdown, execution_formatted, schema_info | |
| with gr.Blocks( | |
| title="Fixed Intelligent Database Query Agent", | |
| theme=gr.themes.Soft(), | |
| css=""" | |
| .gradio-container { | |
| max-width: 1200px !important; | |
| margin: 0 auto !important; | |
| } | |
| .result-box { | |
| background-color: #f8f9fa; | |
| border: 1px solid #dee2e6; | |
| border-radius: 8px; | |
| padding: 15px; | |
| } | |
| """ | |
| ) as interface: | |
| gr.HTML(""" | |
| <div style="text-align: center; padding: 20px; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); color: white; border-radius: 10px; margin-bottom: 20px;"> | |
| <h1>π€ Fixed Intelligent Database Query Agent</h1> | |
| <p>AI-powered agent that intelligently selects relevant tables from 100+ tables and generates optimized SQL queries</p> | |
| <p><strong>Database:</strong> 100 tables across 10 business domains | <strong>Models:</strong> Llama 3.1 8B + Mistral 7B + Gemma 2 9B</p> | |
| <p><strong>β FIXED:</strong> Reserved Word Aliases | Enhanced Column Validation | Better SQL Syntax Checking</p> | |
| </div> | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| api_key_input = gr.Textbox( | |
| label="π OpenRouter API Key", | |
| type="password", | |
| placeholder="Enter your OpenRouter API key...", | |
| info="Get your free API key from openrouter.ai" | |
| ) | |
| query_input = gr.Textbox( | |
| label="π¬ Database Query", | |
| placeholder="Enter your natural language query...", | |
| lines=3, | |
| info="Example: 'Find all customers who placed orders in the last month'" | |
| ) | |
| with gr.Row(): | |
| submit_btn = gr.Button("π Process Query", variant="primary", size="lg") | |
| clear_btn = gr.Button("ποΈ Clear", variant="secondary") | |
| gr.HTML("<h3>π Sample Test Queries</h3>") | |
| sample_dropdown = gr.Dropdown( | |
| choices=sample_queries, | |
| label="Quick Test Examples", | |
| info="Select a sample query to test the agent" | |
| ) | |
| with gr.Column(scale=2): | |
| summary_output = gr.Markdown(label="π Analysis Summary") | |
| with gr.Tabs(): | |
| with gr.Tab("π¦ Llama 3.1 8B Response"): | |
| llama_output = gr.Markdown(label="Llama Response") | |
| with gr.Tab("π Mistral 7B Response"): | |
| mistral_output = gr.Markdown(label="Mistral Response") | |
| with gr.Tab("β Verification (Gemma 2 9B)"): | |
| verification_output = gr.Markdown(label="Verification Analysis") | |
| with gr.Tab("ποΈ Query Execution Results"): | |
| execution_output = gr.Textbox( | |
| label="Database Execution Results", | |
| lines=15, | |
| max_lines=20, | |
| elem_classes=["result-box"] | |
| ) | |
| with gr.Tab("π Database Schema"): | |
| schema_output = gr.Textbox( | |
| label="Relevant Database Schema", | |
| lines=15, | |
| max_lines=20, | |
| elem_classes=["result-box"] | |
| ) | |
| submit_btn.click( | |
| fn=process_user_query, | |
| inputs=[api_key_input, query_input], | |
| outputs=[summary_output, llama_output, mistral_output, verification_output, execution_output, schema_output] | |
| ) | |
| clear_btn.click( | |
| fn=lambda: ("", "", "", "", "", "", ""), | |
| outputs=[query_input, summary_output, llama_output, mistral_output, verification_output, execution_output, schema_output] | |
| ) | |
| sample_dropdown.change( | |
| fn=lambda x: x, | |
| inputs=[sample_dropdown], | |
| outputs=[query_input] | |
| ) | |
| gr.HTML(""" | |
| <div style="margin-top: 20px; padding: 15px; background-color: #f8f9fa; border-radius: 8px;"> | |
| <h3>π― How to Use</h3> | |
| <ol> | |
| <li><strong>API Key:</strong> Get a free API key from <a href="https://openrouter.ai" target="_blank">openrouter.ai</a></li> | |
| <li><strong>Query:</strong> Enter your natural language database query</li> | |
| <li><strong>Process:</strong> The agent will analyze your query across 100+ tables and generate optimized SQL</li> | |
| <li><strong>Results:</strong> View responses from multiple AI models, verification analysis, and actual query execution results</li> | |
| </ol> | |
| <p><strong>Features:</strong></p> | |
| <ul> | |
| <li>π§ Multi-model AI analysis (Llama, Mistral, Gemma)</li> | |
| <li>π Intelligent table selection from 100+ tables</li> | |
| <li>β SQL validation and syntax checking</li> | |
| <li>ποΈ Real database query execution with results</li> | |
| <li>π Cross-model verification and comparison</li> | |
| </ul> | |
| </div> | |
| """) | |
| return interface | |
| def main(): | |
| """Main function to launch the application""" | |
| print("π Starting Intelligent Database Query Agent...") | |
| print("π Loading database schema and metadata...") | |
| interface = create_gradio_interface() | |
| print("β Database Query Agent Ready!") | |
| print("π Access the interface at: http://localhost:7860") | |
| print("π Don't forget to add your OpenRouter API key!") | |
| interface.launch(share=True) | |
| if __name__ == "__main__": | |
| main() |