import os
import gradio as gr
import requests
import json
import base64
import logging
import io
import time
from typing import List, Dict, Any, Union, Tuple, Optional
from dotenv import load_dotenv
# Load environment variables from .env file
load_dotenv()
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# Gracefully import libraries with fallbacks
try:
    from PIL import Image
    HAS_PIL = True
except ImportError:
    logger.warning("PIL not installed. Image processing will be limited.")
    HAS_PIL = False
try:
    import PyPDF2
    HAS_PYPDF2 = True
except ImportError:
    logger.warning("PyPDF2 not installed. PDF processing will be limited.")
    HAS_PYPDF2 = False
try:
    import markdown
    HAS_MARKDOWN = True
except ImportError:
    logger.warning("Markdown not installed. Markdown processing will be limited.")
    HAS_MARKDOWN = False
try:
    import openai
    HAS_OPENAI = True
except ImportError:
    logger.warning("OpenAI package not installed. OpenAI models will be unavailable.")
    HAS_OPENAI = False
try:
    from groq import Groq
    HAS_GROQ = True
except ImportError:
    logger.warning("Groq client not installed. Groq API will be unavailable.")
    HAS_GROQ = False
try:
    import cohere
    HAS_COHERE = True
except ImportError:
    logger.warning("Cohere package not installed. Cohere models will be unavailable.")
    HAS_COHERE = False
try:
    from huggingface_hub import InferenceClient
    HAS_HF = True
except ImportError:
    logger.warning("HuggingFace hub not installed. HuggingFace models will be limited.")
    HAS_HF = False
# API keys from environment
OPENROUTER_API_KEY = os.environ.get("OPENROUTER_API_KEY", "")
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "")
GROQ_API_KEY = os.environ.get("GROQ_API_KEY", "")
COHERE_API_KEY = os.environ.get("COHERE_API_KEY", "")
HF_API_KEY = os.environ.get("HF_API_KEY", "")
TOGETHER_API_KEY = os.environ.get("TOGETHER_API_KEY", "")
GOOGLEAI_API_KEY = os.environ.get("GOOGLEAI_API_KEY", "")
ANTHROPIC_API_KEY = os.environ.get("ANTHROPIC_API_KEY", "")
POE_API_KEY = os.environ.get("POE_API_KEY", "")
# Print application startup message with timestamp
current_time = time.strftime("%Y-%m-%d %H:%M:%S")
print(f"===== Application Startup at {current_time} =====\n")
# ==========================================================
# MODEL DEFINITIONS
# ==========================================================
# OPENROUTER MODELS
# These are the original models from the provided code
OPENROUTER_MODELS = [
    # 1M+ Context Models
    {"category": "1M+ Context", "models": [
        ("Google: Gemini Pro 2.0 Experimental", "google/gemini-2.0-pro-exp-02-05:free", 2000000),
        ("Google: Gemini 2.0 Flash Thinking Experimental 01-21", "google/gemini-2.0-flash-thinking-exp:free", 1048576),
        ("Google: Gemini Flash 2.0 Experimental", "google/gemini-2.0-flash-exp:free", 1048576),
        ("Google: Gemini Pro 2.5 Experimental", "google/gemini-2.5-pro-exp-03-25:free", 1000000),
        ("Google: Gemini Flash 1.5 8B Experimental", "google/gemini-flash-1.5-8b-exp", 1000000),
    ]},
    
    # 100K-1M Context Models
    {"category": "100K+ Context", "models": [
        ("DeepSeek: DeepSeek R1 Zero", "deepseek/deepseek-r1-zero:free", 163840),
        ("DeepSeek: R1", "deepseek/deepseek-r1:free", 163840),
        ("DeepSeek: DeepSeek V3 Base", "deepseek/deepseek-v3-base:free", 131072),
        ("DeepSeek: DeepSeek V3 0324", "deepseek/deepseek-chat-v3-0324:free", 131072),
        ("Google: Gemma 3 4B", "google/gemma-3-4b-it:free", 131072),
        ("Google: Gemma 3 12B", "google/gemma-3-12b-it:free", 131072),
        ("Nous: DeepHermes 3 Llama 3 8B Preview", "nousresearch/deephermes-3-llama-3-8b-preview:free", 131072),
        ("Qwen: Qwen2.5 VL 72B Instruct", "qwen/qwen2.5-vl-72b-instruct:free", 131072),
        ("DeepSeek: DeepSeek V3", "deepseek/deepseek-chat:free", 131072),
        ("NVIDIA: Llama 3.1 Nemotron 70B Instruct", "nvidia/llama-3.1-nemotron-70b-instruct:free", 131072),
        ("Meta: Llama 3.2 1B Instruct", "meta-llama/llama-3.2-1b-instruct:free", 131072),
        ("Meta: Llama 3.2 11B Vision Instruct", "meta-llama/llama-3.2-11b-vision-instruct:free", 131072),
        ("Meta: Llama 3.1 8B Instruct", "meta-llama/llama-3.1-8b-instruct:free", 131072),
        ("Mistral: Mistral Nemo", "mistralai/mistral-nemo:free", 128000),
    ]},
    
    # 64K-100K Context Models
    {"category": "64K-100K Context", "models": [
        ("Mistral: Mistral Small 3.1 24B", "mistralai/mistral-small-3.1-24b-instruct:free", 96000),
        ("Google: Gemma 3 27B", "google/gemma-3-27b-it:free", 96000),
        ("Qwen: Qwen2.5 VL 3B Instruct", "qwen/qwen2.5-vl-3b-instruct:free", 64000),
        ("DeepSeek: R1 Distill Qwen 14B", "deepseek/deepseek-r1-distill-qwen-14b:free", 64000),
        ("Qwen: Qwen2.5-VL 7B Instruct", "qwen/qwen-2.5-vl-7b-instruct:free", 64000),
    ]},
    
    # 32K-64K Context Models
    {"category": "32K-64K Context", "models": [
        ("Google: LearnLM 1.5 Pro Experimental", "google/learnlm-1.5-pro-experimental:free", 40960),
        ("Qwen: QwQ 32B", "qwen/qwq-32b:free", 40000),
        ("Google: Gemini 2.0 Flash Thinking Experimental", "google/gemini-2.0-flash-thinking-exp-1219:free", 40000),
        ("Bytedance: UI-TARS 72B", "bytedance-research/ui-tars-72b:free", 32768),
        ("Qwerky 72b", "featherless/qwerky-72b:free", 32768),
        ("OlympicCoder 7B", "open-r1/olympiccoder-7b:free", 32768),
        ("OlympicCoder 32B", "open-r1/olympiccoder-32b:free", 32768),
        ("Google: Gemma 3 1B", "google/gemma-3-1b-it:free", 32768),
        ("Reka: Flash 3", "rekaai/reka-flash-3:free", 32768),
        ("Dolphin3.0 R1 Mistral 24B", "cognitivecomputations/dolphin3.0-r1-mistral-24b:free", 32768),
        ("Dolphin3.0 Mistral 24B", "cognitivecomputations/dolphin3.0-mistral-24b:free", 32768),
        ("Mistral: Mistral Small 3", "mistralai/mistral-small-24b-instruct-2501:free", 32768),
        ("Qwen2.5 Coder 32B Instruct", "qwen/qwen-2.5-coder-32b-instruct:free", 32768),
        ("Qwen2.5 72B Instruct", "qwen/qwen-2.5-72b-instruct:free", 32768),
    ]},
    
    # 8K-32K Context Models
    {"category": "8K-32K Context", "models": [
        ("Meta: Llama 3.2 3B Instruct", "meta-llama/llama-3.2-3b-instruct:free", 20000),
        ("Qwen: QwQ 32B Preview", "qwen/qwq-32b-preview:free", 16384),
        ("DeepSeek: R1 Distill Qwen 32B", "deepseek/deepseek-r1-distill-qwen-32b:free", 16000),
        ("Qwen: Qwen2.5 VL 32B Instruct", "qwen/qwen2.5-vl-32b-instruct:free", 8192),
        ("Moonshot AI: Moonlight 16B A3B Instruct", "moonshotai/moonlight-16b-a3b-instruct:free", 8192),
        ("DeepSeek: R1 Distill Llama 70B", "deepseek/deepseek-r1-distill-llama-70b:free", 8192),
        ("Qwen 2 7B Instruct", "qwen/qwen-2-7b-instruct:free", 8192),
        ("Google: Gemma 2 9B", "google/gemma-2-9b-it:free", 8192),
        ("Mistral: Mistral 7B Instruct", "mistralai/mistral-7b-instruct:free", 8192),
        ("Microsoft: Phi-3 Mini 128K Instruct", "microsoft/phi-3-mini-128k-instruct:free", 8192),
        ("Microsoft: Phi-3 Medium 128K Instruct", "microsoft/phi-3-medium-128k-instruct:free", 8192),
        ("Meta: Llama 3 8B Instruct", "meta-llama/llama-3-8b-instruct:free", 8192),
        ("OpenChat 3.5 7B", "openchat/openchat-7b:free", 8192),
        ("Meta: Llama 3.3 70B Instruct", "meta-llama/llama-3.3-70b-instruct:free", 8000),
    ]},
    
    # <8K Context Models
    {"category": "4K Context", "models": [
        ("AllenAI: Molmo 7B D", "allenai/molmo-7b-d:free", 4096),
        ("Rogue Rose 103B v0.2", "sophosympatheia/rogue-rose-103b-v0.2:free", 4096),
        ("Toppy M 7B", "undi95/toppy-m-7b:free", 4096),
        ("Hugging Face: Zephyr 7B", "huggingfaceh4/zephyr-7b-beta:free", 4096),
        ("MythoMax 13B", "gryphe/mythomax-l2-13b:free", 4096),
    ]},
    # Vision-capable Models
    {"category": "Vision Models", "models": [
        ("Google: Gemini Pro 2.0 Experimental", "google/gemini-2.0-pro-exp-02-05:free", 2000000),
        ("Google: Gemini 2.0 Flash Thinking Experimental 01-21", "google/gemini-2.0-flash-thinking-exp:free", 1048576),
        ("Google: Gemini Flash 2.0 Experimental", "google/gemini-2.0-flash-exp:free", 1048576),
        ("Google: Gemini Pro 2.5 Experimental", "google/gemini-2.5-pro-exp-03-25:free", 1000000),
        ("Google: Gemini Flash 1.5 8B Experimental", "google/gemini-flash-1.5-8b-exp", 1000000),
        ("Google: Gemma 3 4B", "google/gemma-3-4b-it:free", 131072),
        ("Google: Gemma 3 12B", "google/gemma-3-12b-it:free", 131072),
        ("Qwen: Qwen2.5 VL 72B Instruct", "qwen/qwen2.5-vl-72b-instruct:free", 131072),
        ("Meta: Llama 3.2 11B Vision Instruct", "meta-llama/llama-3.2-11b-vision-instruct:free", 131072),
        ("Mistral: Mistral Small 3.1 24B", "mistralai/mistral-small-3.1-24b-instruct:free", 96000),
        ("Google: Gemma 3 27B", "google/gemma-3-27b-it:free", 96000),
        ("Qwen: Qwen2.5 VL 3B Instruct", "qwen/qwen2.5-vl-3b-instruct:free", 64000),
        ("Qwen: Qwen2.5-VL 7B Instruct", "qwen/qwen-2.5-vl-7b-instruct:free", 64000),
        ("Google: LearnLM 1.5 Pro Experimental", "google/learnlm-1.5-pro-experimental:free", 40960),
        ("Google: Gemini 2.0 Flash Thinking Experimental", "google/gemini-2.0-flash-thinking-exp-1219:free", 40000),
        ("Bytedance: UI-TARS 72B", "bytedance-research/ui-tars-72b:free", 32768),
        ("Google: Gemma 3 1B", "google/gemma-3-1b-it:free", 32768),
        ("Qwen: Qwen2.5 VL 32B Instruct", "qwen/qwen2.5-vl-32b-instruct:free", 8192),
        ("AllenAI: Molmo 7B D", "allenai/molmo-7b-d:free", 4096),
    ]},
]
# Flatten OpenRouter model list for easier access
OPENROUTER_ALL_MODELS = []
for category in OPENROUTER_MODELS:
    for model in category["models"]:
        if model not in OPENROUTER_ALL_MODELS:  # Avoid duplicates
            OPENROUTER_ALL_MODELS.append(model)
