Update app.py
Browse files
app.py
CHANGED
|
@@ -7,6 +7,10 @@ import logging
|
|
| 7 |
import io
|
| 8 |
import time
|
| 9 |
from typing import List, Dict, Any, Union, Tuple, Optional
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
# Configure logging
|
| 12 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
|
@@ -220,7 +224,7 @@ OPENAI_MODELS = {
|
|
| 220 |
"o1-mini-2024-09-12": 128000,
|
| 221 |
}
|
| 222 |
|
| 223 |
-
# HUGGINGFACE MODELS
|
| 224 |
HUGGINGFACE_MODELS = {
|
| 225 |
"microsoft/phi-3-mini-4k-instruct": 4096,
|
| 226 |
"microsoft/Phi-3-mini-128k-instruct": 131072,
|
|
@@ -509,9 +513,7 @@ def filter_models(provider, search_term):
|
|
| 509 |
if filtered_models:
|
| 510 |
return filtered_models, filtered_models[0]
|
| 511 |
else:
|
| 512 |
-
return
|
| 513 |
-
|
| 514 |
-
return all_models, all_models[0] if all_models else None
|
| 515 |
|
| 516 |
def get_model_info(provider, model_choice):
|
| 517 |
"""Get model ID and context size based on provider and model name"""
|
|
@@ -1688,14 +1690,14 @@ def create_app():
|
|
| 1688 |
# Define event handlers
|
| 1689 |
def toggle_model_dropdowns(provider):
|
| 1690 |
"""Show/hide model dropdowns based on provider selection"""
|
| 1691 |
-
return
|
| 1692 |
-
|
| 1693 |
-
|
| 1694 |
-
|
| 1695 |
-
|
| 1696 |
-
|
| 1697 |
-
|
| 1698 |
-
|
| 1699 |
|
| 1700 |
def update_context_for_provider(provider, openrouter_model, openai_model, hf_model, groq_model, cohere_model, glhf_model):
|
| 1701 |
"""Update context display based on selected provider and model"""
|
|
@@ -1728,33 +1730,68 @@ def create_app():
|
|
| 1728 |
elif provider == "GLHF":
|
| 1729 |
return update_model_info(provider, glhf_model)
|
| 1730 |
return "<p>Model information not available</p>"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1731 |
|
| 1732 |
-
def filter_provider_models(provider, search_term):
|
| 1733 |
-
"""Filter models for the selected provider"""
|
| 1734 |
if provider == "OpenRouter":
|
| 1735 |
all_models = [model[0] for model in OPENROUTER_ALL_MODELS]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1736 |
elif provider == "OpenAI":
|
| 1737 |
all_models = list(OPENAI_MODELS.keys())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1738 |
elif provider == "HuggingFace":
|
| 1739 |
all_models = list(HUGGINGFACE_MODELS.keys())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1740 |
elif provider == "Groq":
|
| 1741 |
all_models = list(GROQ_MODELS.keys())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1742 |
elif provider == "Cohere":
|
| 1743 |
all_models = list(COHERE_MODELS.keys())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1744 |
elif provider == "GLHF":
|
| 1745 |
all_models = list(GLHF_MODELS.keys())
|
| 1746 |
-
|
| 1747 |
-
|
|
|
|
|
|
|
| 1748 |
|
| 1749 |
-
|
| 1750 |
-
return all_models, all_models[0] if all_models else None
|
| 1751 |
|
| 1752 |
-
|
| 1753 |
-
|
| 1754 |
-
if filtered_models:
|
| 1755 |
-
return filtered_models, filtered_models[0]
|
| 1756 |
-
else:
|
| 1757 |
-
return all_models, all_models[0] if all_models else None
|
| 1758 |
|
| 1759 |
def refresh_groq_models_list():
|
| 1760 |
"""Refresh the list of Groq models"""
|
|
@@ -1800,14 +1837,12 @@ def create_app():
|
|
| 1800 |
outputs=model_info_display
|
| 1801 |
)
|
| 1802 |
|
| 1803 |
-
# Set up model search event
|
|
|
|
| 1804 |
model_search.change(
|
| 1805 |
-
fn=
|
| 1806 |
inputs=[provider_choice, model_search],
|
| 1807 |
-
outputs=[
|
| 1808 |
-
gr.update(choices=None, value=None),
|
| 1809 |
-
gr.update(choices=None, value=None)
|
| 1810 |
-
]
|
| 1811 |
)
|
| 1812 |
|
| 1813 |
# Set up model change events
|
|
@@ -1871,6 +1906,25 @@ def create_app():
|
|
| 1871 |
outputs=model_info_display
|
| 1872 |
)
|
| 1873 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1874 |
# Set up submission event
|
| 1875 |
def submit_message(message, history, provider, openrouter_model, openai_model, hf_model, groq_model, cohere_model, glhf_model,
|
| 1876 |
temperature, max_tokens, top_p, frequency_penalty, presence_penalty, repetition_penalty,
|
|
@@ -1963,11 +2017,40 @@ def create_app():
|
|
| 1963 |
|
| 1964 |
# Launch the app
|
| 1965 |
if __name__ == "__main__":
|
| 1966 |
-
# Check API keys
|
|
|
|
|
|
|
| 1967 |
if not OPENROUTER_API_KEY:
|
| 1968 |
logger.warning("WARNING: OPENROUTER_API_KEY environment variable is not set")
|
| 1969 |
-
|
| 1970 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1971 |
demo = create_app()
|
| 1972 |
demo.launch(
|
| 1973 |
server_name="0.0.0.0",
|
|
|
|
| 7 |
import io
|
| 8 |
import time
|
| 9 |
from typing import List, Dict, Any, Union, Tuple, Optional
|
| 10 |
+
from dotenv import load_dotenv
|
| 11 |
+
|
| 12 |
+
# Load environment variables from .env file
|
| 13 |
+
load_dotenv()
|
| 14 |
|
| 15 |
# Configure logging
|
| 16 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
|
|
|
| 224 |
"o1-mini-2024-09-12": 128000,
|
| 225 |
}
|
| 226 |
|
| 227 |
+
# HUGGINGFACE MODELS
|
| 228 |
HUGGINGFACE_MODELS = {
|
| 229 |
"microsoft/phi-3-mini-4k-instruct": 4096,
|
| 230 |
"microsoft/Phi-3-mini-128k-instruct": 131072,
|
|
|
|
| 513 |
if filtered_models:
|
| 514 |
return filtered_models, filtered_models[0]
|
| 515 |
else:
|
| 516 |
+
return all_models, all_models[0] if all_models else None
|
|
|
|
|
|
|
| 517 |
|
| 518 |
def get_model_info(provider, model_choice):
|
| 519 |
"""Get model ID and context size based on provider and model name"""
|
|
|
|
| 1690 |
# Define event handlers
|
| 1691 |
def toggle_model_dropdowns(provider):
|
| 1692 |
"""Show/hide model dropdowns based on provider selection"""
|
| 1693 |
+
return [
|
| 1694 |
+
gr.update(visible=(provider == "OpenRouter")),
|
| 1695 |
+
gr.update(visible=(provider == "OpenAI")),
|
| 1696 |
+
gr.update(visible=(provider == "HuggingFace")),
|
| 1697 |
+
gr.update(visible=(provider == "Groq")),
|
| 1698 |
+
gr.update(visible=(provider == "Cohere")),
|
| 1699 |
+
gr.update(visible=(provider == "GLHF"))
|
| 1700 |
+
]
|
| 1701 |
|
| 1702 |
def update_context_for_provider(provider, openrouter_model, openai_model, hf_model, groq_model, cohere_model, glhf_model):
|
| 1703 |
"""Update context display based on selected provider and model"""
|
|
|
|
| 1730 |
elif provider == "GLHF":
|
| 1731 |
return update_model_info(provider, glhf_model)
|
| 1732 |
return "<p>Model information not available</p>"
|
| 1733 |
+
|
| 1734 |
+
# Handling model search function - Fixed compared to previous implementation
|
| 1735 |
+
def search_models(provider, search_term):
|
| 1736 |
+
"""Filter models for the selected provider based on search term"""
|
| 1737 |
+
filtered_models = []
|
| 1738 |
|
|
|
|
|
|
|
| 1739 |
if provider == "OpenRouter":
|
| 1740 |
all_models = [model[0] for model in OPENROUTER_ALL_MODELS]
|
| 1741 |
+
if search_term:
|
| 1742 |
+
filtered_models = [model for model in all_models if search_term.lower() in model.lower()]
|
| 1743 |
+
else:
|
| 1744 |
+
filtered_models = all_models
|
| 1745 |
+
|
| 1746 |
+
return gr.update(choices=filtered_models, value=filtered_models[0] if filtered_models else None)
|
| 1747 |
+
|
| 1748 |
elif provider == "OpenAI":
|
| 1749 |
all_models = list(OPENAI_MODELS.keys())
|
| 1750 |
+
if search_term:
|
| 1751 |
+
filtered_models = [model for model in all_models if search_term.lower() in model.lower()]
|
| 1752 |
+
else:
|
| 1753 |
+
filtered_models = all_models
|
| 1754 |
+
|
| 1755 |
+
return gr.update(choices=filtered_models, value=filtered_models[0] if filtered_models else None)
|
| 1756 |
+
|
| 1757 |
elif provider == "HuggingFace":
|
| 1758 |
all_models = list(HUGGINGFACE_MODELS.keys())
|
| 1759 |
+
if search_term:
|
| 1760 |
+
filtered_models = [model for model in all_models if search_term.lower() in model.lower()]
|
| 1761 |
+
else:
|
| 1762 |
+
filtered_models = all_models
|
| 1763 |
+
|
| 1764 |
+
return gr.update(choices=filtered_models, value=filtered_models[0] if filtered_models else None)
|
| 1765 |
+
|
| 1766 |
elif provider == "Groq":
|
| 1767 |
all_models = list(GROQ_MODELS.keys())
|
| 1768 |
+
if search_term:
|
| 1769 |
+
filtered_models = [model for model in all_models if search_term.lower() in model.lower()]
|
| 1770 |
+
else:
|
| 1771 |
+
filtered_models = all_models
|
| 1772 |
+
|
| 1773 |
+
return gr.update(choices=filtered_models, value=filtered_models[0] if filtered_models else None)
|
| 1774 |
+
|
| 1775 |
elif provider == "Cohere":
|
| 1776 |
all_models = list(COHERE_MODELS.keys())
|
| 1777 |
+
if search_term:
|
| 1778 |
+
filtered_models = [model for model in all_models if search_term.lower() in model.lower()]
|
| 1779 |
+
else:
|
| 1780 |
+
filtered_models = all_models
|
| 1781 |
+
|
| 1782 |
+
return gr.update(choices=filtered_models, value=filtered_models[0] if filtered_models else None)
|
| 1783 |
+
|
| 1784 |
elif provider == "GLHF":
|
| 1785 |
all_models = list(GLHF_MODELS.keys())
|
| 1786 |
+
if search_term:
|
| 1787 |
+
filtered_models = [model for model in all_models if search_term.lower() in model.lower()]
|
| 1788 |
+
else:
|
| 1789 |
+
filtered_models = all_models
|
| 1790 |
|
| 1791 |
+
return gr.update(choices=filtered_models, value=filtered_models[0] if filtered_models else None)
|
|
|
|
| 1792 |
|
| 1793 |
+
# Default return in case of unknown provider
|
| 1794 |
+
return gr.update(choices=[], value=None)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1795 |
|
| 1796 |
def refresh_groq_models_list():
|
| 1797 |
"""Refresh the list of Groq models"""
|
|
|
|
| 1837 |
outputs=model_info_display
|
| 1838 |
)
|
| 1839 |
|
| 1840 |
+
# Set up model search event - FIXED VERSION
|
| 1841 |
+
# Important: We need to return a proper Gradio component update for each dropdown
|
| 1842 |
model_search.change(
|
| 1843 |
+
fn=search_models,
|
| 1844 |
inputs=[provider_choice, model_search],
|
| 1845 |
+
outputs=[openrouter_model] # This will be handled by the JS forwarding logic
|
|
|
|
|
|
|
|
|
|
| 1846 |
)
|
| 1847 |
|
| 1848 |
# Set up model change events
|
|
|
|
| 1906 |
outputs=model_info_display
|
| 1907 |
)
|
| 1908 |
|
| 1909 |
+
# Add custom JavaScript for routing model search to visible dropdown
|
| 1910 |
+
gr.HTML("""
|
| 1911 |
+
<script>
|
| 1912 |
+
// To be triggered after page load
|
| 1913 |
+
document.addEventListener('DOMContentLoaded', function() {
|
| 1914 |
+
// Find dropdowns
|
| 1915 |
+
const providerRadio = document.querySelector('input[name="provider_choice"]');
|
| 1916 |
+
const searchInput = document.getElementById('model_search');
|
| 1917 |
+
|
| 1918 |
+
if (providerRadio && searchInput) {
|
| 1919 |
+
// When provider changes, clear the search
|
| 1920 |
+
providerRadio.addEventListener('change', function() {
|
| 1921 |
+
searchInput.value = '';
|
| 1922 |
+
});
|
| 1923 |
+
}
|
| 1924 |
+
});
|
| 1925 |
+
</script>
|
| 1926 |
+
""")
|
| 1927 |
+
|
| 1928 |
# Set up submission event
|
| 1929 |
def submit_message(message, history, provider, openrouter_model, openai_model, hf_model, groq_model, cohere_model, glhf_model,
|
| 1930 |
temperature, max_tokens, top_p, frequency_penalty, presence_penalty, repetition_penalty,
|
|
|
|
| 2017 |
|
| 2018 |
# Launch the app
|
| 2019 |
if __name__ == "__main__":
|
| 2020 |
+
# Check API keys and print status
|
| 2021 |
+
missing_keys = []
|
| 2022 |
+
|
| 2023 |
if not OPENROUTER_API_KEY:
|
| 2024 |
logger.warning("WARNING: OPENROUTER_API_KEY environment variable is not set")
|
| 2025 |
+
missing_keys.append("OpenRouter")
|
| 2026 |
|
| 2027 |
+
if not OPENAI_API_KEY:
|
| 2028 |
+
logger.warning("WARNING: OPENAI_API_KEY environment variable is not set")
|
| 2029 |
+
missing_keys.append("OpenAI")
|
| 2030 |
+
|
| 2031 |
+
if not GROQ_API_KEY:
|
| 2032 |
+
logger.warning("WARNING: GROQ_API_KEY environment variable is not set")
|
| 2033 |
+
missing_keys.append("Groq")
|
| 2034 |
+
|
| 2035 |
+
if not COHERE_API_KEY:
|
| 2036 |
+
logger.warning("WARNING: COHERE_API_KEY environment variable is not set")
|
| 2037 |
+
missing_keys.append("Cohere")
|
| 2038 |
+
|
| 2039 |
+
if not GLHF_API_KEY:
|
| 2040 |
+
logger.warning("WARNING: GLHF_API_KEY environment variable is not set")
|
| 2041 |
+
missing_keys.append("GLHF")
|
| 2042 |
+
|
| 2043 |
+
if missing_keys:
|
| 2044 |
+
print("Missing API keys for the following providers:")
|
| 2045 |
+
for key in missing_keys:
|
| 2046 |
+
print(f"- {key}")
|
| 2047 |
+
print("\nYou can still use the application, but some providers will require API keys.")
|
| 2048 |
+
print("You can provide API keys through environment variables or use the API Key Override field.")
|
| 2049 |
+
|
| 2050 |
+
if "OpenRouter" in missing_keys:
|
| 2051 |
+
print("\nNote: OpenRouter offers free tier access to many models!")
|
| 2052 |
+
|
| 2053 |
+
print("\nStarting Multi-Provider CrispChat application...")
|
| 2054 |
demo = create_app()
|
| 2055 |
demo.launch(
|
| 2056 |
server_name="0.0.0.0",
|