# VISION MODELS - For tracking which models support images
VISION_MODELS = {
    "OpenRouter": [model[0] for model in OPENROUTER_MODELS[-1]["models"]],  # Last category is Vision Models
    "OpenAI": [
        "gpt-4-vision-preview", "gpt-4o", "gpt-4o-mini", "gpt-4-turbo", 
        "gpt-4-turbo-preview", "gpt-4-0125-preview", "gpt-4-1106-preview",
        "o1-preview", "o1-mini"
    ],
    "HuggingFace": [
        "Qwen/Qwen2.5-VL-7B-Instruct", "Qwen/qwen2.5-vl-3b-instruct", 
        "Qwen/qwen2.5-vl-32b-instruct", "Qwen/qwen2.5-vl-72b-instruct"
    ],
    "Groq": ["llama-3.2-11b-vision", "llama-3.2-90b-vision"],
    "Together": ["Llama-3.2-11B-Vision-Instruct", "Llama-3.2-90B-Vision-Instruct"],
    #"OVH": ["llava-next-mistral-7b", "qwen2.5-vl-72b-instruct"],
    #"Cerebras": [],
    "GoogleAI": ["gemini-1.5-pro", "gemini-1.0-pro", "gemini-1.5-flash", "gemini-2.0-pro", "gemini-2.5-pro"]
}
# POE MODELS
POE_MODELS = {
    "claude_3_igloo": 4000,               # Claude-3.5-Sonnet
    "claude_2_1_cedar": 4000,             # Claude-3-Opus
    "claude_2_1_bamboo": 4000,            # Claude-3-Sonnet
    "claude_3_haiku": 4000,               # Claude-3-Haiku
    "claude_3_igloo_200k": 200000,        # Claude-3.5-Sonnet-200k
    "claude_3_opus_200k": 200000,         # Claude-3-Opus-200k
    "claude_3_sonnet_200k": 200000,       # Claude-3-Sonnet-200k
    "claude_3_haiku_200k": 200000,        # Claude-3-Haiku-200k
    "claude_2_short": 4000,               # Claude-2
    "a2_2": 100000,                       # Claude-2-100k
    "a2": 9000,                           # Claude-instant
    "a2_100k": 100000,                    # Claude-instant-100k
    "chinchilla": 4000,                   # GPT-3.5-Turbo
    "gpt3_5": 2000,                       # GPT-3.5-Turbo-Raw
    "chinchilla_instruct": 2000,          # GPT-3.5-Turbo-Instruct
    "agouti": 16000,                      # ChatGPT-16k
    "gpt4_classic": 2000,                 # GPT-4-Classic
    "beaver": 4000,                       # GPT-4-Turbo
    "vizcacha": 128000,                   # GPT-4-Turbo-128k
    "gpt4_o": 4000,                       # GPT-4o
    "gpt4_o_128k": 128000,                # GPT-4o-128k
    "gpt4_o_mini": 4000,                  # GPT-4o-Mini
    "gpt4_o_mini_128k": 128000,           # GPT-4o-Mini-128k
    "acouchy": 8000,                      # Google-PaLM
    "code_llama_13b_instruct": 4000,      # Code-Llama-13b
    "code_llama_34b_instruct": 4000,      # Code-Llama-34b
    "upstage_solar_0_70b_16bit": 2000,    # Solar-Mini
    "gemini_pro_search": 4000,            # Gemini-1.5-Flash-Search
    "gemini_1_5_pro_1m": 2000000,         # Gemini-1.5-Pro-2M
}
# Add vision-capable models to vision models list
POE_VISION_MODELS = [
    "claude_3_igloo", "claude_2_1_cedar", "claude_2_1_bamboo", "claude_3_haiku",
    "claude_3_igloo_200k", "claude_3_opus_200k", "claude_3_sonnet_200k", "claude_3_haiku_200k",
    "gpt4_o", "gpt4_o_128k", "gpt4_o_mini", "gpt4_o_mini_128k", "beaver", "vizcacha"
]
VISION_MODELS["Poe"] = POE_VISION_MODELS
# OPENAI MODELS
OPENAI_MODELS = {
    "gpt-3.5-turbo": 16385,
    "gpt-3.5-turbo-0125": 16385,
    "gpt-3.5-turbo-1106": 16385,
    "gpt-3.5-turbo-instruct": 4096,
    "gpt-4": 8192,
    "gpt-4-0314": 8192,
    "gpt-4-0613": 8192,
    "gpt-4-turbo": 128000,
    "gpt-4-turbo-2024-04-09": 128000,
    "gpt-4-turbo-preview": 128000,
    "gpt-4-0125-preview": 128000,
    "gpt-4-1106-preview": 128000,
    "gpt-4o": 128000,
    "gpt-4o-2024-11-20": 128000,
    "gpt-4o-2024-08-06": 128000,
    "gpt-4o-2024-05-13": 128000,
    "gpt-4o-mini": 128000,
    "gpt-4o-mini-2024-07-18": 128000,
    "o1-preview": 128000,
    "o1-preview-2024-09-12": 128000,
    "o1-mini": 128000,
    "o1-mini-2024-09-12": 128000,
}
# HUGGINGFACE MODELS 
HUGGINGFACE_MODELS = {
    "microsoft/phi-3-mini-4k-instruct": 4096,
    "microsoft/Phi-3-mini-128k-instruct": 131072,
    "HuggingFaceH4/zephyr-7b-beta": 8192,
    "deepseek-ai/DeepSeek-Coder-V2-Instruct": 8192,
    "mistralai/Mistral-7B-Instruct-v0.3": 32768,
    "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO": 32768,
    "microsoft/Phi-3.5-mini-instruct": 4096,
    "google/gemma-2-2b-it": 2048,
    "openai-community/gpt2": 1024,
    "microsoft/phi-2": 2048,
    "TinyLlama/TinyLlama-1.1B-Chat-v1.0": 2048,
    "VAGOsolutions/Llama-3-SauerkrautLM-8b-Instruct": 2048,
    "VAGOsolutions/Llama-3.1-SauerkrautLM-8b-Instruct": 4096,
    "VAGOsolutions/SauerkrautLM-Nemo-12b-Instruct": 4096,
    "openGPT-X/Teuken-7B-instruct-research-v0.4": 4096,
    "Qwen/Qwen2.5-7B-Instruct": 131072,
    "tiiuae/falcon-7b-instruct": 8192,
    "Qwen/QwQ-32B-preview": 32768,
    "Qwen/Qwen2.5-VL-7B-Instruct": 64000,
    "Qwen/qwen2.5-vl-3b-instruct": 64000,
    "Qwen/qwen2.5-vl-32b-instruct": 8192,
    "Qwen/qwen2.5-vl-72b-instruct": 131072,
}
# GROQ MODELS - We'll populate this dynamically
DEFAULT_GROQ_MODELS = {
    "deepseek-r1-distill-llama-70b": 8192,
    "deepseek-r1-distill-qwen-32b": 8192,
    "gemma2-9b-it": 8192,
    "llama-3.1-8b-instant": 131072,
    "llama-3.2-1b-preview": 131072,
    "llama-3.2-3b-preview": 131072,
    "llama-3.2-11b-vision-preview": 131072,
    "llama-3.2-90b-vision-preview": 131072,
    "llama-3.3-70b-specdec": 131072,
    "llama-3.3-70b-versatile": 131072,
    "llama-guard-3-8b": 8192,
    "llama3-8b-8192": 8192,
    "llama3-70b-8192": 8192,
    "mistral-saba-24b": 32768,
    "qwen-2.5-32b": 32768,
    "qwen-2.5-coder-32b": 32768,
    "qwen-qwq-32b": 32768,
    "playai-tts": 4096,       # Including TTS models but setting reasonable context limits
    "playai-tts-arabic": 4096,
    "distil-whisper-large-v3-en": 4096,
    "whisper-large-v3": 4096,
    "whisper-large-v3-turbo": 4096
}
# COHERE MODELS
COHERE_MODELS = {
    "command-r-plus-08-2024": 131072,
    "command-r-plus-04-2024": 131072,
    "command-r-plus": 131072,
    "command-r-08-2024": 131072,
    "command-r-03-2024": 131072,
    "command-r": 131072,
    "command": 4096,
    "command-nightly": 131072,
    "command-light": 4096,
    "command-light-nightly": 4096,
    "c4ai-aya-expanse-8b": 8192,
    "c4ai-aya-expanse-32b": 131072,
}
# TOGETHER MODELS in the free tier
TOGETHER_MODELS = {
    "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo": 131072,
    "deepseek-ai/DeepSeek-R1-Distill-Llama-70B-free": 8192,
    "meta-llama/Llama-Vision-Free": 8192,
    "meta-llama/Llama-3.3-70B-Instruct-Turbo-Free": 8192,
}
# Add to the vision models list
VISION_MODELS["Together"] = ["meta-llama/Llama-Vision-Free"]
# OVH MODELS - OVH AI Endpoints (free beta)
OVH_MODELS = {
    "ovh/codestral-mamba-7b-v0.1": 131072,
    "ovh/deepseek-r1-distill-llama-70b": 8192,
    "ovh/llama-3.1-70b-instruct": 131072,
    "ovh/llama-3.1-8b-instruct": 131072,
    "ovh/llama-3.3-70b-instruct": 131072,
    "ovh/llava-next-mistral-7b": 8192,
    "ovh/mistral-7b-instruct-v0.3": 32768,
    "ovh/mistral-nemo-2407": 131072,
    "ovh/mixtral-8x7b-instruct": 32768,
    "ovh/qwen2.5-coder-32b-instruct": 32768,
    "ovh/qwen2.5-vl-72b-instruct": 131072,
}
# CEREBRAS MODELS
CEREBRAS_MODELS = {
    "llama3.1-8b": 8192,
    "llama-3.3-70b": 8192,
}
# GOOGLE AI MODELS
GOOGLEAI_MODELS = {
    "gemini-1.0-pro": 32768,
    "gemini-1.5-flash": 1000000,
    "gemini-1.5-pro": 1000000,
    "gemini-2.0-pro": 2000000,
    "gemini-2.5-pro": 2000000,
}
# ANTHROPIC MODELS
ANTHROPIC_MODELS = {
    "claude-3-7-sonnet-20250219": 128000,  # Claude 3.7 Sonnet
    "claude-3-5-sonnet-20241022": 200000,  # Claude 3.5 Sonnet
    "claude-3-5-haiku-20240307": 200000,   # Claude 3.5 Haiku 
    "claude-3-5-sonnet-20240620": 200000,  # Claude 3.5 Sonnet 2024-06-20
    "claude-3-opus-20240229": 200000,      # Claude 3 Opus
    "claude-3-haiku-20240307": 200000,     # Claude 3 Haiku
    "claude-3-sonnet-20240229": 200000,    # Claude 3 Sonnet
}
# Add Anthropic to the vision models list
VISION_MODELS["Anthropic"] = [
    "claude-3-7-sonnet-20250219",
    "claude-3-5-sonnet-20241022", 
    "claude-3-opus-20240229", 
    "claude-3-sonnet-20240229",
    "claude-3-5-haiku-20240307",
    "claude-3-haiku-20240307"
]
# Add all models with "vl", "vision", "visual" in their name to HF vision models
for model_name in list(HUGGINGFACE_MODELS.keys()):
    if any(x in model_name.lower() for x in ["vl", "vision", "visual", "llava"]):
        if model_name not in VISION_MODELS["HuggingFace"]:
            VISION_MODELS["HuggingFace"].append(model_name)
# ==========================================================
# HELPER FUNCTIONS
# ==========================================================
def fetch_groq_models():
    """Fetch available Groq models with proper error handling"""
    try:
        if not HAS_GROQ or not GROQ_API_KEY:
            logger.warning("Groq client not available or no API key. Using default model list.")
            return DEFAULT_GROQ_MODELS
        client = Groq(api_key=GROQ_API_KEY)
        models = client.models.list()
        
        # Create dictionary of model_id -> context size
        model_dict = {}
        for model in models.data:
            model_id = model.id
            # Map known context sizes or use a default
            if "llama-3" in model_id and "70b" in model_id:
                context_size = 131072
            elif "llama-3" in model_id and "8b" in model_id:
                context_size = 131072
            elif "mixtral" in model_id:
                context_size = 32768
            elif "gemma" in model_id:
                context_size = 8192
            elif "vision" in model_id:
                context_size = 131072
            else:
                context_size = 8192  # Default assumption
                
            model_dict[model_id] = context_size
            
        # Ensure we have models by combining with defaults
        if not model_dict:
            return DEFAULT_GROQ_MODELS
        return {**DEFAULT_GROQ_MODELS, **model_dict}
        
    except Exception as e:
        logger.error(f"Error fetching Groq models: {e}")
        return DEFAULT_GROQ_MODELS
# Initialize Groq models
GROQ_MODELS = fetch_groq_models()
def encode_image_to_base64(image_path):
    """Encode an image file to base64 string"""
    try:
        if isinstance(image_path, str):  # File path as string
            with open(image_path, "rb") as image_file:
                encoded_string = base64.b64encode(image_file.read()).decode('utf-8')
                file_extension = image_path.split('.')[-1].lower()
                mime_type = f"image/{file_extension}"
                if file_extension in ["jpg", "jpeg"]:
                    mime_type = "image/jpeg"
                elif file_extension == "png":
                    mime_type = "image/png"
                elif file_extension == "webp":
                    mime_type = "image/webp"
                return f"data:{mime_type};base64,{encoded_string}"
        elif hasattr(image_path, 'name'):  # Handle Gradio file objects directly
            with open(image_path.name, "rb") as image_file:
                encoded_string = base64.b64encode(image_file.read()).decode('utf-8')
                file_extension = image_path.name.split('.')[-1].lower()
                mime_type = f"image/{file_extension}"
                if file_extension in ["jpg", "jpeg"]:
                    mime_type = "image/jpeg"
                elif file_extension == "png":
                    mime_type = "image/png"
                elif file_extension == "webp":
                    mime_type = "image/webp"
                return f"data:{mime_type};base64,{encoded_string}"
        else:  # Handle file object or other types
            logger.error(f"Unsupported image type: {type(image_path)}")
            return None
    except Exception as e:
        logger.error(f"Error encoding image: {str(e)}")
        return None
def extract_text_from_file(file_path):
    """Extract text from various file types"""
    try:
        file_extension = file_path.split('.')[-1].lower()
        
        if file_extension == 'pdf':
            if HAS_PYPDF2:
                text = ""
                with open(file_path, 'rb') as file:
                    pdf_reader = PyPDF2.PdfReader(file)
                    for page_num in range(len(pdf_reader.pages)):
                        page = pdf_reader.pages[page_num]
                        text += page.extract_text() + "\n\n"
                return text
            else:
                return "PDF processing is not available (PyPDF2 not installed)"
        
        elif file_extension == 'md':
            with open(file_path, 'r', encoding='utf-8') as file:
                return file.read()
        
        elif file_extension == 'txt':
            with open(file_path, 'r', encoding='utf-8') as file:
                return file.read()
                
        else:
            return f"Unsupported file type: {file_extension}"
            
    except Exception as e:
        logger.error(f"Error extracting text from file: {str(e)}")
        return f"Error processing file: {str(e)}"
def prepare_message_with_media(text, images=None, documents=None):
    """Prepare a message with text, images, and document content"""
    # If no media, return text only
    if not images and not documents:
        return text
    
    # Start with text content
    if documents and len(documents) > 0:
        # If there are documents, append their content to the text
        document_texts = []
        for doc in documents:
            if doc is None:
                continue
            # Make sure to handle file objects properly
            doc_path = doc.name if hasattr(doc, 'name') else doc
            doc_text = extract_text_from_file(doc_path)
            if doc_text:
                document_texts.append(doc_text)
        
        # Add document content to text
        if document_texts:
            if not text:
                text = "Please analyze these documents:"
            else:
                text = f"{text}\n\nDocument content:\n\n"
            
            text += "\n\n".join(document_texts)
            
        # If no images, return text only
        if not images:
            return text
    
    # If we have images, create a multimodal content array
    content = [{"type": "text", "text": text or "Analyze this image:"}]
    
    # Add images if any
    if images:
        # Check if images is a list of image paths or file objects
        if isinstance(images, list):
            for img in images:
                if img is None:
                    continue
                
                encoded_image = encode_image_to_base64(img)
                if encoded_image:
                    content.append({
                        "type": "image_url",
                        "image_url": {"url": encoded_image}
                    })
        else:
            # For single image or Gallery component
            logger.warning(f"Images is not a list: {type(images)}")
            # Try to handle as single image
            encoded_image = encode_image_to_base64(images)
            if encoded_image:
                content.append({
                    "type": "image_url", 
                    "image_url": {"url": encoded_image}
                })
    
    return content
def format_to_message_dict(history):
    """Convert history to proper message format"""
    messages = []
    for item in history:
        if isinstance(item, dict) and "role" in item and "content" in item:
            # Already in the correct format
            messages.append(item)
        elif isinstance(item, list) and len(item) == 2:
            # Convert from old format [user_msg, ai_msg]
            human, ai = item
            if human:
                messages.append({"role": "user", "content": human})
            if ai:
                messages.append({"role": "assistant", "content": ai})
    return messages
def process_uploaded_images(files):
    """Process uploaded image files"""
    file_paths = []
    for file in files:
        if hasattr(file, 'name'):
            file_paths.append(file.name)
    return file_paths
def filter_models(provider, search_term):
    """Filter models based on search term and provider"""
    if provider == "OpenRouter":
        all_models = [model[0] for model in OPENROUTER_ALL_MODELS]
    elif provider == "OpenAI":
        all_models = list(OPENAI_MODELS.keys())
    elif provider == "HuggingFace":
        all_models = list(HUGGINGFACE_MODELS.keys())
    elif provider == "Groq":
        all_models = list(GROQ_MODELS.keys())
    elif provider == "Cohere":
        all_models = list(COHERE_MODELS.keys())
    elif provider == "Together":
        all_models = list(TOGETHER_MODELS.keys())
    elif provider == "OVH":
        all_models = list(OVH_MODELS.keys())
    elif provider == "Cerebras":
        all_models = list(CEREBRAS_MODELS.keys())
    elif provider == "GoogleAI":
        all_models = list(GOOGLEAI_MODELS.keys())
    else:
        return [], None
        
    if not search_term:
        return all_models, all_models[0] if all_models else None
        
    filtered_models = [model for model in all_models if search_term.lower() in model.lower()]
    
    if filtered_models:
        return filtered_models, filtered_models[0]
    else:
        return all_models, all_models[0] if all_models else None
def get_model_info(provider, model_choice):
    """Get model ID and context size based on provider and model name"""
    if provider == "OpenRouter":
        for name, model_id, ctx_size in OPENROUTER_ALL_MODELS:
            if name == model_choice:
                return model_id, ctx_size
    elif provider == "Poe":
        if model_choice in POE_MODELS:
            return model_choice, POE_MODELS[model_choice]
    elif provider == "OpenAI":
        if model_choice in OPENAI_MODELS:
            return model_choice, OPENAI_MODELS[model_choice]
    elif provider == "HuggingFace":
        if model_choice in HUGGINGFACE_MODELS:
            return model_choice, HUGGINGFACE_MODELS[model_choice]
    elif provider == "Groq":
        if model_choice in GROQ_MODELS:
            return model_choice, GROQ_MODELS[model_choice]
    elif provider == "Cohere":
        if model_choice in COHERE_MODELS:
            return model_choice, COHERE_MODELS[model_choice]
    elif provider == "Together":
        if model_choice in TOGETHER_MODELS:
            return model_choice, TOGETHER_MODELS[model_choice]
    elif provider == "Anthropic":
        if model_choice in ANTHROPIC_MODELS:
            return model_choice, ANTHROPIC_MODELS[model_choice]
    elif provider == "GoogleAI":
        if model_choice in GOOGLEAI_MODELS:
            return model_choice, GOOGLEAI_MODELS[model_choice]
    
    return None, 0
def update_context_display(provider, model_name):
    """Update context size display for the selected model"""
    _, ctx_size = get_model_info(provider, model_name)
    return f"{ctx_size:,}" if ctx_size else "Unknown"
def is_vision_model(provider, model_name):
    """Check if a model supports vision/images"""
    # Safety check for None model name
    if model_name is None:
        return False
        
    if provider in VISION_MODELS:
        if model_name in VISION_MODELS[provider]:
            return True
            
        # Also check for common vision indicators in model names
        try:
            if any(x in model_name.lower() for x in ["vl", "vision", "visual", "llava", "gemini"]):
                return True
        except AttributeError:
            # In case model_name is not a string or has no lower method
            return False
    
    return False
def update_model_info(provider, model_name):
    """Generate HTML info display for the selected model"""
    model_id, ctx_size = get_model_info(provider, model_name)
    if not model_id:
        return "
Model information not available
"
        
    # Check if this is a vision model
    is_vision = is_vision_model(provider, model_name)
    
    vision_badge = 'Vision' if is_vision else ''
    
    # For OpenRouter, show the model ID
    model_id_html = f"Model ID: {model_id}
" if provider == "OpenRouter" else ""
    
    # For others, the ID is the same as the name
    if provider != "OpenRouter":
        model_id_html = ""
    
    return f"""
    
        {model_name} {vision_badge}
        {model_id_html}
        
Context Size: {ctx_size:,} tokens
        Provider: {provider}
        {f'
Features: Supports image understanding
' if is_vision else ''}
    
 
    """
# ==========================================================
# API HANDLERS
# ==========================================================
def call_anthropic_api(payload, api_key_override=None):
    """Make a call to Anthropic API with error handling"""
    try:
        # Try to import Anthropic
        try:
            import anthropic
            from anthropic import Anthropic
        except ImportError:
            raise ImportError("Anthropic package not installed. Install it with: pip install anthropic")
            
        api_key = api_key_override if api_key_override else os.environ.get("ANTHROPIC_API_KEY", "")
        if not api_key:
            raise ValueError("Anthropic API key is required")
            
        client = Anthropic(api_key=api_key)
        
        # Extract parameters from payload
        model = payload.get("model", "claude-3-5-sonnet-20241022")
        messages = payload.get("messages", [])
        temperature = payload.get("temperature", 0.7)
        max_tokens = payload.get("max_tokens", 1000)
        
        # Format messages for Anthropic
        # Find system message if any
        system_prompt = None
        chat_messages = []
        
        for msg in messages:
            if msg["role"] == "system":
                system_prompt = msg["content"]
            else:
                # Format content
                if isinstance(msg["content"], list):
                    # Handle multimodal content (images)
                    anthropic_content = []
                    for item in msg["content"]:
                        if item["type"] == "text":
                            anthropic_content.append({
                                "type": "text",
                                "text": item["text"]
                            })
                        elif item["type"] == "image_url":
                            # Extract base64 from data URL if present
                            image_url = item["image_url"]["url"]
                            if image_url.startswith("data:"):
                                # Extract media type and base64 data
                                parts = image_url.split(",", 1)
                                media_type = parts[0].split(":")[1].split(";")[0]
                                base64_data = parts[1]
                                
                                anthropic_content.append({
                                    "type": "image",
                                    "source": {
                                        "type": "base64",
                                        "media_type": media_type,
                                        "data": base64_data
                                    }
                                })
                            else:
                                # URL not supported by Anthropic yet
                                anthropic_content.append({
                                    "type": "text",
                                    "text": f"[Image URL: {image_url}]"
                                })
                    chat_messages.append({
                        "role": msg["role"],
                        "content": anthropic_content
                    })
                else:
                    # Simple text content
                    chat_messages.append({
                        "role": msg["role"],
                        "content": msg["content"]
                    })
        
        # Make request to Anthropic
        response = client.messages.create(
            model=model,
            max_tokens=max_tokens,
            temperature=temperature,
            system=system_prompt,
            messages=chat_messages
        )
        
        return response
    except Exception as e:
        logger.error(f"Anthropic API error: {str(e)}")
        raise e
def call_poe_api(payload, api_key_override=None):
    """Make a call to Poe API with error handling"""
    try:
        # Try to import fastapi_poe
        try:
            import fastapi_poe as fp
        except ImportError:
            raise ImportError("fastapi_poe package not installed. Install it with: pip install fastapi_poe")
            
        api_key = api_key_override if api_key_override else os.environ.get("POE_API_KEY", "")
        if not api_key:
            raise ValueError("Poe API key is required")
            
        # Extract parameters from payload
        model = payload.get("model", "chinchilla")  # Default to GPT-3.5-Turbo
        messages = payload.get("messages", [])
        
        # Convert messages to Poe format
        poe_messages = []
        for msg in messages:
            role = msg["role"]
            content = msg["content"]
            
            # Skip system messages as Poe doesn't support them directly
            if role == "system":
                continue
                
            # Convert content format
            if isinstance(content, list):
                # Handle multimodal content (images)
                text_parts = []
                for item in content:
                    if item["type"] == "text":
                        text_parts.append(item["text"])
                
                # For images, we'll need to extract and handle them separately
                # This is a simplified approach - in reality, you'd need to handle images properly
                content = "\n".join(text_parts)
                
            # Add message to Poe messages
            poe_messages.append(fp.ProtocolMessage(role=role, content=content))
        
        # Make synchronous request to Poe
        response_content = ""
        for partial in fp.get_bot_response_sync(messages=poe_messages, bot_name=model, api_key=api_key):
            if hasattr(partial, "text"):
                response_content += partial.text
        
        # Create a response object with a structure similar to other APIs
        response = {
            "id": f"poe-{int(time.time())}",
            "choices": [
                {
                    "message": {
                        "role": "assistant",
                        "content": response_content
                    }
                }
            ]
        }
        
        return response
    except Exception as e:
        logger.error(f"Poe API error: {str(e)}")
        raise e
def call_openrouter_api(payload, api_key_override=None):
    """Make a call to OpenRouter API with error handling"""
    try:
        api_key = api_key_override if api_key_override else OPENROUTER_API_KEY
        if not api_key:
            raise ValueError("OpenRouter API key is required")
            
        response = requests.post(
            "https://openrouter.ai/api/v1/chat/completions",
            headers={
                "Content-Type": "application/json",
                "Authorization": f"Bearer {api_key}",
                "HTTP-Referer": "https://huggingface.co/spaces/cstr/CrispChat"
            },
            json=payload,
            timeout=180  # Longer timeout for document processing
        )
        return response
    except requests.RequestException as e:
        logger.error(f"OpenRouter API request error: {str(e)}")
        raise e
def call_openai_api(payload, api_key_override=None):
    """Make a call to OpenAI API with error handling"""
    try:
        if not HAS_OPENAI:
            raise ImportError("OpenAI package not installed")
            
        api_key = api_key_override if api_key_override else OPENAI_API_KEY
        if not api_key:
            raise ValueError("OpenAI API key is required")
            
        client = openai.OpenAI(api_key=api_key)
        
        # Extract parameters from payload
        model = payload.get("model", "gpt-3.5-turbo")
        messages = payload.get("messages", [])
        temperature = payload.get("temperature", 0.7)
        max_tokens = payload.get("max_tokens", 1000)
        stream = payload.get("stream", False)
        top_p = payload.get("top_p", 0.9)
        presence_penalty = payload.get("presence_penalty", 0)
        frequency_penalty = payload.get("frequency_penalty", 0)
        
        # Handle response format if specified
        response_format = None
        if payload.get("response_format") == "json_object":
            response_format = {"type": "json_object"}
            
        # Create completion
        response = client.chat.completions.create(
            model=model,
            messages=messages,
            temperature=temperature,
            max_tokens=max_tokens,
            stream=stream,
            top_p=top_p,
            presence_penalty=presence_penalty,
            frequency_penalty=frequency_penalty,
            response_format=response_format
        )
        
        return response
    except Exception as e:
        logger.error(f"OpenAI API error: {str(e)}")
        raise e
def call_huggingface_api(payload, api_key_override=None):
    """Make a call to HuggingFace API with error handling"""
    try:
        if not HAS_HF:
            raise ImportError("HuggingFace hub not installed")
            
        api_key = api_key_override if api_key_override else HF_API_KEY
        
        # Extract parameters from payload
        model_id = payload.get("model", "mistralai/Mistral-7B-Instruct-v0.3")
        messages = payload.get("messages", [])
        temperature = payload.get("temperature", 0.7)
        max_tokens = payload.get("max_tokens", 500)
        
        # Create a prompt from messages
        prompt = ""
        for msg in messages:
            role = msg["role"].upper()
            content = msg["content"]
            
            # Handle multimodal content
            if isinstance(content, list):
                text_parts = []
                for item in content:
                    if item["type"] == "text":
                        text_parts.append(item["text"])
                content = "\n".join(text_parts)
                
            prompt += f"{role}: {content}\n"
            
        prompt += "ASSISTANT: "
        
        # Create client with or without API key
        client = InferenceClient(token=api_key) if api_key else InferenceClient()
        
        # Generate response
        response = client.text_generation(
            prompt,
            model=model_id,
            max_new_tokens=max_tokens,
            temperature=temperature,
            repetition_penalty=1.1
        )
        
        return {"generated_text": str(response)}
    except Exception as e:
        logger.error(f"HuggingFace API error: {str(e)}")
        raise e
def call_groq_api(payload, api_key_override=None):
    """Make a call to Groq API with error handling"""
    try:
        if not HAS_GROQ:
            raise ImportError("Groq client not installed")
            
        api_key = api_key_override if api_key_override else GROQ_API_KEY
        if not api_key:
            raise ValueError("Groq API key is required")
            
        client = Groq(api_key=api_key)
        
        # Extract parameters from payload
        model = payload.get("model", "llama-3.1-8b-instant")
        
        # Clean up messages - remove any unexpected properties
        messages = []
        for msg in payload.get("messages", []):
            clean_msg = {
                "role": msg["role"],
                "content": msg["content"]
            }
            messages.append(clean_msg)
        
        # Basic parameters
        groq_payload = {
            "model": model,
            "messages": messages,
            "temperature": payload.get("temperature", 0.7),
            "max_tokens": payload.get("max_tokens", 1000),
            "stream": payload.get("stream", False),
            "top_p": payload.get("top_p", 0.9)
        }
        
        # Create completion
        response = client.chat.completions.create(**groq_payload)
        
        return response
    except Exception as e:
        logger.error(f"Groq API error: {str(e)}")
        raise e
def call_cohere_api(payload, api_key_override=None):
    """Make a call to Cohere API with error handling"""
    try:
        if not HAS_COHERE:
            raise ImportError("Cohere package not installed")
            
        api_key = api_key_override if api_key_override else COHERE_API_KEY
        if not api_key:
            raise ValueError("Cohere API key is required")
            
        client = cohere.ClientV2(api_key=api_key)
        
        # Extract parameters from payload
        model = payload.get("model", "command-r-plus")
        messages = payload.get("messages", [])
        temperature = payload.get("temperature", 0.7)
        max_tokens = payload.get("max_tokens", 1000)
        
        # Create chat completion
        response = client.chat(
            model=model,
            messages=messages,
            temperature=temperature,
            max_tokens=max_tokens
        )
        
        return response
    except Exception as e:
        logger.error(f"Cohere API error: {str(e)}")
        raise e
def extract_ai_response(result, provider):
    """Extract AI response based on provider format"""
    try:
        if provider == "OpenRouter":
            if isinstance(result, dict):
                if "choices" in result and len(result["choices"]) > 0:
                    if "message" in result["choices"][0]:
                        message = result["choices"][0]["message"]
                        if message.get("reasoning") and not message.get("content"):
                            reasoning = message.get("reasoning")
                            lines = reasoning.strip().split('\n')
                            for line in lines:
                                if line and not line.startswith('I should') and not line.startswith('Let me'):
                                    return line.strip()
                            for line in lines:
                                if line.strip():
                                    return line.strip()
                        return message.get("content", "")
                    elif "delta" in result["choices"][0]:
                        return result["choices"][0]["delta"].get("content", "")
                    
        elif provider == "OpenAI":
            if hasattr(result, "choices") and len(result.choices) > 0:
                return result.choices[0].message.content
            
        elif provider == "Anthropic":
            if hasattr(result, "content"):
                # Combine text from all content blocks
                full_text = ""
                for block in result.content:
                    if block.type == "text":
                        full_text += block.text
                return full_text
            return "No content returned from Anthropic"
                
        elif provider == "HuggingFace":
            return result.get("generated_text", "")
                
        elif provider == "Groq":
            if hasattr(result, "choices") and len(result.choices) > 0:
                return result.choices[0].message.content
                
        elif provider == "Cohere":
            # Specific handling for Cohere's response format
            if hasattr(result, "message") and hasattr(result.message, "content"):
                # Extract text from content items
                text_content = ""
                for content_item in result.message.content:
                    if hasattr(content_item, "text") and content_item.text:
                        text_content += content_item.text
                return text_content
            else:
                return "No response content from Cohere"
        elif provider == "Poe":
            if isinstance(result, dict) and "choices" in result and len(result["choices"]) > 0:
                return result["choices"][0]["message"]["content"]
            return "No response content from Poe"
        elif provider == "Together":
            # Handle response from Together's native client
            if hasattr(result, "choices") and len(result.choices) > 0:
                if hasattr(result.choices[0], "message") and hasattr(result.choices[0].message, "content"):
                    return result.choices[0].message.content
                elif hasattr(result.choices[0], "delta") and hasattr(result.choices[0].delta, "content"):
                    return result.choices[0].delta.content
            # Fallback
            return str(result)
                
        elif provider == "OVH":
            if isinstance(result, dict) and "choices" in result and len(result["choices"]) > 0:
                return result["choices"][0]["message"]["content"]
                
        elif provider == "Cerebras":
            if isinstance(result, dict) and "choices" in result and len(result["choices"]) > 0:
                return result["choices"][0]["message"]["content"]
                
        elif provider == "GoogleAI":
            if isinstance(result, dict) and "choices" in result and len(result["choices"]) > 0:
                return result["choices"][0]["message"]["content"]
            
        logger.error(f"Unexpected response structure from {provider}: {result}")
        return f"Error: Could not extract response from {provider} API result"
    except Exception as e:
        logger.error(f"Error extracting AI response: {str(e)}")
        return f"Error: {str(e)}"
def call_together_api(payload, api_key_override=None):
    """Make a call to Together API with error handling using their native client"""
    try:
        # Import Together's native client
        # Note: This might need to be installed with: pip install together
        try:
            from together import Together
        except ImportError:
            raise ImportError("The Together Python package is not installed. Please install it with: pip install together")
        
        api_key = api_key_override if api_key_override else TOGETHER_API_KEY
        if not api_key:
            raise ValueError("Together API key is required")
        
        # Create the Together client
        client = Together(api_key=api_key)
        
        # Extract parameters from payload
        requested_model = payload.get("model", "")
        messages = payload.get("messages", [])
        temperature = payload.get("temperature", 0.7)
        max_tokens = payload.get("max_tokens", 1000)
        stream = payload.get("stream", False)
        
        # Use one of the free, serverless models
        free_models = [
            "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
            "deepseek-ai/DeepSeek-R1-Distill-Llama-70B-free",
            "meta-llama/Llama-Vision-Free",
            "meta-llama/Llama-3.3-70B-Instruct-Turbo-Free"
        ]
        
        # Default to the first free model
        model = free_models[0]
        
        # Try to match a requested model with a free model if possible
        if requested_model:
            for free_model in free_models:
                if requested_model.lower() in free_model.lower() or free_model.lower() in requested_model.lower():
                    model = free_model
                    break
        
        # Process messages for possible image content
        processed_messages = []
        for msg in messages:
            role = msg["role"]
            content = msg["content"]
            
            # Handle multimodal content for vision models
            if isinstance(content, list) and "vision" in model.lower():
                # Format according to Together's expected multimodal format
                parts = []
                for item in content:
                    if item["type"] == "text":
                        parts.append({"type": "text", "text": item["text"]})
                    elif item["type"] == "image_url":
                        parts.append({
                            "type": "image_url",
                            "image_url": item["image_url"]
                        })
                processed_messages.append({"role": role, "content": parts})
            else:
                # Regular text messages
                processed_messages.append({"role": role, "content": content})
        
        # Create completion with Together's client
        response = client.chat.completions.create(
            model=model,
            messages=processed_messages,
            temperature=temperature,
            max_tokens=max_tokens,
            stream=stream
        )
        
        return response
    except Exception as e:
        logger.error(f"Together API error: {str(e)}")
        raise e
def call_ovh_api(payload, api_key_override=None):
    """Make a call to OVH AI Endpoints API with error handling"""
    try:
        # Extract parameters from payload
        model = payload.get("model", "ovh/llama-3.1-8b-instruct")
        messages = payload.get("messages", [])
        temperature = payload.get("temperature", 0.7)
        max_tokens = payload.get("max_tokens", 1000)
                
        headers = {
            "Content-Type": "application/json"
        }
        
        data = {
            "model": model,
            "messages": messages,
            "temperature": temperature,
            "max_tokens": max_tokens
        }
        
        # Use a try-except to handle DNS resolution errors and provide a more helpful message
        try:
            # Correct endpoint URL based on documentation
            response = requests.post(
                "https://endpoints.ai.cloud.ovh.net/v1/chat/completions",  # Updated endpoint
                headers=headers,
                json=data,
                timeout=10  # Add timeout to avoid hanging
            )
            
            if response.status_code != 200:
                raise ValueError(f"OVH API returned status code {response.status_code}: {response.text}")
                
            return response.json()
        except requests.exceptions.ConnectionError as e:
            raise ValueError(f"Connection error to OVH API. This may be due to network restrictions in the environment: {str(e)}")
            
    except Exception as e:
        logger.error(f"OVH API error: {str(e)}")
        raise e
def call_cerebras_api(payload, api_key_override=None):
    """Make a call to Cerebras API with error handling"""
    try:
        # Extract parameters from payload
        requested_model = payload.get("model", "")
        
        # Map the full model name to the correct Cerebras model ID
        model_mapping = {
            "cerebras/llama-3.1-8b": "llama3.1-8b",
            "cerebras/llama-3.3-70b": "llama-3.3-70b",
            "llama-3.1-8b": "llama3.1-8b",
            "llama-3.3-70b": "llama-3.3-70b",
            "llama3.1-8b": "llama3.1-8b"
        }
        
        # Default to the 8B model
        model = "llama3.1-8b"
        
        # If the requested model matches any of our mappings, use that instead
        if requested_model in model_mapping:
            model = model_mapping[requested_model]
        elif "3.3" in requested_model or "70b" in requested_model.lower():
            model = "llama-3.3-70b"
        
        messages = payload.get("messages", [])
        temperature = payload.get("temperature", 0.7)
        max_tokens = payload.get("max_tokens", 1000)
        
        # Try-except block for network issues
        try:
            headers = {
                "Content-Type": "application/json",
                "Authorization": f"Bearer {api_key_override or os.environ.get('CEREBRAS_API_KEY', '')}"
            }
            
            data = {
                "model": model,
                "messages": messages,
                "temperature": temperature,
                "max_tokens": max_tokens
            }
            
            response = requests.post(
                "https://api.cloud.cerebras.ai/v1/chat/completions",
                headers=headers,
                json=data,
                timeout=30  # Increased timeout
            )
            
            if response.status_code != 200:
                raise ValueError(f"Cerebras API returned status code {response.status_code}: {response.text}")
                
            return response.json()
        except requests.exceptions.RequestException as e:
            # More specific error handling for network issues
            if "NameResolution" in str(e):
                raise ValueError(
                    "Unable to connect to the Cerebras API. This might be due to network "
                    "restrictions in your environment. The API requires direct internet access. "
                    "Please try a different provider or check your network settings."
                )
            else:
                raise ValueError(f"Request to Cerebras API failed: {str(e)}")
    except Exception as e:
        logger.error(f"Cerebras API error: {str(e)}")
        raise e
def call_googleai_api(payload, api_key_override=None):
    """Make a call to Google AI (Gemini) API with error handling"""
    try:
        api_key = api_key_override if api_key_override else GOOGLEAI_API_KEY
        if not api_key:
            raise ValueError("Google AI API key is required")
            
        # Use regular requests instead of the SDK since it might be missing
        gemini_api_url = "https://generativelanguage.googleapis.com/v1/models/gemini-1.5-pro:generateContent"
        
        # Extract parameters from payload
        messages = payload.get("messages", [])
        temperature = payload.get("temperature", 0.7)
        max_tokens = payload.get("max_tokens", 1000)
        
        # Convert to Google's format
        content_parts = []
        
        # Add all messages
        for msg in messages:
            role = msg["role"]
            content = msg["content"]
            
            # Handle different roles
            if role == "system":
                # For system messages, we add it as part of the first user message
                continue
            elif role == "user":
                # For user messages, add as regular content
                if isinstance(content, str):
                    content_parts.append({"text": content})
                else:
                    # Handle multimodal content
                    for item in content:
                        if item["type"] == "text":
                            content_parts.append({"text": item["text"]})
            
        # Form the request data
        data = {
            "contents": [{"parts": content_parts}],
            "generationConfig": {
                "temperature": temperature,
                "maxOutputTokens": max_tokens,
                "topP": payload.get("top_p", 0.95),
            }
        }
        
        headers = {
            "Content-Type": "application/json",
            "x-goog-api-key": api_key
        }
        
        # Make the request
        response = requests.post(
            gemini_api_url,
            headers=headers,
            json=data,
            timeout=30
        )
        
        if response.status_code != 200:
            error_msg = f"Google AI API error: {response.status_code} - {response.text}"
            logger.error(error_msg)
            raise ValueError(error_msg)
            
        # Parse response and convert to standard format
        result = response.json()
        text_content = ""
        
        # Extract text from response
        if "candidates" in result and len(result["candidates"]) > 0:
            candidate = result["candidates"][0]
            if "content" in candidate and "parts" in candidate["content"]:
                for part in candidate["content"]["parts"]:
                    if "text" in part:
                        text_content += part["text"]
        
        # Create a standardized response format
        return {
            "choices": [
                {
                    "message": {
                        "role": "assistant",
                        "content": text_content
                    }
                }
            ]
        }
        
    except Exception as e:
        logger.error(f"Google AI API error: {str(e)}")
        raise e
# ==========================================================
# STREAMING HANDLERS
# ==========================================================
def openrouter_streaming_handler(response, history, message):
    """Handle streaming responses from OpenRouter"""
    try:
        updated_history = history + [{"role": "user", "content": message}]
        assistant_response = ""
        
        for line in response.iter_lines():
            if not line:
                continue
                
            line = line.decode('utf-8')
            if not line.startswith('data: '):
                continue
                
            data = line[6:]
            if data.strip() == '[DONE]':
                break
                
            try:
                chunk = json.loads(data)
                if "choices" in chunk and len(chunk["choices"]) > 0:
                    delta = chunk["choices"][0].get("delta", {})
                    if "content" in delta and delta["content"]:
                        # Update the current response
                        assistant_response += delta["content"]
                        yield updated_history + [{"role": "assistant", "content": assistant_response}]
            except json.JSONDecodeError:
                logger.error(f"Failed to parse JSON from chunk: {data}")
    except Exception as e:
        logger.error(f"Error in streaming handler: {str(e)}")
        # Add error message to the current response
        yield updated_history + [{"role": "assistant", "content": f"Error during streaming: {str(e)}"}]
def openai_streaming_handler(response, history, message):
    """Handle streaming responses from OpenAI"""
    try:
        updated_history = history + [{"role": "user", "content": message}]
        assistant_response = ""
        
        for chunk in response:
            if hasattr(chunk.choices[0].delta, "content") and chunk.choices[0].delta.content is not None:
                content = chunk.choices[0].delta.content
                assistant_response += content
                yield updated_history + [{"role": "assistant", "content": assistant_response}]
    except Exception as e:
        logger.error(f"Error in OpenAI streaming handler: {str(e)}")
        # Add error message to the current response
        yield updated_history + [{"role": "assistant", "content": f"Error during streaming: {str(e)}"}]
def groq_streaming_handler(response, history, message):
    """Handle streaming responses from Groq"""
    try:
        updated_history = history + [{"role": "user", "content": message}]
        assistant_response = ""
        
        for chunk in response:
            if hasattr(chunk.choices[0].delta, "content") and chunk.choices[0].delta.content is not None:
                content = chunk.choices[0].delta.content
                assistant_response += content
                yield updated_history + [{"role": "assistant", "content": assistant_response}]
    except Exception as e:
        logger.error(f"Error in Groq streaming handler: {str(e)}")
        # Add error message to the current response
        yield updated_history + [{"role": "assistant", "content": f"Error during streaming: {str(e)}"}]
def together_streaming_handler(response, history, message):
    """Handle streaming responses from Together"""
    try:
        updated_history = history + [{"role": "user", "content": message}]
        assistant_response = ""
        
        for chunk in response:
            if hasattr(chunk.choices[0].delta, "content") and chunk.choices[0].delta.content is not None:
                content = chunk.choices[0].delta.content
                assistant_response += content
                yield updated_history + [{"role": "assistant", "content": assistant_response}]
    except Exception as e:
        logger.error(f"Error in Together streaming handler: {str(e)}")
        # Add error message to the current response
        yield updated_history + [{"role": "assistant", "content": f"Error during streaming: {str(e)}"}]
        
# ==========================================================
# MAIN FUNCTION TO ASK AI
# ==========================================================
def ask_ai(message, history, provider, model_choice, temperature, max_tokens, top_p, 
           frequency_penalty, presence_penalty, repetition_penalty, top_k, min_p, 
           seed, top_a, stream_output, response_format, images, documents, 
           reasoning_effort, system_message, transforms, api_key_override=None):
    """Enhanced AI query function with support for multiple providers"""
    # Validate input
    if not message.strip() and not images and not documents:
        return history
    
    # Create messages from chat history for API requests
    messages = format_to_message_dict(history)
    
    # Add system message if provided
    if system_message and system_message.strip():
        # Remove any existing system message
        messages = [msg for msg in messages if msg.get("role") != "system"]
        # Add new system message at the beginning
        messages.insert(0, {"role": "system", "content": system_message.strip()})
    
    # Prepare message with images and documents if any
    content = prepare_message_with_media(message, images, documents)
    
    # Add current message to API messages
    messages.append({"role": "user", "content": content})
    
    # Common parameters for all providers
    common_params = {
        "temperature": temperature,
        "max_tokens": max_tokens,
        "top_p": top_p,
        "frequency_penalty": frequency_penalty,
        "presence_penalty": presence_penalty,
        "stream": stream_output
    }
    
    try:
        # Process based on provider
        if provider == "OpenRouter":
            # Get model ID from registry
            model_id, _ = get_model_info(provider, model_choice)
            if not model_id:
                error_message = f"Error: Model '{model_choice}' not found in OpenRouter"
                return history + [
                    {"role": "user", "content": message},
                    {"role": "assistant", "content": error_message}
                ]
                
            # Build OpenRouter payload
            payload = {
                "model": model_id,
                "messages": messages,
                **common_params
            }
            
            # Add optional parameters if set
            if repetition_penalty != 1.0:
                payload["repetition_penalty"] = repetition_penalty
            
            if top_k > 0:
                payload["top_k"] = top_k
            
            if min_p > 0:
                payload["min_p"] = min_p
            
            if seed > 0:
                payload["seed"] = seed
            
            if top_a > 0:
                payload["top_a"] = top_a
            
            # Add response format if JSON is requested
            if response_format == "json_object":
                payload["response_format"] = {"type": "json_object"}
            
            # Add reasoning if selected
            if reasoning_effort != "none":
                payload["reasoning"] = {
                    "effort": reasoning_effort
                }
            
            # Add transforms if selected
            if transforms:
                payload["transforms"] = transforms
                
            # Call OpenRouter API
            logger.info(f"Sending request to OpenRouter model: {model_id}")
            
            response = call_openrouter_api(payload, api_key_override)
            
            # Handle streaming response
            if stream_output and response.status_code == 200:
                # Set up generator for streaming updates
                def streaming_generator():
                    updated_history = history + [{"role": "user", "content": message}]
                    assistant_response = ""
                    
                    for line in response.iter_lines():
                        if not line:
                            continue
                            
                        line = line.decode('utf-8')
                        if not line.startswith('data: '):
                            continue
                            
                        data = line[6:]
                        if data.strip() == '[DONE]':
                            break
                            
                        try:
                            chunk = json.loads(data)
                            if "choices" in chunk and len(chunk["choices"]) > 0:
                                delta = chunk["choices"][0].get("delta", {})
                                if "content" in delta and delta["content"]:
                                    # Update the current response
                                    assistant_response += delta["content"]
                                    # Return updated history with current response
                                    yield updated_history + [{"role": "assistant", "content": assistant_response}]
                        except json.JSONDecodeError:
                            logger.error(f"Failed to parse JSON from chunk: {data}")
                
                return streaming_generator()
            
            # Handle normal response
            elif response.status_code == 200:
                result = response.json()
                logger.info(f"Response content: {result}")
                
                # Extract AI response
                ai_response = extract_ai_response(result, provider)
                
                # Add response to history with proper format
                return history + [
                    {"role": "user", "content": message},
                    {"role": "assistant", "content": ai_response}
                ]
            
            # Handle error response
            else:
                error_message = f"Error: Status code {response.status_code}"
                try:
                    response_data = response.json()
                    error_message += f"\n\nDetails: {json.dumps(response_data, indent=2)}"
                except:
                    error_message += f"\n\nResponse: {response.text}"
                
                logger.error(error_message)
                return history + [
                    {"role": "user", "content": message},
                    {"role": "assistant", "content": error_message}
                ]
            
        elif provider == "Poe":
            # Get model ID from registry
            model_id, _ = get_model_info(provider, model_choice)
            if not model_id:
                error_message = f"Error: Model '{model_choice}' not found in Poe"
                return history + [
                    {"role": "user", "content": message},
                    {"role": "assistant", "content": error_message}
                ]
                
            # Build Poe payload
            payload = {
                "model": model_id,
                "messages": messages
                # Poe doesn't support most parameters directly
            }
            
            # Call Poe API
            logger.info(f"Sending request to Poe model: {model_id}")
            
            try:
                response = call_poe_api(payload, api_key_override)
                
                # Extract response
                ai_response = extract_ai_response(response, provider)
                return history + [
                    {"role": "user", "content": message},
                    {"role": "assistant", "content": ai_response}
                ]
            except Exception as e:
                error_message = f"Poe API Error: {str(e)}"
                logger.error(error_message)
                return history + [
                    {"role": "user", "content": message},
                    {"role": "assistant", "content": error_message}
                ]
            
        elif provider == "Anthropic":
            # Get model ID from registry
            model_id, _ = get_model_info(provider, model_choice)
            if not model_id:
                error_message = f"Error: Model '{model_choice}' not found in Anthropic"
                return history + [
                    {"role": "user", "content": message},
                    {"role": "assistant", "content": error_message}
                ]
                
            # Build Anthropic payload
            payload = {
                "model": model_id,
                "messages": messages,
                "temperature": temperature,
                "max_tokens": max_tokens
            }
            
            # Call Anthropic API
            logger.info(f"Sending request to Anthropic model: {model_id}")
            
            try:
                response = call_anthropic_api(payload, api_key_override)
                
                # Extract response
                ai_response = extract_ai_response(response, provider)
                return history + [
                    {"role": "user", "content": message},
                    {"role": "assistant", "content": ai_response}
                ]
            except Exception as e:
                error_message = f"Anthropic API Error: {str(e)}"
                logger.error(error_message)
                return history + [
                    {"role": "user", "content": message},
                    {"role": "assistant", "content": error_message}
                ]
                
        elif provider == "OpenAI":
            # Process OpenAI similarly as above...
            # Get model ID from registry
            model_id, _ = get_model_info(provider, model_choice)
            if not model_id:
                error_message = f"Error: Model '{model_choice}' not found in OpenAI"
                return history + [
                    {"role": "user", "content": message},
                    {"role": "assistant", "content": error_message}
                ]
                
            # Build OpenAI payload
            payload = {
                "model": model_id,
                "messages": messages,
                **common_params
            }
            
            # Add response format if JSON is requested
            if response_format == "json_object":
                payload["response_format"] = {"type": "json_object"}
            
            # Call OpenAI API
            logger.info(f"Sending request to OpenAI model: {model_id}")
            
            try:
                response = call_openai_api(payload, api_key_override)
                
                # Handle streaming response
                if stream_output:
                    # Set up generator for streaming updates
                    def streaming_generator():
                        updated_history = history + [{"role": "user", "content": message}]
                        assistant_response = ""
                        
                        for chunk in response:
                            if hasattr(chunk.choices[0].delta, "content") and chunk.choices[0].delta.content is not None:
                                content = chunk.choices[0].delta.content
                                assistant_response += content
                                yield updated_history + [{"role": "assistant", "content": assistant_response}]
                    
                    return streaming_generator()
                
                # Handle normal response
                else:
                    ai_response = extract_ai_response(response, provider)
                    return history + [
                        {"role": "user", "content": message},
                        {"role": "assistant", "content": ai_response}
                    ]
            except Exception as e:
                error_message = f"OpenAI API Error: {str(e)}"
                logger.error(error_message)
                return history + [
                    {"role": "user", "content": message},
                    {"role": "assistant", "content": error_message}
                ]
                
        elif provider == "HuggingFace":
            model_id, _ = get_model_info(provider, model_choice)
            if not model_id:
                error_message = f"Error: Model '{model_choice}' not found in HuggingFace"
                return history + [
                    {"role": "user", "content": message},
                    {"role": "assistant", "content": error_message}
                ]
                
            # Build HuggingFace payload
            payload = {
                "model": model_id,
                "messages": messages,
                "temperature": temperature,
                "max_tokens": max_tokens
            }
            
            # Call HuggingFace API
            logger.info(f"Sending request to HuggingFace model: {model_id}")
            
            try:
                response = call_huggingface_api(payload, api_key_override)
                
                # Extract response
                ai_response = extract_ai_response(response, provider)
                return history + [
                    {"role": "user", "content": message},
                    {"role": "assistant", "content": ai_response}
                ]
            except Exception as e:
                error_message = f"HuggingFace API Error: {str(e)}"
                logger.error(error_message)
                return history + [
                    {"role": "user", "content": message},
                    {"role": "assistant", "content": error_message}
                ]
                
        elif provider == "Groq":
            # Get model ID from registry
            model_id, _ = get_model_info(provider, model_choice)
            if not model_id:
                error_message = f"Error: Model '{model_choice}' not found in Groq"
                return history + [
                    {"role": "user", "content": message},
                    {"role": "assistant", "content": error_message}
                ]
                
            # Build Groq payload
            payload = {
                "model": model_id,
                "messages": messages,
                "temperature": temperature,
                "max_tokens": max_tokens,
                "top_p": top_p,
                "stream": stream_output
            }
            
            # Call Groq API
            logger.info(f"Sending request to Groq model: {model_id}")
            
            try:
                response = call_groq_api(payload, api_key_override)
                
                # Handle streaming response
                if stream_output:
                    # Add message to history
                    updated_history = history + [{"role": "user", "content": message}]
                    
                    # Set up generator for streaming updates
                    def streaming_generator():
                        assistant_response = ""
                        for chunk in response:
                            if hasattr(chunk.choices[0].delta, "content") and chunk.choices[0].delta.content is not None:
                                content = chunk.choices[0].delta.content
                                assistant_response += content
                                yield updated_history + [{"role": "assistant", "content": assistant_response}]
                    
                    return streaming_generator()
                
                # Handle normal response
                else:
                    ai_response = extract_ai_response(response, provider)
                    return history + [
                        {"role": "user", "content": message},
                        {"role": "assistant", "content": ai_response}
                    ]
            except Exception as e:
                error_message = f"Groq API Error: {str(e)}"
                logger.error(error_message)
                return history + [
                    {"role": "user", "content": message},
                    {"role": "assistant", "content": error_message}
                ]
                
        elif provider == "Cohere":
            # Get model ID from registry
            model_id, _ = get_model_info(provider, model_choice)
            if not model_id:
                error_message = f"Error: Model '{model_choice}' not found in Cohere"
                return history + [
                    {"role": "user", "content": message},
                    {"role": "assistant", "content": error_message}
                ]
                
            # Build Cohere payload (doesn't support streaming the same way)
            payload = {
                "model": model_id,
                "messages": messages,
                "temperature": temperature,
                "max_tokens": max_tokens
            }
            
            # Call Cohere API
            logger.info(f"Sending request to Cohere model: {model_id}")
            
            try:
                response = call_cohere_api(payload, api_key_override)
                
                # Extract response
                ai_response = extract_ai_response(response, provider)
                return history + [
                    {"role": "user", "content": message},
                    {"role": "assistant", "content": ai_response}
                ]
            except Exception as e:
                error_message = f"Cohere API Error: {str(e)}"
                logger.error(error_message)
                return history + [
                    {"role": "user", "content": message},
                    {"role": "assistant", "content": error_message}
                ]
                
        elif provider == "Together":
            # Get model ID from registry
            model_id, _ = get_model_info(provider, model_choice)
            if not model_id:
                error_message = f"Error: Model '{model_choice}' not found in Together"
                return history + [
                    {"role": "user", "content": message},
                    {"role": "assistant", "content": error_message}
                ]
            
            # Build Together payload
            payload = {
                "model": model_id,
                "messages": messages,
                "temperature": temperature,
                "max_tokens": max_tokens,
                "stream": stream_output
            }
            
            # Call Together API
            logger.info(f"Sending request to Together model: {model_id}")
            
            try:
                response = call_together_api(payload, api_key_override)
                
                # Handle streaming response
                if stream_output:
                    # Add message to history
                    updated_history = history + [{"role": "user", "content": message}]
                    
                    # Set up generator for streaming updates
                    def streaming_generator():
                        assistant_response = ""
                        for chunk in response:
                            if hasattr(chunk.choices[0].delta, "content") and chunk.choices[0].delta.content is not None:
                                content = chunk.choices[0].delta.content
                                assistant_response += content
                                yield updated_history + [{"role": "assistant", "content": assistant_response}]
                    
                    return streaming_generator()
                
                # Handle normal response
                else:
                    ai_response = extract_ai_response(response, provider)
                    return history + [
                        {"role": "user", "content": message},
                        {"role": "assistant", "content": ai_response}
                    ]
            except Exception as e:
                error_message = f"Together API Error: {str(e)}"
                logger.error(error_message)
                return history + [
                    {"role": "user", "content": message},
                    {"role": "assistant", "content": error_message}
                ]
                
        elif provider == "OVH":
            # Get model ID from registry
            model_id, _ = get_model_info(provider, model_choice)
            if not model_id:
                error_message = f"Error: Model '{model_choice}' not found in OVH"
                return history + [
                    {"role": "user", "content": message},
                    {"role": "assistant", "content": error_message}
                ]
            
            # Build OVH payload
            payload = {
                "model": model_id,
                "messages": messages,
                "temperature": temperature,
                "max_tokens": max_tokens
            }
            
            # Call OVH API
            logger.info(f"Sending request to OVH model: {model_id}")
            
            try:
                response = call_ovh_api(payload)
                
                # Extract response
                ai_response = extract_ai_response(response, provider)
                return history + [
                    {"role": "user", "content": message},
                    {"role": "assistant", "content": ai_response}
                ]
            except Exception as e:
                error_message = f"OVH API Error: {str(e)}"
                logger.error(error_message)
                return history + [
                    {"role": "user", "content": message},
                    {"role": "assistant", "content": error_message}
                ]
                
        elif provider == "Cerebras":
            # Get model ID from registry
            model_id, _ = get_model_info(provider, model_choice)
            if not model_id:
                error_message = f"Error: Model '{model_choice}' not found in Cerebras"
                return history + [
                    {"role": "user", "content": message},
                    {"role": "assistant", "content": error_message}
                ]
            
            # Build Cerebras payload
            payload = {
                "model": model_id,
                "messages": messages,
                "temperature": temperature,
                "max_tokens": max_tokens
            }
            
            # Call Cerebras API
            logger.info(f"Sending request to Cerebras model: {model_id}")
            
            try:
                response = call_cerebras_api(payload)
                
                # Extract response
                ai_response = extract_ai_response(response, provider)
                return history + [
                    {"role": "user", "content": message},
                    {"role": "assistant", "content": ai_response}
                ]
            except Exception as e:
                error_message = f"Cerebras API Error: {str(e)}"
                logger.error(error_message)
                return history + [
                    {"role": "user", "content": message},
                    {"role": "assistant", "content": error_message}
                ]
                
        elif provider == "GoogleAI":
            # Get model ID from registry
            model_id, _ = get_model_info(provider, model_choice)
            if not model_id:
                error_message = f"Error: Model '{model_choice}' not found in GoogleAI"
                return history + [
                    {"role": "user", "content": message},
                    {"role": "assistant", "content": error_message}
                ]
            
            # Build GoogleAI payload
            payload = {
                "model": model_id,
                "messages": messages,
                "temperature": temperature,
                "max_tokens": max_tokens,
                "top_p": top_p
            }
            
            # Call GoogleAI API
            logger.info(f"Sending request to GoogleAI model: {model_id}")
            
            try:
                response = call_googleai_api(payload, api_key_override)
                
                # Extract response
                ai_response = extract_ai_response(response, provider)
                return history + [
                    {"role": "user", "content": message},
                    {"role": "assistant", "content": ai_response}
                ]
            except Exception as e:
                error_message = f"GoogleAI API Error: {str(e)}"
                logger.error(error_message)
                return history + [
                    {"role": "user", "content": message},
                    {"role": "assistant", "content": error_message}
                ]
            
        else:
            error_message = f"Error: Unsupported provider '{provider}'"
            return history + [
                {"role": "user", "content": message},
                {"role": "assistant", "content": error_message}
            ]
            
    except Exception as e:
        error_message = f"Error: {str(e)}"
        logger.error(f"Exception during API call: {error_message}")
        return history + [
            {"role": "user", "content": message},
            {"role": "assistant", "content": error_message}
        ]
def clear_chat():
    """Reset all inputs"""
    return [], "", [], [], 0.7, 1000, 0.8, 0.0, 0.0, 1.0, 40, 0.1, 0, 0.0, False, "default", "none", "", []
# ==========================================================
# UI CREATION
# ==========================================================
def create_app():
    """Create the CrispChat Gradio application"""
    with gr.Blocks(
        title="CrispChat",
        css="""
            .context-size { 
                font-size: 0.9em;
                color: #666;
                margin-left: 10px;
            }
            footer { display: none !important; }
            .model-selection-row {
                display: flex;
                align-items: center;
            }
            .parameter-grid {
                display: grid;
                grid-template-columns: 1fr 1fr;
                gap: 10px;
            }
            .vision-badge {
                background-color: #4CAF50;
                color: white;
                padding: 3px 6px;
                border-radius: 3px;
                font-size: 0.8em;
                margin-left: 5px;
            }
            .provider-selection {
                margin-bottom: 10px;
                padding: 10px;
                border-radius: 5px;
                background-color: #f5f5f5;
            }
        """
    ) as demo:
        gr.Markdown("""
        # 🤖 CrispChat
        
        Chat with AI models from multiple providers: OpenRouter, OpenAI, HuggingFace, Groq, Cohere, Together, Anthropic, and Google AI.
        """)
        
        with gr.Row():
            with gr.Column(scale=2):
                # Chatbot interface
                chatbot = gr.Chatbot(
                    height=500, 
                    show_copy_button=True, 
                    show_label=False,
                    avatar_images=(None, "https://upload.wikimedia.org/wikipedia/commons/0/04/ChatGPT_logo.svg"),
                    elem_id="chat-window",
                    type="messages"  # use the new format
                )
                
                with gr.Row():
                    message = gr.Textbox(
                        placeholder="Type your message here...",
                        label="Message",
                        lines=2,
                        elem_id="message-input",
                        scale=4
                    )
                
                with gr.Row():
                    with gr.Column(scale=3):
                        submit_btn = gr.Button("Send", variant="primary", elem_id="send-btn")
                    
                    with gr.Column(scale=1):
                        clear_btn = gr.Button("Clear Chat", variant="secondary")
                
                # Container for conditionally showing image upload
                with gr.Row(visible=True) as image_upload_container:
                    # Image upload
                    with gr.Accordion("Upload Images (for vision models)", open=False):
                        images = gr.File(
                            label="Uploaded Images",
                            file_types=["image"],
                            file_count="multiple"
                        )
                        
                        image_upload_btn = gr.UploadButton(
                            label="Upload Images",
                            file_types=["image"],
                            file_count="multiple"
                        )
                    
                    # Document upload
                    with gr.Accordion("Upload Documents (PDF, MD, TXT)", open=False):
                        documents = gr.File(
                            label="Uploaded Documents",
                            file_types=[".pdf", ".md", ".txt"], 
                            file_count="multiple"
                        )
            
            with gr.Column(scale=1):
                                
                with gr.Group(elem_classes="provider-selection"):
                    gr.Markdown("### Provider Selection")
                    
                    # Provider selection
                    provider_choice = gr.Radio(
                        choices=["OpenRouter", "OpenAI", "HuggingFace", "Groq", "Cohere", "Together", "Anthropic", "Poe", "GoogleAI"],
                        value="OpenRouter",
                        label="AI Provider"
                    )
                    
                    # API key input with separate fields for each provider
                    with gr.Accordion("API Keys", open=False):
                        gr.Markdown("Enter API keys directly or set them as environment variables")
                        
                        openrouter_api_key = gr.Textbox(
                            placeholder="Enter OpenRouter API key",
                            label="OpenRouter API Key",
                            type="password",
                            value=OPENROUTER_API_KEY if OPENROUTER_API_KEY else ""
                        )
                        poe_api_key = gr.Textbox(
                            placeholder="Enter Poe API key",
                            label="Poe API Key",
                            type="password",
                            value=POE_API_KEY if POE_API_KEY else ""
                        )
                        
                        openai_api_key = gr.Textbox(
                            placeholder="Enter OpenAI API key",
                            label="OpenAI API Key",
                            type="password",
                            value=OPENAI_API_KEY if OPENAI_API_KEY else ""
                        )
                        
                        hf_api_key = gr.Textbox(
                            placeholder="Enter HuggingFace API key",
                            label="HuggingFace API Key",
                            type="password",
                            value=HF_API_KEY if HF_API_KEY else ""
                        )
                        
                        groq_api_key = gr.Textbox(
                            placeholder="Enter Groq API key",
                            label="Groq API Key",
                            type="password",
                            value=GROQ_API_KEY if GROQ_API_KEY else ""
                        )
                        
                        cohere_api_key = gr.Textbox(
                            placeholder="Enter Cohere API key",
                            label="Cohere API Key",
                            type="password",
                            value=COHERE_API_KEY if COHERE_API_KEY else ""
                        )
                        
                        together_api_key = gr.Textbox(
                            placeholder="Enter Together API key",
                            label="Together API Key",
                            type="password",
                            value=TOGETHER_API_KEY if TOGETHER_API_KEY else ""
                        )
                        
                        # Add Anthropic API key
                        anthropic_api_key = gr.Textbox(
                            placeholder="Enter Anthropic API key",
                            label="Anthropic API Key",
                            type="password",
                            value=os.environ.get("ANTHROPIC_API_KEY", "")
                        )
                        
                        googleai_api_key = gr.Textbox(
                            placeholder="Enter Google AI API key",
                            label="Google AI API Key",
                            type="password",
                            value=GOOGLEAI_API_KEY if GOOGLEAI_API_KEY else ""
                        )
                
                with gr.Group():
                    gr.Markdown("### Model Selection")
                    
                    with gr.Row(elem_classes="model-selection-row"):
                        model_search = gr.Textbox(
                            placeholder="Search models...",
                            label="",
                            show_label=False
                        )
                    
                    # Provider-specific model dropdowns
                    openrouter_model = gr.Dropdown(
                        choices=[model[0] for model in OPENROUTER_ALL_MODELS],
                        value=OPENROUTER_ALL_MODELS[0][0] if OPENROUTER_ALL_MODELS else None,
                        label="OpenRouter Model",
                        elem_id="openrouter-model-choice",
                        visible=True
                    )
                    # Add Poe model dropdown
                    poe_model = gr.Dropdown(
                        choices=list(POE_MODELS.keys()),
                        value="chinchilla" if "chinchilla" in POE_MODELS else None,
                        label="Poe Model",
                        elem_id="poe-model-choice",
                        visible=False
                    )
                    
                    openai_model = gr.Dropdown(
                        choices=list(OPENAI_MODELS.keys()),
                        value="gpt-3.5-turbo" if "gpt-3.5-turbo" in OPENAI_MODELS else None,
                        label="OpenAI Model",
                        elem_id="openai-model-choice",
                        visible=False
                    )
                    
                    hf_model = gr.Dropdown(
                        choices=list(HUGGINGFACE_MODELS.keys()),
                        value="mistralai/Mistral-7B-Instruct-v0.3" if "mistralai/Mistral-7B-Instruct-v0.3" in HUGGINGFACE_MODELS else None,
                        label="HuggingFace Model",
                        elem_id="hf-model-choice",
                        visible=False
                    )
                    
                    groq_model = gr.Dropdown(
                        choices=list(GROQ_MODELS.keys()),
                        value="llama-3.1-8b-instant" if "llama-3.1-8b-instant" in GROQ_MODELS else None,
                        label="Groq Model",
                        elem_id="groq-model-choice",
                        visible=False
                    )
                    
                    cohere_model = gr.Dropdown(
                        choices=list(COHERE_MODELS.keys()),
                        value="command-r-plus" if "command-r-plus" in COHERE_MODELS else None,
                        label="Cohere Model",
                        elem_id="cohere-model-choice",
                        visible=False
                    )
                    
                    together_model = gr.Dropdown(
                        choices=list(TOGETHER_MODELS.keys()),
                        value="meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo" if "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo" in TOGETHER_MODELS else None,
                        label="Together Model",
                        elem_id="together-model-choice",
                        visible=False
                    )
                    
                    # Add Anthropic model dropdown
                    anthropic_model = gr.Dropdown(
                        choices=list(ANTHROPIC_MODELS.keys()),
                        value="claude-3-5-sonnet-20241022" if "claude-3-5-sonnet-20241022" in ANTHROPIC_MODELS else None,
                        label="Anthropic Model",
                        elem_id="anthropic-model-choice",
                        visible=False
                    )
                    
                    googleai_model = gr.Dropdown(
                        choices=list(GOOGLEAI_MODELS.keys()),
                        value="gemini-1.5-pro" if "gemini-1.5-pro" in GOOGLEAI_MODELS else None,
                        label="Google AI Model",
                        elem_id="googleai-model-choice",
                        visible=False
                    )
                    
                    context_display = gr.Textbox(
                        value=update_context_display("OpenRouter", OPENROUTER_ALL_MODELS[0][0]),
                        label="Context Size",
                        interactive=False,
                        elem_classes="context-size"
                    )
                
                with gr.Accordion("Generation Parameters", open=False):
                    with gr.Group(elem_classes="parameter-grid"):
                        temperature = gr.Slider(
                            minimum=0.0, 
                            maximum=2.0, 
                            value=0.7, 
                            step=0.1,
                            label="Temperature"
                        )
                        
                        max_tokens = gr.Slider(
                            minimum=100, 
                            maximum=4000, 
                            value=1000, 
                            step=100,
                            label="Max Tokens"
                        )
                        
                        top_p = gr.Slider(
                            minimum=0.1, 
                            maximum=1.0, 
                            value=0.8, 
                            step=0.1,
                            label="Top P"
                        )
                        
                        frequency_penalty = gr.Slider(
                            minimum=-2.0, 
                            maximum=2.0, 
                            value=0.0, 
                            step=0.1,
                            label="Frequency Penalty"
                        )
                        
                        presence_penalty = gr.Slider(
                            minimum=-2.0, 
                            maximum=2.0, 
                            value=0.0, 
                            step=0.1,
                            label="Presence Penalty"
                        )
                        
                        reasoning_effort = gr.Radio(
                            ["none", "low", "medium", "high"],
                            value="none",
                            label="Reasoning Effort (OpenRouter)"
                        )
                
                with gr.Accordion("Advanced Options", open=False):
                    with gr.Row():
                        with gr.Column():
                            repetition_penalty = gr.Slider(
                                minimum=0.1, 
                                maximum=2.0, 
                                value=1.0, 
                                step=0.1,
                                label="Repetition Penalty"
                            )
                            
                            top_k = gr.Slider(
                                minimum=1, 
                                maximum=100, 
                                value=40, 
                                step=1,
                                label="Top K"
                            )
                            
                            min_p = gr.Slider(
                                minimum=0.0, 
                                maximum=1.0, 
                                value=0.1, 
                                step=0.05,
                                label="Min P"
                            )
                        
                        with gr.Column():
                            seed = gr.Number(
                                value=0,
                                label="Seed (0 for random)",
                                precision=0
                            )
                            
                            top_a = gr.Slider(
                                minimum=0.0, 
                                maximum=1.0, 
                                value=0.0, 
                                step=0.05,
                                label="Top A"
                            )
                            
                            stream_output = gr.Checkbox(
                                label="Stream Output",
                                value=False
                            )
                    
                    with gr.Row():
                        response_format = gr.Radio(
                            ["default", "json_object"],
                            value="default",
                            label="Response Format"
                        )
                        
                        gr.Markdown("""
                        * **json_object**: Forces the model to respond with valid JSON only.
                        * Only available on certain models - check model support.
                        """)
                
                # Custom instructing options
                with gr.Accordion("Custom Instructions", open=False):
                    system_message = gr.Textbox(
                        placeholder="Enter a system message to guide the model's behavior...",
                        label="System Message",
                        lines=3
                    )
                    
                    transforms = gr.CheckboxGroup(
                        ["prompt_optimize", "prompt_distill", "prompt_compress"],
                        label="Prompt Transforms (OpenRouter specific)"
                    )
                    
                    gr.Markdown("""
                    * **prompt_optimize**: Improve prompt for better responses.
                    * **prompt_distill**: Compress prompt to use fewer tokens without changing meaning.
                    * **prompt_compress**: Aggressively compress prompt to fit larger contexts.
                    """)
                
                # Add a model information section
                with gr.Accordion("About Selected Model", open=False):
                    model_info_display = gr.HTML(
                        value=update_model_info("OpenRouter", OPENROUTER_ALL_MODELS[0][0])
                    )
                    
                    is_vision_indicator = gr.Checkbox(
                        label="Supports Images",
                        value=is_vision_model("OpenRouter", OPENROUTER_ALL_MODELS[0][0]),
                        interactive=False
                    )
        
        # Add usage instructions
        with gr.Accordion("Usage Instructions", open=False):
            gr.Markdown("""
            ## Basic Usage
            1. Type your message in the input box
            2. Select a provider and model
            3. Click "Send" or press Enter
            
            ## Working with Files
            - **Images**: Upload images to use with vision-capable models
            - **Documents**: Upload PDF, Markdown, or text files to analyze their content
            
            ## Provider Information
            - **OpenRouter**: Free access to various models with context window sizes up to 2M tokens
            - **OpenAI**: Requires an API key, includes GPT-3.5 and GPT-4 models
            - **HuggingFace**: Direct access to open models, some models require API key
            - **Groq**: High-performance inference, requires API key
            - **Cohere**: Specialized in language understanding, requires API key
            - **Together**: Access to high-quality open models, requires API key
            - **Anthropic**: Claude models with strong reasoning capabilities, requires API key
            - **GoogleAI**: Google's Gemini models, requires API key
            
            ## Advanced Parameters
            - **Temperature**: Controls randomness (higher = more creative, lower = more deterministic)
            - **Max Tokens**: Maximum length of the response
            - **Top P**: Nucleus sampling threshold (higher = consider more tokens)
            - **Reasoning Effort**: Some models can show their reasoning process (OpenRouter only)
            """)
        
        # Add a footer with version info
        footer_md = gr.Markdown("""
        ---
        ### CrispChat v1.2
        Built with ❤️ using Gradio and multiple AI provider APIs | Context sizes shown next to model names
        """)
        
        # Define event handlers
        def toggle_model_dropdowns(provider):
            """Show/hide model dropdowns based on provider selection"""
            return {
                openrouter_model: gr.update(visible=(provider == "OpenRouter")),
                openai_model: gr.update(visible=(provider == "OpenAI")),
                hf_model: gr.update(visible=(provider == "HuggingFace")),
                groq_model: gr.update(visible=(provider == "Groq")),
                cohere_model: gr.update(visible=(provider == "Cohere")),
                together_model: gr.update(visible=(provider == "Together")),
                anthropic_model: gr.update(visible=(provider == "Anthropic")),
                poe_model: gr.update(visible=(provider == "Poe")),
                googleai_model: gr.update(visible=(provider == "GoogleAI"))
            }
            
        def update_context_for_provider(provider, openrouter_model, openai_model, hf_model, groq_model, cohere_model, together_model, anthropic_model, poe_model, googleai_model):
            """Update context display based on selected provider and model"""
            if provider == "OpenRouter":
                return update_context_display(provider, openrouter_model)
            elif provider == "OpenAI":
                return update_context_display(provider, openai_model)
            elif provider == "HuggingFace":
                return update_context_display(provider, hf_model)
            elif provider == "Groq":
                return update_context_display(provider, groq_model)
            elif provider == "Cohere":
                return update_context_display(provider, cohere_model)
            elif provider == "Together":
                return update_context_display(provider, together_model)
            elif provider == "Anthropic":
                return update_context_display(provider, anthropic_model)
            elif provider == "Poe":
                return update_context_display(provider, poe_model)
            elif provider == "GoogleAI":
                return update_context_display(provider, googleai_model)
            return "Unknown"
            
        def update_model_info_for_provider(provider, openrouter_model, openai_model, hf_model, groq_model, cohere_model, together_model, anthropic_model, poe_model, googleai_model):
            """Update model info based on selected provider and model"""
            if provider == "OpenRouter":
                return update_model_info(provider, openrouter_model)
            elif provider == "OpenAI":
                return update_model_info(provider, openai_model)
            elif provider == "HuggingFace":
                return update_model_info(provider, hf_model)
            elif provider == "Groq":
                return update_model_info(provider, groq_model)
            elif provider == "Cohere":
                return update_model_info(provider, cohere_model)
            elif provider == "Together":
                return update_model_info(provider, together_model)
            elif provider == "Anthropic":
                return update_model_info(provider, anthropic_model)
            elif provider == "Poe":
                return update_model_info(provider, poe_model)
            elif provider == "GoogleAI":
                return update_model_info(provider, googleai_model)
            return "Model information not available
"
            
        def update_vision_indicator(provider, model_choice=None):
            """Update the vision capability indicator"""
            # Simplified - don't call get_current_model since it causes issues
            if model_choice is None:
                # Just check if the provider generally supports vision
                return provider in VISION_MODELS and len(VISION_MODELS[provider]) > 0
            
            return is_vision_model(provider, model_choice)
            
        def update_image_upload_visibility(provider, model_choice=None):
            """Show/hide image upload based on model vision capabilities"""
            # Simplified
            is_vision = update_vision_indicator(provider, model_choice)
            return gr.update(visible=is_vision)
        
        # Search model function
        def search_openrouter_models(search_term):
            """Filter OpenRouter models based on search term"""
            all_models = [model[0] for model in OPENROUTER_ALL_MODELS]
            if not search_term:
                return gr.update(choices=all_models, value=all_models[0] if all_models else None)
                
            filtered_models = [model for model in all_models if search_term.lower() in model.lower()]
            
            if filtered_models:
                return gr.update(choices=filtered_models, value=filtered_models[0])
            else:
                return gr.update(choices=all_models, value=all_models[0] if all_models else None)
                
        def search_openai_models(search_term):
            """Filter OpenAI models based on search term"""
            all_models = list(OPENAI_MODELS.keys())
            if not search_term:
                return gr.update(choices=all_models, value="gpt-3.5-turbo" if "gpt-3.5-turbo" in all_models else all_models[0] if all_models else None)
                
            filtered_models = [model for model in all_models if search_term.lower() in model.lower()]
            
            if filtered_models:
                return gr.update(choices=filtered_models, value=filtered_models[0])
            else:
                return gr.update(choices=all_models, value="gpt-3.5-turbo" if "gpt-3.5-turbo" in all_models else all_models[0] if all_models else None)
                
        def search_hf_models(search_term):
            """Filter HuggingFace models based on search term"""
            all_models = list(HUGGINGFACE_MODELS.keys())
            if not search_term:
                default_model = "mistralai/Mistral-7B-Instruct-v0.3" if "mistralai/Mistral-7B-Instruct-v0.3" in all_models else all_models[0] if all_models else None
                return gr.update(choices=all_models, value=default_model)
                
            filtered_models = [model for model in all_models if search_term.lower() in model.lower()]
            
            if filtered_models:
                return gr.update(choices=filtered_models, value=filtered_models[0])
            else:
                default_model = "mistralai/Mistral-7B-Instruct-v0.3" if "mistralai/Mistral-7B-Instruct-v0.3" in all_models else all_models[0] if all_models else None
                return gr.update(choices=all_models, value=default_model)
            
        def search_models_generic(search_term, model_dict, default_model=None):
            """Generic model search function to reduce code duplication"""
            all_models = list(model_dict.keys())
            if not all_models:
                return gr.update(choices=[], value=None)
                
            if not search_term:
                return gr.update(choices=all_models, value=default_model if default_model in all_models else all_models[0])
                
            filtered_models = [model for model in all_models if search_term.lower() in model.lower()]
            
            if filtered_models:
                return gr.update(choices=filtered_models, value=filtered_models[0])
            else:
                return gr.update(choices=all_models, value=default_model if default_model in all_models else all_models[0])
        def search_poe_models(search_term):
            """Filter Poe models based on search term"""
            return search_models_generic(search_term, POE_MODELS, "chinchilla")    
        
        def search_groq_models(search_term):
            """Filter Groq models based on search term"""
            all_models = list(GROQ_MODELS.keys())
            if not search_term:
                default_model = "llama-3.1-8b-instant" if "llama-3.1-8b-instant" in all_models else all_models[0] if all_models else None
                return gr.update(choices=all_models, value=default_model)
                
            filtered_models = [model for model in all_models if search_term.lower() in model.lower()]
            
            if filtered_models:
                return gr.update(choices=filtered_models, value=filtered_models[0])
            else:
                default_model = "llama-3.1-8b-instant" if "llama-3.1-8b-instant" in all_models else all_models[0] if all_models else None
                return gr.update(choices=all_models, value=default_model)
                
        def search_cohere_models(search_term):
            """Filter Cohere models based on search term"""
            all_models = list(COHERE_MODELS.keys())
            if not search_term:
                default_model = "command-r-plus" if "command-r-plus" in all_models else all_models[0] if all_models else None
                return gr.update(choices=all_models, value=default_model)
                
            filtered_models = [model for model in all_models if search_term.lower() in model.lower()]
            
            if filtered_models:
                return gr.update(choices=filtered_models, value=filtered_models[0])
            else:
                default_model = "command-r-plus" if "command-r-plus" in all_models else all_models[0] if all_models else None
                return gr.update(choices=all_models, value=default_model)
                
        def search_together_models(search_term):
            """Filter Together models based on search term"""
            all_models = list(TOGETHER_MODELS.keys())
            if not search_term:
                default_model = "meta-llama/Llama-3.1-8B-Instruct" if "meta-llama/Llama-3.1-8B-Instruct" in all_models else all_models[0] if all_models else None
                return gr.update(choices=all_models, value=default_model)
                
            filtered_models = [model for model in all_models if search_term.lower() in model.lower()]
            
            if filtered_models:
                return gr.update(choices=filtered_models, value=filtered_models[0])
            else:
                default_model = "meta-llama/Llama-3.1-8B-Instruct" if "meta-llama/Llama-3.1-8B-Instruct" in all_models else all_models[0] if all_models else None
                return gr.update(choices=all_models, value=default_model)
                
        def search_anthropic_models(search_term):
            """Filter Anthropic models based on search term"""
            all_models = list(ANTHROPIC_MODELS.keys())
            if not search_term:
                return gr.update(choices=all_models, value="claude-3-5-sonnet-20241022" if "claude-3-5-sonnet-20241022" in all_models else all_models[0] if all_models else None)
                
            filtered_models = [model for model in all_models if search_term.lower() in model.lower()]
            
            if filtered_models:
                return gr.update(choices=filtered_models, value=filtered_models[0])
            else:
                return gr.update(choices=all_models, value="claude-3-5-sonnet-20241022" if "claude-3-5-sonnet-20241022" in all_models else all_models[0] if all_models else None)
                
        def search_googleai_models(search_term):
            """Filter GoogleAI models based on search term"""
            all_models = list(GOOGLEAI_MODELS.keys())
            if not search_term:
                default_model = "gemini-1.5-pro" if "gemini-1.5-pro" in all_models else all_models[0] if all_models else None
                return gr.update(choices=all_models, value=default_model)
                
            filtered_models = [model for model in all_models if search_term.lower() in model.lower()]
            
            if filtered_models:
                return gr.update(choices=filtered_models, value=filtered_models[0])
            else:
                default_model = "gemini-1.5-pro" if "gemini-1.5-pro" in all_models else all_models[0] if all_models else None
                return gr.update(choices=all_models, value=default_model)
            
        def get_current_model(provider, openrouter_model, openai_model, hf_model, groq_model, cohere_model, 
                            together_model, anthropic_model, poe_model, googleai_model):
            """Get the currently selected model based on provider"""
            if provider == "OpenRouter":
                return openrouter_model
            elif provider == "OpenAI":
                return openai_model
            elif provider == "HuggingFace":
                return hf_model
            elif provider == "Groq":
                return groq_model
            elif provider == "Cohere":
                return cohere_model
            elif provider == "Together":
                return together_model
            elif provider == "Anthropic":
                return anthropic_model
            elif provider == "Poe":
                return poe_model
            elif provider == "GoogleAI":
                return googleai_model
            return None
        
        # Process uploaded images
        image_upload_btn.upload(
            fn=lambda files: files,
            inputs=image_upload_btn,
            outputs=images
        )
        
        # Set up provider selection event
        provider_choice.change(
            fn=toggle_model_dropdowns,
            inputs=provider_choice,
            outputs=[
                openrouter_model,
                openai_model,
                hf_model,
                groq_model,
                cohere_model,
                together_model,
                anthropic_model,
                poe_model,
                googleai_model
            ]
        ).then(
            fn=update_context_for_provider,
            inputs=[
                provider_choice, 
                openrouter_model, 
                openai_model, 
                hf_model, 
                groq_model, 
                cohere_model, 
                together_model, 
                anthropic_model, 
                poe_model, 
                googleai_model
            ],
            outputs=context_display
        ).then(
            fn=update_model_info_for_provider,
            inputs=[
                provider_choice, 
                openrouter_model, 
                openai_model, 
                hf_model, 
                groq_model, 
                cohere_model, 
                together_model, 
                anthropic_model, 
                poe_model, 
                googleai_model
            ],
            outputs=model_info_display
        ).then(
            # Fix this with correct number of args using a simpler approach
            fn=lambda provider: update_vision_indicator(provider, None),
            inputs=provider_choice,
            outputs=is_vision_indicator
        ).then(
            # Same here
            fn=lambda provider: update_image_upload_visibility(provider, None),
            inputs=provider_choice,
            outputs=image_upload_container
        )
        
        # Set up model search event - return model dropdown updates
        model_search.change(
            fn=lambda provider, search: [
                search_openrouter_models(search) if provider == "OpenRouter" else gr.update(),
                search_openai_models(search) if provider == "OpenAI" else gr.update(),
                search_hf_models(search) if provider == "HuggingFace" else gr.update(),
                search_groq_models(search) if provider == "Groq" else gr.update(),
                search_cohere_models(search) if provider == "Cohere" else gr.update(),
                search_together_models(search) if provider == "Together" else gr.update(),
                search_anthropic_models(search) if provider == "Anthropic" else gr.update(),
                search_poe_models(search) if provider == "Poe" else gr.update(),
                search_googleai_models(search) if provider == "GoogleAI" else gr.update()
            ],
            inputs=[provider_choice, model_search],
            outputs=[
                openrouter_model, openai_model, hf_model, groq_model, 
                cohere_model, together_model, anthropic_model, poe_model, googleai_model
            ]
        )
        
        # Set up model change events to update context display and model info
        openrouter_model.change(
            fn=lambda model: update_context_display("OpenRouter", model),
            inputs=openrouter_model,
            outputs=context_display
        ).then(
            fn=lambda model: update_model_info("OpenRouter", model),
            inputs=openrouter_model,
            outputs=model_info_display
        ).then(
            fn=lambda model: update_vision_indicator("OpenRouter", model),
            inputs=openrouter_model,
            outputs=is_vision_indicator
        ).then(
            fn=lambda model: update_image_upload_visibility("OpenRouter", model),
            inputs=openrouter_model,
            outputs=image_upload_container
        )
        # Event handler for Poe model change
        poe_model.change(
            fn=lambda model: update_context_display("Poe", model),
            inputs=poe_model,
            outputs=context_display
        ).then(
            fn=lambda model: update_model_info("Poe", model),
            inputs=poe_model,
            outputs=model_info_display
        ).then(
            fn=lambda model: update_vision_indicator("Poe", model),
            inputs=poe_model,
            outputs=is_vision_indicator
        ).then(
            fn=lambda model: update_image_upload_visibility("Poe", model),
            inputs=poe_model,
            outputs=image_upload_container
        )
        
        openai_model.change(
            fn=lambda model: update_context_display("OpenAI", model),
            inputs=openai_model,
            outputs=context_display
        ).then(
            fn=lambda model: update_model_info("OpenAI", model),
            inputs=openai_model,
            outputs=model_info_display
        ).then(
            fn=lambda model: update_vision_indicator("OpenAI", model),
            inputs=openai_model,
            outputs=is_vision_indicator
        ).then(
            fn=lambda model: update_image_upload_visibility("OpenAI", model),
            inputs=openai_model,
            outputs=image_upload_container
        )
        
        hf_model.change(
            fn=lambda model: update_context_display("HuggingFace", model),
            inputs=hf_model,
            outputs=context_display
        ).then(
            fn=lambda model: update_model_info("HuggingFace", model),
            inputs=hf_model,
            outputs=model_info_display
        ).then(
            fn=lambda model: update_vision_indicator("HuggingFace", model),
            inputs=hf_model,
            outputs=is_vision_indicator
        ).then(
            fn=lambda model: update_image_upload_visibility("HuggingFace", model),
            inputs=hf_model,
            outputs=image_upload_container
        )
        
        groq_model.change(
            fn=lambda model: update_context_display("Groq", model),
            inputs=groq_model,
            outputs=context_display
        ).then(
            fn=lambda model: update_model_info("Groq", model),
            inputs=groq_model,
            outputs=model_info_display
        ).then(
            fn=lambda model: update_vision_indicator("Groq", model),
            inputs=groq_model,
            outputs=is_vision_indicator
        ).then(
            fn=lambda model: update_image_upload_visibility("Groq", model),
            inputs=groq_model,
            outputs=image_upload_container
        )
        
        cohere_model.change(
            fn=lambda model: update_context_display("Cohere", model),
            inputs=cohere_model,
            outputs=context_display
        ).then(
            fn=lambda model: update_model_info("Cohere", model),
            inputs=cohere_model,
            outputs=model_info_display
        ).then(
            fn=lambda model: update_vision_indicator("Cohere", model),
            inputs=cohere_model,
            outputs=is_vision_indicator
        ).then(
            fn=lambda model: update_image_upload_visibility("Cohere", model),
            inputs=cohere_model,
            outputs=image_upload_container
        )
        
        together_model.change(
            fn=lambda model: update_context_display("Together", model),
            inputs=together_model,
            outputs=context_display
        ).then(
            fn=lambda model: update_model_info("Together", model),
            inputs=together_model,
            outputs=model_info_display
        ).then(
            fn=lambda model: update_vision_indicator("Together", model),
            inputs=together_model,
            outputs=is_vision_indicator
        ).then(
            fn=lambda model: update_image_upload_visibility("Together", model),
            inputs=together_model,
            outputs=image_upload_container
        )
        
        anthropic_model.change(
            fn=lambda model: update_context_display("Anthropic", model),
            inputs=anthropic_model,
            outputs=context_display
        ).then(
            fn=lambda model: update_model_info("Anthropic", model),
            inputs=anthropic_model,
            outputs=model_info_display
        ).then(
            fn=lambda model: update_vision_indicator("Anthropic", model),
            inputs=anthropic_model,
            outputs=is_vision_indicator
        ).then(
            fn=lambda model: update_image_upload_visibility("Anthropic", model),
            inputs=anthropic_model,
            outputs=image_upload_container
        )
        
        googleai_model.change(
            fn=lambda model: update_context_display("GoogleAI", model),
            inputs=googleai_model,
            outputs=context_display
        ).then(
            fn=lambda model: update_model_info("GoogleAI", model),
            inputs=googleai_model,
            outputs=model_info_display
        ).then(
            fn=lambda model: update_vision_indicator("GoogleAI", model),
            inputs=googleai_model,
            outputs=is_vision_indicator
        ).then(
            fn=lambda model: update_image_upload_visibility("GoogleAI", model),
            inputs=googleai_model,
            outputs=image_upload_container
        )
        def handle_search(provider, search_term):
            """Handle search based on provider"""
            if provider == "OpenRouter":
                return search_openrouter_models(search_term)
            elif provider == "OpenAI":
                return search_openai_models(search_term)
            elif provider == "HuggingFace":
                return search_hf_models(search_term)
            elif provider == "Groq":
                return search_groq_models(search_term)
            elif provider == "Cohere":
                return search_cohere_models(search_term)
            elif provider == "Together":
                return search_together_models(search_term)
            elif provider == "Anthropic":
                return search_anthropic_models(search_term)
            elif provider == "GoogleAI":
                return search_googleai_models(search_term)
            return None
        
        # Set up submission event
        def submit_message(message, history, provider, 
                openrouter_model, openai_model, hf_model, groq_model, cohere_model, together_model, anthropic_model, poe_model, googleai_model, 
                temperature, max_tokens, top_p, frequency_penalty, presence_penalty, repetition_penalty, 
                top_k, min_p, seed, top_a, stream_output, response_format,
                images, documents, reasoning_effort, system_message, transforms, 
                openrouter_api_key, openai_api_key, hf_api_key, groq_api_key, cohere_api_key, together_api_key, anthropic_api_key, poe_api_key, googleai_api_key):
            """Submit message to selected provider and model"""
            # Get the currently selected model
            model_choice = get_current_model(provider, openrouter_model, openai_model, hf_model, groq_model, cohere_model, 
                                    together_model, anthropic_model, poe_model, googleai_model)
            
            # Check if model is selected
            if not model_choice:
                error_message = f"Error: No model selected for provider {provider}"
                return history + [
                    {"role": "user", "content": message},
                    {"role": "assistant", "content": error_message}
                ]
            
            # Select the appropriate API key based on the provider
            api_key_override = None
            if provider == "OpenRouter" and openrouter_api_key:
                api_key_override = openrouter_api_key
            elif provider == "OpenAI" and openai_api_key:
                api_key_override = openai_api_key
            elif provider == "HuggingFace" and hf_api_key:
                api_key_override = hf_api_key
            elif provider == "Groq" and groq_api_key:
                api_key_override = groq_api_key
            elif provider == "Cohere" and cohere_api_key:
                api_key_override = cohere_api_key
            elif provider == "Together" and together_api_key:
                api_key_override = together_api_key
            elif provider == "Anthropic" and anthropic_api_key:
                api_key_override = anthropic_api_key
            elif provider == "Poe" and poe_api_key:
                api_key_override = poe_api_key
            elif provider == "GoogleAI" and googleai_api_key:
                api_key_override = googleai_api_key
            
            # Call the ask_ai function with the appropriate parameters
            return ask_ai(
                message=message,
                history=history,
                provider=provider,
                model_choice=model_choice,
                temperature=temperature,
                max_tokens=max_tokens,
                top_p=top_p,
                frequency_penalty=frequency_penalty,
                presence_penalty=presence_penalty,
                repetition_penalty=repetition_penalty,
                top_k=top_k,
                min_p=min_p,
                seed=seed,
                top_a=top_a,
                stream_output=stream_output,
                response_format=response_format,
                images=images,
                documents=documents,
                reasoning_effort=reasoning_effort,
                system_message=system_message,
                transforms=transforms,
                api_key_override=api_key_override
            )
        def clean_message(message):
            """Clean the message from style tags"""
            if isinstance(message, str):
                import re
                # Remove style tags
                message = re.sub(r'.*?', '', message)
            return message
        
        # Submit button click event
        submit_btn.click(
            fn=lambda *args: submit_message(clean_message(args[0]), *args[1:]),
            inputs=[
                message, chatbot, provider_choice, 
                openrouter_model, openai_model, hf_model, groq_model, cohere_model, 
                together_model, anthropic_model, poe_model, googleai_model,
                temperature, max_tokens, top_p, frequency_penalty, presence_penalty, repetition_penalty, 
                top_k, min_p, seed, top_a, stream_output, response_format,
                images, documents, reasoning_effort, system_message, transforms,
                openrouter_api_key, openai_api_key, hf_api_key, groq_api_key, cohere_api_key, 
                together_api_key, anthropic_api_key, poe_api_key, googleai_api_key
            ],
            outputs=chatbot,
            show_progress="minimal",
        ).then(
            fn=lambda: "",  # Clear message box after sending
            inputs=None,
            outputs=message
        )
        
        # Also submit on Enter key
        message.submit(
            fn=submit_message,
            inputs=[
                message, chatbot, provider_choice, 
                openrouter_model, openai_model, hf_model, groq_model, cohere_model, 
                together_model, anthropic_model, poe_model, googleai_model,
                temperature, max_tokens, top_p, frequency_penalty, presence_penalty, repetition_penalty, 
                top_k, min_p, seed, top_a, stream_output, response_format,
                images, documents, reasoning_effort, system_message, transforms,
                openrouter_api_key, openai_api_key, hf_api_key, groq_api_key, cohere_api_key, 
                together_api_key, anthropic_api_key, poe_api_key, googleai_api_key
            ],
            outputs=chatbot,
            show_progress="minimal",
        ).then(
            fn=lambda: "",  # Clear message box after sending
            inputs=None,
            outputs=message
        )
        
        # Clear chat button
        clear_btn.click(
            fn=clear_chat,
            inputs=[],
            outputs=[
                chatbot, message, images, documents, temperature, 
                max_tokens, top_p, frequency_penalty, presence_penalty,
                repetition_penalty, top_k, min_p, seed, top_a, stream_output,
                response_format, reasoning_effort, system_message, transforms
            ]
        )
        
        return demo
# Launch the app
if __name__ == "__main__":
    # Check API keys and print status
    missing_keys = []
    
    if not OPENROUTER_API_KEY:
        logger.warning("WARNING: OPENROUTER_API_KEY environment variable is not set")
        missing_keys.append("OpenRouter")
    # Add Poe
    if not POE_API_KEY:
        logger.warning("WARNING: POE_API_KEY environment variable is not set")
        missing_keys.append("Poe")
    if not ANTHROPIC_API_KEY:
        logger.warning("WARNING: ANTHROPIC_API_KEY environment variable is not set")
        missing_keys.append("Anthropic")
    
    if not OPENAI_API_KEY:
        logger.warning("WARNING: OPENAI_API_KEY environment variable is not set")
        missing_keys.append("OpenAI")
        
    if not GROQ_API_KEY:
        logger.warning("WARNING: GROQ_API_KEY environment variable is not set")
        missing_keys.append("Groq")
        
    if not COHERE_API_KEY:
        logger.warning("WARNING: COHERE_API_KEY environment variable is not set")
        missing_keys.append("Cohere")
        
    if not TOGETHER_API_KEY:
        logger.warning("WARNING: TOGETHER_API_KEY environment variable is not set")
        missing_keys.append("Together")
        
    if not GOOGLEAI_API_KEY:
        logger.warning("WARNING: GOOGLEAI_API_KEY environment variable is not set")
        missing_keys.append("GoogleAI")
        
    if missing_keys:
        print("Missing API keys for the following providers:")
        for key in missing_keys:
            print(f"- {key}")
        print("\nYou can still use the application, but some providers will require API keys.")
        print("You can provide API keys through environment variables or use the API Key Override field.")
        
        if "OpenRouter" in missing_keys:
            print("\nNote: OpenRouter offers free tier access to many models!")
        
        #if "OVH" not in missing_keys and "Cerebras" not in missing_keys:
        #    print("\nNote: OVH AI Endpoints (beta) and Cerebras offer free usage tiers!")
            
    print("\nStarting CrispChat application...")
    demo = create_app()
    demo.launch(
        server_name="0.0.0.0", 
        server_port=7860, 
        debug=True,
        show_error=True
    )