Spaces:
Sleeping
Sleeping
updated app
Browse files- .env.example +14 -0
- CHAT_GUIDE.md +95 -0
- Dockerfile +35 -0
- MODEL_SETUP.md +152 -0
- README_MODELS.md +184 -0
- app.py +300 -24
- app_chat.py +295 -0
- app_original.py +36 -0
- app_reserve.py +296 -0
- chat_app.py +284 -0
- config.json +23 -2
- docker-compose.yml +23 -0
- download_models.py +142 -0
- requirements.txt +5 -0
- startup.py +55 -0
- test_imports.py +46 -0
- test_models.py +182 -0
.env.example
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Environment variables for Policy Analysis Application
|
| 2 |
+
# Copy this file to .env and update the values
|
| 3 |
+
|
| 4 |
+
# API Configuration
|
| 5 |
+
API_KEY=your_api_key_here
|
| 6 |
+
|
| 7 |
+
# Model Configuration
|
| 8 |
+
MODEL=llama3.3-70b-instruct
|
| 9 |
+
|
| 10 |
+
# Model Pre-loading
|
| 11 |
+
PRELOAD_MODELS=true
|
| 12 |
+
|
| 13 |
+
# HuggingFace Configuration (optional)
|
| 14 |
+
# HF_TOKEN=your_huggingface_token_here
|
CHAT_GUIDE.md
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Chat Interface Usage Guide
|
| 2 |
+
|
| 3 |
+
## π― **New Chat Features**
|
| 4 |
+
|
| 5 |
+
Your Policy Analysis application now has a conversational interface! Here's what you can do:
|
| 6 |
+
|
| 7 |
+
### π¬ **How to Use the Chat**
|
| 8 |
+
|
| 9 |
+
1. **Ask Your First Question**
|
| 10 |
+
```
|
| 11 |
+
"What are Kenya's renewable energy policies?"
|
| 12 |
+
```
|
| 13 |
+
|
| 14 |
+
2. **Follow Up with Related Questions**
|
| 15 |
+
```
|
| 16 |
+
"What penalties exist for non-compliance?"
|
| 17 |
+
"How does this relate to environmental protection?"
|
| 18 |
+
"Can you explain more about the licensing requirements?"
|
| 19 |
+
```
|
| 20 |
+
|
| 21 |
+
3. **Reference Previous Responses**
|
| 22 |
+
```
|
| 23 |
+
"What does this mean in practice?"
|
| 24 |
+
"Can you elaborate on the point about penalties?"
|
| 25 |
+
"How do these regulations compare to what you mentioned earlier?"
|
| 26 |
+
```
|
| 27 |
+
|
| 28 |
+
### π **Conversation Flow Example**
|
| 29 |
+
|
| 30 |
+
**You:** "What are Kenya's energy policies regarding renewable sources?"
|
| 31 |
+
|
| 32 |
+
**Assistant:** *[Provides detailed information about renewable energy policies with quotes and sources]*
|
| 33 |
+
|
| 34 |
+
**You:** "What are the penalties for not following these policies?"
|
| 35 |
+
|
| 36 |
+
**Assistant:** *[Builds on the previous context and explains penalties specifically]*
|
| 37 |
+
|
| 38 |
+
**You:** "How do I apply for a renewable energy license?"
|
| 39 |
+
|
| 40 |
+
**Assistant:** *[Continues the conversation with licensing information]*
|
| 41 |
+
|
| 42 |
+
### βοΈ **Advanced Features**
|
| 43 |
+
|
| 44 |
+
- **π Sentiment Analysis**: Toggle on/off to analyze the tone of policy documents
|
| 45 |
+
- **π Coherence Analysis**: Toggle on/off to check document relevance and consistency
|
| 46 |
+
- **πΎ Chat History**: The assistant remembers your conversation for better context
|
| 47 |
+
- **π Copy Responses**: Click the copy button on any response
|
| 48 |
+
- **π Share Responses**: Share interesting responses using the share button
|
| 49 |
+
|
| 50 |
+
### π¨ **Interface Elements**
|
| 51 |
+
|
| 52 |
+
- **Chat Bubbles**: User messages (π€) and assistant responses (π€)
|
| 53 |
+
- **Settings Panel**: Control sentiment and coherence analysis
|
| 54 |
+
- **Clear Chat**: Start a fresh conversation
|
| 55 |
+
- **Analysis Status**: See which features are currently enabled
|
| 56 |
+
|
| 57 |
+
### π‘ **Tips for Better Conversations**
|
| 58 |
+
|
| 59 |
+
1. **Be Specific**: Ask about particular aspects of policies
|
| 60 |
+
2. **Build Context**: Ask follow-up questions that reference previous answers
|
| 61 |
+
3. **Use Natural Language**: Talk as you would to a human expert
|
| 62 |
+
4. **Reference Sources**: Ask for more details about quoted sources
|
| 63 |
+
|
| 64 |
+
### π **Example Conversation Starters**
|
| 65 |
+
|
| 66 |
+
**Policy Research:**
|
| 67 |
+
- "What are the main objectives of Kenya's water management policies?"
|
| 68 |
+
- "Tell me about environmental compliance requirements"
|
| 69 |
+
|
| 70 |
+
**Follow-up Questions:**
|
| 71 |
+
- "What does this mean for small businesses?"
|
| 72 |
+
- "Can you explain the implementation process?"
|
| 73 |
+
- "What are the timelines mentioned?"
|
| 74 |
+
|
| 75 |
+
**Comparative Questions:**
|
| 76 |
+
- "How does this compare to energy policies?"
|
| 77 |
+
- "Are there similar requirements in other sectors?"
|
| 78 |
+
|
| 79 |
+
### π **Getting Started**
|
| 80 |
+
|
| 81 |
+
1. Start the application: `python app.py`
|
| 82 |
+
2. Open your browser to the provided URL
|
| 83 |
+
3. Begin with a general question about Kenya policies
|
| 84 |
+
4. Use follow-up questions to dive deeper
|
| 85 |
+
5. Toggle analysis features as needed
|
| 86 |
+
|
| 87 |
+
### π§ **Settings Explained**
|
| 88 |
+
|
| 89 |
+
- **Sentiment Analysis ON**: Get insights into the tone and intent of policy text
|
| 90 |
+
- **Coherence Analysis ON**: Verify that retrieved documents are relevant and consistent
|
| 91 |
+
- **Both OFF**: Faster responses with just policy content and analysis
|
| 92 |
+
|
| 93 |
+
---
|
| 94 |
+
|
| 95 |
+
**Note**: The chat maintains context from your conversation, so each response builds on what was discussed earlier, making it feel more natural and helpful!
|
Dockerfile
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Dockerfile for Policy Analysis Application
|
| 2 |
+
FROM python:3.9-slim
|
| 3 |
+
|
| 4 |
+
# Set working directory
|
| 5 |
+
WORKDIR /app
|
| 6 |
+
|
| 7 |
+
# Install system dependencies
|
| 8 |
+
RUN apt-get update && apt-get install -y \
|
| 9 |
+
git \
|
| 10 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 11 |
+
|
| 12 |
+
# Copy requirements first for better caching
|
| 13 |
+
COPY requirements.txt .
|
| 14 |
+
|
| 15 |
+
# Install Python dependencies
|
| 16 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 17 |
+
|
| 18 |
+
# Copy application code
|
| 19 |
+
COPY . .
|
| 20 |
+
|
| 21 |
+
# Create cache directory for models
|
| 22 |
+
RUN mkdir -p /root/.cache/huggingface
|
| 23 |
+
|
| 24 |
+
# Download models during build phase (this will cache them in the image)
|
| 25 |
+
RUN echo "π Pre-downloading models during image build..." && \
|
| 26 |
+
python download_models.py
|
| 27 |
+
|
| 28 |
+
# Set environment variable to skip model preloading in app since they're already cached
|
| 29 |
+
ENV PRELOAD_MODELS=false
|
| 30 |
+
|
| 31 |
+
# Expose port
|
| 32 |
+
EXPOSE 7860
|
| 33 |
+
|
| 34 |
+
# Run the application
|
| 35 |
+
CMD ["python", "app.py"]
|
MODEL_SETUP.md
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Model Pre-loading Setup Guide
|
| 2 |
+
|
| 3 |
+
This guide explains how to set up the Policy Analysis application with pre-downloaded models to reduce inference latency.
|
| 4 |
+
|
| 5 |
+
## Overview
|
| 6 |
+
|
| 7 |
+
The application uses several ML models:
|
| 8 |
+
- **Embedding Models**: `sentence-transformers/all-MiniLM-L6-v2`, `BAAI/bge-m3`
|
| 9 |
+
- **Cross-Encoder**: `cross-encoder/ms-marco-MiniLM-L-6-v2`
|
| 10 |
+
- **Zero-shot Classification**: `MoritzLaurer/deberta-v3-base-zeroshot-v2.0`
|
| 11 |
+
|
| 12 |
+
## Deployment Options
|
| 13 |
+
|
| 14 |
+
### Option 1: Docker Deployment (Recommended)
|
| 15 |
+
|
| 16 |
+
Models are automatically downloaded during the Docker image build process:
|
| 17 |
+
|
| 18 |
+
```bash
|
| 19 |
+
# Build and run with docker-compose
|
| 20 |
+
docker-compose up --build
|
| 21 |
+
|
| 22 |
+
# Or build and run manually
|
| 23 |
+
docker build -t policy-analysis .
|
| 24 |
+
docker run -p 7860:7860 policy-analysis
|
| 25 |
+
```
|
| 26 |
+
|
| 27 |
+
**Benefits:**
|
| 28 |
+
- Models are cached in the Docker image
|
| 29 |
+
- No download time during runtime
|
| 30 |
+
- Consistent deployment across environments
|
| 31 |
+
|
| 32 |
+
### Option 2: Manual Model Pre-loading
|
| 33 |
+
|
| 34 |
+
If not using Docker, run the model downloader script before starting the application:
|
| 35 |
+
|
| 36 |
+
```bash
|
| 37 |
+
# Install dependencies
|
| 38 |
+
pip install -r requirements.txt
|
| 39 |
+
|
| 40 |
+
# Download all models (one-time setup)
|
| 41 |
+
python download_models.py
|
| 42 |
+
|
| 43 |
+
# Start the application
|
| 44 |
+
python app.py
|
| 45 |
+
```
|
| 46 |
+
|
| 47 |
+
### Option 3: Startup Script
|
| 48 |
+
|
| 49 |
+
Use the startup script that automatically downloads models if needed:
|
| 50 |
+
|
| 51 |
+
```bash
|
| 52 |
+
python startup.py
|
| 53 |
+
```
|
| 54 |
+
|
| 55 |
+
## Environment Variables
|
| 56 |
+
|
| 57 |
+
- `PRELOAD_MODELS=true` (default): Pre-load models in app.py
|
| 58 |
+
- `PRELOAD_MODELS=false`: Skip model pre-loading (useful when models are already cached)
|
| 59 |
+
|
| 60 |
+
## Model Storage
|
| 61 |
+
|
| 62 |
+
Models are cached in:
|
| 63 |
+
- **Linux/Mac**: `~/.cache/huggingface/`
|
| 64 |
+
- **Windows**: `%USERPROFILE%\.cache\huggingface\`
|
| 65 |
+
|
| 66 |
+
## Deployment Best Practices
|
| 67 |
+
|
| 68 |
+
### 1. For Production Deployments
|
| 69 |
+
|
| 70 |
+
```bash
|
| 71 |
+
# Build Docker image with models pre-cached
|
| 72 |
+
docker build -t policy-analysis:latest .
|
| 73 |
+
|
| 74 |
+
# Deploy with persistent model cache
|
| 75 |
+
docker-compose up -d
|
| 76 |
+
```
|
| 77 |
+
|
| 78 |
+
### 2. For Development
|
| 79 |
+
|
| 80 |
+
```bash
|
| 81 |
+
# Download models once
|
| 82 |
+
python download_models.py
|
| 83 |
+
|
| 84 |
+
# Start development server
|
| 85 |
+
python app.py
|
| 86 |
+
```
|
| 87 |
+
|
| 88 |
+
### 3. For Cloud Deployments
|
| 89 |
+
|
| 90 |
+
When deploying to cloud platforms (AWS, GCP, Azure):
|
| 91 |
+
|
| 92 |
+
1. Use the Dockerfile to ensure models are cached in the image
|
| 93 |
+
2. Consider using a persistent volume for model cache if rebuilding frequently
|
| 94 |
+
3. Set appropriate resource limits (RAM: 4GB+, CPU: 2+ cores)
|
| 95 |
+
|
| 96 |
+
## Model Download Sizes
|
| 97 |
+
|
| 98 |
+
Approximate download sizes:
|
| 99 |
+
- `sentence-transformers/all-MiniLM-L6-v2`: ~90MB
|
| 100 |
+
- `BAAI/bge-m3`: ~2.3GB
|
| 101 |
+
- `cross-encoder/ms-marco-MiniLM-L-6-v2`: ~130MB
|
| 102 |
+
- `MoritzLaurer/deberta-v3-base-zeroshot-v2.0`: ~1.5GB
|
| 103 |
+
|
| 104 |
+
**Total**: ~4GB
|
| 105 |
+
|
| 106 |
+
## Troubleshooting
|
| 107 |
+
|
| 108 |
+
### Model Download Fails
|
| 109 |
+
```bash
|
| 110 |
+
# Check internet connection
|
| 111 |
+
# Ensure sufficient disk space (>5GB)
|
| 112 |
+
# Verify HuggingFace Hub access
|
| 113 |
+
|
| 114 |
+
# Manual download test
|
| 115 |
+
python -c "from sentence_transformers import SentenceTransformer; SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')"
|
| 116 |
+
```
|
| 117 |
+
|
| 118 |
+
### Memory Issues
|
| 119 |
+
- Ensure at least 8GB RAM available
|
| 120 |
+
- Consider using CPU-only inference for smaller deployments
|
| 121 |
+
- Use model quantization if needed
|
| 122 |
+
|
| 123 |
+
### Slow First Request
|
| 124 |
+
- Verify models are properly cached
|
| 125 |
+
- Check if `PRELOAD_MODELS=true` is set
|
| 126 |
+
- Monitor GPU/CPU utilization
|
| 127 |
+
|
| 128 |
+
## Performance Optimization
|
| 129 |
+
|
| 130 |
+
1. **Model Caching**: Models cached locally = faster loading
|
| 131 |
+
2. **GPU Usage**: Set `device=0` in model configs for GPU acceleration
|
| 132 |
+
3. **Batch Processing**: Process multiple requests together when possible
|
| 133 |
+
4. **Model Quantization**: Use quantized models for edge deployments
|
| 134 |
+
|
| 135 |
+
## Monitoring
|
| 136 |
+
|
| 137 |
+
Monitor these metrics:
|
| 138 |
+
- Model loading time
|
| 139 |
+
- Inference latency
|
| 140 |
+
- Memory usage
|
| 141 |
+
- Disk space (for model cache)
|
| 142 |
+
|
| 143 |
+
## Updates
|
| 144 |
+
|
| 145 |
+
To update models:
|
| 146 |
+
```bash
|
| 147 |
+
# Clear cache
|
| 148 |
+
rm -rf ~/.cache/huggingface/
|
| 149 |
+
|
| 150 |
+
# Re-download
|
| 151 |
+
python download_models.py
|
| 152 |
+
```
|
README_MODELS.md
ADDED
|
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Policy Analysis Application - Model Pre-loading Setup
|
| 2 |
+
|
| 3 |
+
This application has been enhanced with model pre-loading capabilities to significantly reduce inference time during deployment.
|
| 4 |
+
|
| 5 |
+
## π Quick Start
|
| 6 |
+
|
| 7 |
+
### Option 1: Docker Deployment (Recommended)
|
| 8 |
+
```bash
|
| 9 |
+
# Clone the repository
|
| 10 |
+
git clone <your-repo-url>
|
| 11 |
+
cd policy-analysis
|
| 12 |
+
|
| 13 |
+
# Build and run with Docker
|
| 14 |
+
docker-compose up --build
|
| 15 |
+
```
|
| 16 |
+
|
| 17 |
+
### Option 2: Manual Setup
|
| 18 |
+
```bash
|
| 19 |
+
# Install dependencies
|
| 20 |
+
pip install -r requirements.txt
|
| 21 |
+
|
| 22 |
+
# Download all models (one-time setup)
|
| 23 |
+
python download_models.py
|
| 24 |
+
|
| 25 |
+
# Test models are working
|
| 26 |
+
python test_models.py
|
| 27 |
+
|
| 28 |
+
# Start the application
|
| 29 |
+
python app.py
|
| 30 |
+
```
|
| 31 |
+
|
| 32 |
+
## π¦ What's New
|
| 33 |
+
|
| 34 |
+
### Files Added:
|
| 35 |
+
- **`download_models.py`** - Downloads all required ML models
|
| 36 |
+
- **`test_models.py`** - Verifies all models are working correctly
|
| 37 |
+
- **`startup.py`** - Startup script with automatic model downloading
|
| 38 |
+
- **`Dockerfile`** - Docker configuration with model pre-caching
|
| 39 |
+
- **`docker-compose.yml`** - Docker Compose setup
|
| 40 |
+
- **`MODEL_SETUP.md`** - Detailed setup documentation
|
| 41 |
+
|
| 42 |
+
### Files Modified:
|
| 43 |
+
- **`app.py`** - Added model pre-loading functionality
|
| 44 |
+
- **`requirements.txt`** - Added missing dependencies (numpy, requests)
|
| 45 |
+
- **`utils/coherence_bbscore.py`** - Fixed default embedder parameter
|
| 46 |
+
|
| 47 |
+
## π€ Models Used
|
| 48 |
+
|
| 49 |
+
The application uses these ML models:
|
| 50 |
+
|
| 51 |
+
| Model | Type | Size | Purpose |
|
| 52 |
+
|-------|------|------|---------|
|
| 53 |
+
| `sentence-transformers/all-MiniLM-L6-v2` | Embedding | ~90MB | Text encoding |
|
| 54 |
+
| `BAAI/bge-m3` | Embedding | ~2.3GB | Advanced text encoding |
|
| 55 |
+
| `cross-encoder/ms-marco-MiniLM-L-6-v2` | Cross-Encoder | ~130MB | Document reranking |
|
| 56 |
+
| `MoritzLaurer/deberta-v3-base-zeroshot-v2.0` | Classification | ~1.5GB | Sentiment analysis |
|
| 57 |
+
|
| 58 |
+
**Total download size**: ~4GB
|
| 59 |
+
|
| 60 |
+
## β‘ Performance Benefits
|
| 61 |
+
|
| 62 |
+
### Before (without pre-loading):
|
| 63 |
+
- First request: 30-60 seconds (model download + inference)
|
| 64 |
+
- Subsequent requests: 2-5 seconds
|
| 65 |
+
|
| 66 |
+
### After (with pre-loading):
|
| 67 |
+
- First request: 2-5 seconds
|
| 68 |
+
- Subsequent requests: 2-5 seconds
|
| 69 |
+
|
| 70 |
+
## π§ Configuration
|
| 71 |
+
|
| 72 |
+
### Environment Variables:
|
| 73 |
+
- `PRELOAD_MODELS=true` (default) - Pre-load models on app startup
|
| 74 |
+
- `PRELOAD_MODELS=false` - Skip pre-loading (useful when models are cached)
|
| 75 |
+
|
| 76 |
+
### Model Cache Location:
|
| 77 |
+
- **Linux/Mac**: `~/.cache/huggingface/`
|
| 78 |
+
- **Windows**: `%USERPROFILE%\.cache\huggingface\`
|
| 79 |
+
|
| 80 |
+
## π³ Docker Deployment
|
| 81 |
+
|
| 82 |
+
The Dockerfile automatically downloads models during the build process:
|
| 83 |
+
|
| 84 |
+
```dockerfile
|
| 85 |
+
# Downloads models and caches them in the image
|
| 86 |
+
RUN python download_models.py
|
| 87 |
+
```
|
| 88 |
+
|
| 89 |
+
This means:
|
| 90 |
+
- β
No download time during container startup
|
| 91 |
+
- β
Consistent performance across deployments
|
| 92 |
+
- β
Offline inference capability
|
| 93 |
+
|
| 94 |
+
## π§ͺ Testing
|
| 95 |
+
|
| 96 |
+
Verify everything is working:
|
| 97 |
+
|
| 98 |
+
```bash
|
| 99 |
+
# Test all models
|
| 100 |
+
python test_models.py
|
| 101 |
+
|
| 102 |
+
# Expected output:
|
| 103 |
+
# π§ͺ Model Verification Test Suite
|
| 104 |
+
# β
All tests passed! The application is ready to deploy.
|
| 105 |
+
```
|
| 106 |
+
|
| 107 |
+
## π Resource Requirements
|
| 108 |
+
|
| 109 |
+
### Minimum:
|
| 110 |
+
- **RAM**: 8GB
|
| 111 |
+
- **Storage**: 6GB (models + dependencies)
|
| 112 |
+
- **CPU**: 2+ cores
|
| 113 |
+
|
| 114 |
+
### Recommended:
|
| 115 |
+
- **RAM**: 16GB
|
| 116 |
+
- **Storage**: 10GB
|
| 117 |
+
- **CPU**: 4+ cores
|
| 118 |
+
- **GPU**: Optional (NVIDIA with CUDA support)
|
| 119 |
+
|
| 120 |
+
## π¨ Troubleshooting
|
| 121 |
+
|
| 122 |
+
### Model Download Issues:
|
| 123 |
+
```bash
|
| 124 |
+
# Check connectivity
|
| 125 |
+
curl -I https://huggingface.co
|
| 126 |
+
|
| 127 |
+
# Check disk space
|
| 128 |
+
df -h
|
| 129 |
+
|
| 130 |
+
# Manual model test
|
| 131 |
+
python -c "from sentence_transformers import SentenceTransformer; SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')"
|
| 132 |
+
```
|
| 133 |
+
|
| 134 |
+
### Memory Issues:
|
| 135 |
+
- Reduce model batch sizes
|
| 136 |
+
- Use CPU-only inference: `device=-1`
|
| 137 |
+
- Consider model quantization
|
| 138 |
+
|
| 139 |
+
### Slow Performance:
|
| 140 |
+
- Verify models are cached locally
|
| 141 |
+
- Check if `PRELOAD_MODELS=true`
|
| 142 |
+
- Monitor CPU/GPU usage
|
| 143 |
+
|
| 144 |
+
## π Monitoring
|
| 145 |
+
|
| 146 |
+
Monitor these metrics in production:
|
| 147 |
+
- Model loading time
|
| 148 |
+
- Inference latency
|
| 149 |
+
- Memory usage
|
| 150 |
+
- Cache hit ratio
|
| 151 |
+
|
| 152 |
+
## π Updates
|
| 153 |
+
|
| 154 |
+
To update models:
|
| 155 |
+
```bash
|
| 156 |
+
# Clear cache
|
| 157 |
+
rm -rf ~/.cache/huggingface/
|
| 158 |
+
|
| 159 |
+
# Re-download
|
| 160 |
+
python download_models.py
|
| 161 |
+
|
| 162 |
+
# Test
|
| 163 |
+
python test_models.py
|
| 164 |
+
```
|
| 165 |
+
|
| 166 |
+
## π‘ Tips for Production
|
| 167 |
+
|
| 168 |
+
1. **Use Docker**: Models are cached in the image
|
| 169 |
+
2. **Persistent Volumes**: Mount model cache for faster rebuilds
|
| 170 |
+
3. **Health Checks**: Monitor model availability
|
| 171 |
+
4. **Resource Limits**: Set appropriate memory/CPU limits
|
| 172 |
+
5. **Load Balancing**: Use multiple instances for high traffic
|
| 173 |
+
|
| 174 |
+
## π€ Contributing
|
| 175 |
+
|
| 176 |
+
When adding new models:
|
| 177 |
+
1. Add model name to `download_models.py`
|
| 178 |
+
2. Add test case to `test_models.py`
|
| 179 |
+
3. Update documentation
|
| 180 |
+
4. Test thoroughly
|
| 181 |
+
|
| 182 |
+
---
|
| 183 |
+
|
| 184 |
+
For detailed setup instructions, see [`MODEL_SETUP.md`](MODEL_SETUP.md).
|
app.py
CHANGED
|
@@ -1,36 +1,312 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
|
| 4 |
-
with
|
| 5 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
-
|
| 19 |
-
|
|
|
|
| 20 |
|
| 21 |
-
|
| 22 |
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
)
|
| 34 |
|
| 35 |
if __name__ == "__main__":
|
| 36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
+
import requests
|
| 3 |
+
import numpy as np
|
| 4 |
+
import time
|
| 5 |
+
import json
|
| 6 |
+
import os
|
| 7 |
|
| 8 |
+
# Import the utilities with proper error handling
|
| 9 |
+
try:
|
| 10 |
+
from utils.encoding_input import encode_text
|
| 11 |
+
from utils.retrieve_n_rerank import retrieve_and_rerank
|
| 12 |
+
from utils.sentiment_analysis import get_sentiment
|
| 13 |
+
from utils.coherence_bbscore import coherence_report
|
| 14 |
+
from utils.loading_embeddings import get_vectorstore
|
| 15 |
+
from utils.model_generation import build_messages
|
| 16 |
+
from utils.query_constraints import parse_query_constraints, page_matches, doc_matches
|
| 17 |
+
from utils.conversation_logging import load_history, log_exchange
|
| 18 |
+
from langchain.schema import Document
|
| 19 |
+
except ImportError as e:
|
| 20 |
+
print(f"Import error: {e}")
|
| 21 |
+
print("Make sure you're running from the correct directory and all dependencies are installed.")
|
| 22 |
|
| 23 |
+
API_KEY = os.getenv("API_KEY", "sk-do-8Hjf0liuGQCoPwglilL49xiqrthMECwjGP_kAjPM53OTOFQczPyfPK8xJc")
|
| 24 |
+
MODEL = "llama3.3-70b-instruct"
|
| 25 |
+
|
| 26 |
+
# Global settings for sentiment and coherence analysis
|
| 27 |
+
ENABLE_SENTIMENT = True
|
| 28 |
+
ENABLE_COHERENCE = True
|
| 29 |
+
|
| 30 |
+
# Load persisted history (if any) for memory retention
|
| 31 |
+
PERSISTED_HISTORY = load_history()
|
| 32 |
+
|
| 33 |
+
def chat_response(message, history):
|
| 34 |
+
"""
|
| 35 |
+
Generate response for chat interface.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
message: Current user message
|
| 39 |
+
history: List of [user_message, bot_response] pairs
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
try:
|
| 43 |
+
# Initialize vectorstore when needed
|
| 44 |
+
vectorstore = get_vectorstore()
|
| 45 |
+
|
| 46 |
+
constraints = parse_query_constraints(message)
|
| 47 |
+
want_page = constraints.get("page")
|
| 48 |
+
doc_tokens = constraints.get("doc_tokens", [])
|
| 49 |
+
|
| 50 |
+
# Increase initial recall if a specific page is requested
|
| 51 |
+
base_k = 120 if want_page is not None else 50
|
| 52 |
+
reranked_results = retrieve_and_rerank(
|
| 53 |
+
query_text=message,
|
| 54 |
+
vectorstore=vectorstore,
|
| 55 |
+
k=base_k,
|
| 56 |
+
rerank_model="cross-encoder/ms-marco-MiniLM-L-6-v2",
|
| 57 |
+
top_m=40 if want_page is not None else 20,
|
| 58 |
+
min_score=0.4 if want_page is not None else 0.5, # relax threshold for page-constrained queries
|
| 59 |
+
only_docs=False
|
| 60 |
)
|
| 61 |
+
|
| 62 |
+
if not reranked_results:
|
| 63 |
+
return "I'm sorry, I couldn't find any relevant information in the policy documents to answer your question. Could you try rephrasing your question or asking about a different topic?"
|
| 64 |
+
|
| 65 |
+
# Enforce page constraint if present
|
| 66 |
+
# Document filtering (title tokens)
|
| 67 |
+
if doc_tokens:
|
| 68 |
+
reranked_results = [(d,s) for d,s in reranked_results if doc_matches(getattr(d,'metadata',{}), doc_tokens)]
|
| 69 |
|
| 70 |
+
# Page filter
|
| 71 |
+
if want_page is not None:
|
| 72 |
+
page_filtered = [(d, s) for d, s in reranked_results if page_matches(getattr(d, 'metadata', {}), want_page)]
|
| 73 |
+
if not page_filtered:
|
| 74 |
+
# Fallback: exhaustive scan of vectorstore for that page & doc tokens
|
| 75 |
+
all_docs = []
|
| 76 |
+
try:
|
| 77 |
+
for i in range(len(vectorstore.index_to_docstore_id)):
|
| 78 |
+
doc = vectorstore.docstore.search(vectorstore.index_to_docstore_id[i])
|
| 79 |
+
meta = getattr(doc,'metadata',{})
|
| 80 |
+
if doc_tokens and not doc_matches(meta, doc_tokens):
|
| 81 |
+
continue
|
| 82 |
+
if page_matches(meta, want_page):
|
| 83 |
+
all_docs.append(doc)
|
| 84 |
+
except Exception:
|
| 85 |
+
pass
|
| 86 |
+
if all_docs:
|
| 87 |
+
# treat as retrieved with neutral score
|
| 88 |
+
reranked_results = [(d, 0.0) for d in all_docs]
|
| 89 |
+
page_filtered = reranked_results
|
| 90 |
+
else:
|
| 91 |
+
reranked_results = page_filtered
|
| 92 |
|
| 93 |
+
# If still nothing after fallback, return not found
|
| 94 |
+
if want_page is not None and (not reranked_results or (doc_tokens and not any(page_matches(getattr(d,'metadata',{}), want_page) for d,_ in reranked_results))):
|
| 95 |
+
return "Not found in sources."
|
| 96 |
|
| 97 |
+
top_docs = [doc for doc, score in reranked_results]
|
| 98 |
|
| 99 |
+
# Perform sentiment and coherence analysis if enabled
|
| 100 |
+
sentiment_rollup = get_sentiment(top_docs) if ENABLE_SENTIMENT else {}
|
| 101 |
+
coherence_report_ = coherence_report(reranked_results=top_docs, input_text=message) if ENABLE_COHERENCE else ""
|
| 102 |
+
|
| 103 |
+
# Build base messages from strict template
|
| 104 |
+
allow_meta = None
|
| 105 |
+
if want_page is not None and doc_tokens:
|
| 106 |
+
# simple doc_id alias from tokens joined
|
| 107 |
+
allow_meta = {"doc_id": "_".join(doc_tokens), "pages": [want_page]}
|
| 108 |
+
base_messages = build_messages(
|
| 109 |
+
query=message,
|
| 110 |
+
top_docs=top_docs,
|
| 111 |
+
task_mode="verbatim_sentiment",
|
| 112 |
+
sentiment_rollup=sentiment_rollup if ENABLE_SENTIMENT else {},
|
| 113 |
+
coherence_report=coherence_report_ if ENABLE_COHERENCE else "",
|
| 114 |
+
allowlist_meta=allow_meta
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
# Insert recent history (excluding system + final user already in base) after system message
|
| 118 |
+
messages = [base_messages[0]] # system
|
| 119 |
+
# Combine persisted history (only at first call when provided history empty)
|
| 120 |
+
if not history and PERSISTED_HISTORY:
|
| 121 |
+
history.extend(PERSISTED_HISTORY[-6:]) # seed last 6 past exchanges
|
| 122 |
+
recent_history = history[-6:] if len(history) > 6 else history
|
| 123 |
+
for u, a in recent_history:
|
| 124 |
+
messages.append({"role": "user", "content": u})
|
| 125 |
+
messages.append({"role": "assistant", "content": a})
|
| 126 |
+
messages.append(base_messages[1]) # current user prompt (template)
|
| 127 |
+
|
| 128 |
+
# Stream response from the API
|
| 129 |
+
response = ""
|
| 130 |
+
for chunk in stream_llm_response(messages):
|
| 131 |
+
response += chunk
|
| 132 |
+
yield response
|
| 133 |
+
# After final response, log exchange persistently
|
| 134 |
+
try:
|
| 135 |
+
log_exchange(message, response, meta={"pages": [getattr(d.metadata,'page_label', None) if hasattr(d,'metadata') else None for d in top_docs]})
|
| 136 |
+
except Exception as log_err:
|
| 137 |
+
print(f"Logging error: {log_err}")
|
| 138 |
+
|
| 139 |
+
except Exception as e:
|
| 140 |
+
error_msg = f"I encountered an error while processing your request: {str(e)}"
|
| 141 |
+
yield error_msg
|
| 142 |
|
| 143 |
+
## Removed custom prompt builder in favor of strict template usage
|
| 144 |
+
|
| 145 |
+
def stream_llm_response(messages):
|
| 146 |
+
"""Stream response from the LLM API."""
|
| 147 |
+
headers = {
|
| 148 |
+
"Authorization": f"Bearer {API_KEY}",
|
| 149 |
+
"Content-Type": "application/json"
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
data = {
|
| 153 |
+
"model": MODEL,
|
| 154 |
+
"messages": messages,
|
| 155 |
+
"temperature": 0.2,
|
| 156 |
+
"stream": True,
|
| 157 |
+
"max_tokens": 2000
|
| 158 |
+
}
|
| 159 |
+
|
| 160 |
+
try:
|
| 161 |
+
with requests.post("https://inference.do-ai.run/v1/chat/completions",
|
| 162 |
+
headers=headers, json=data, stream=True, timeout=30) as r:
|
| 163 |
+
if r.status_code != 200:
|
| 164 |
+
yield f"[ERROR] API returned status {r.status_code}: {r.text}"
|
| 165 |
+
return
|
| 166 |
+
|
| 167 |
+
for line in r.iter_lines(decode_unicode=True):
|
| 168 |
+
if not line or line.strip() == "data: [DONE]":
|
| 169 |
+
continue
|
| 170 |
+
if line.startswith("data: "):
|
| 171 |
+
line = line[len("data: "):]
|
| 172 |
+
|
| 173 |
+
try:
|
| 174 |
+
chunk = json.loads(line)
|
| 175 |
+
delta = chunk.get("choices", [{}])[0].get("delta", {}).get("content", "")
|
| 176 |
+
if delta:
|
| 177 |
+
yield delta
|
| 178 |
+
time.sleep(0.01) # Small delay for smooth streaming
|
| 179 |
+
except json.JSONDecodeError:
|
| 180 |
+
continue
|
| 181 |
+
except Exception as e:
|
| 182 |
+
print(f"Streaming error: {e}")
|
| 183 |
+
continue
|
| 184 |
+
|
| 185 |
+
except requests.exceptions.RequestException as e:
|
| 186 |
+
yield f"[ERROR] Network error: {str(e)}"
|
| 187 |
+
except Exception as e:
|
| 188 |
+
yield f"[ERROR] Unexpected error: {str(e)}"
|
| 189 |
+
|
| 190 |
+
def update_sentiment_setting(enable):
|
| 191 |
+
"""Update global sentiment analysis setting."""
|
| 192 |
+
global ENABLE_SENTIMENT
|
| 193 |
+
ENABLE_SENTIMENT = enable
|
| 194 |
+
return f"β
Sentiment analysis {'enabled' if enable else 'disabled'}"
|
| 195 |
+
|
| 196 |
+
def update_coherence_setting(enable):
|
| 197 |
+
"""Update global coherence analysis setting."""
|
| 198 |
+
global ENABLE_COHERENCE
|
| 199 |
+
ENABLE_COHERENCE = enable
|
| 200 |
+
return f"β
Coherence analysis {'enabled' if enable else 'disabled'}"
|
| 201 |
+
|
| 202 |
+
# Create the chat interface
|
| 203 |
+
with gr.Blocks(title="Kenya Policy Assistant - Chat", theme=gr.themes.Soft()) as demo:
|
| 204 |
+
gr.Markdown("""
|
| 205 |
+
# ποΈ Kenya Policy Assistant - Interactive Chat
|
| 206 |
+
Ask questions about Kenya's policies and have a conversation! I can help you understand policy documents with sentiment and coherence analysis.
|
| 207 |
+
""")
|
| 208 |
+
|
| 209 |
+
with gr.Row():
|
| 210 |
+
with gr.Column(scale=3):
|
| 211 |
+
# Settings row at the top
|
| 212 |
+
with gr.Row():
|
| 213 |
+
sentiment_toggle = gr.Checkbox(
|
| 214 |
+
label="π Sentiment Analysis",
|
| 215 |
+
value=True,
|
| 216 |
+
info="Analyze tone and sentiment of policy documents"
|
| 217 |
+
)
|
| 218 |
+
coherence_toggle = gr.Checkbox(
|
| 219 |
+
label="π Coherence Analysis",
|
| 220 |
+
value=True,
|
| 221 |
+
info="Check coherence and consistency of retrieved documents"
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
# Main chat interface
|
| 225 |
+
chatbot = gr.Chatbot(
|
| 226 |
+
height=500,
|
| 227 |
+
bubble_full_width=False,
|
| 228 |
+
show_copy_button=True,
|
| 229 |
+
show_share_button=True,
|
| 230 |
+
avatar_images=("π€", "π€"),
|
| 231 |
+
value=PERSISTED_HISTORY # seed prior memory
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
msg = gr.Textbox(
|
| 235 |
+
placeholder="Ask me about Kenya's policies... (e.g., 'What are the renewable energy regulations?')",
|
| 236 |
+
label="Your Question",
|
| 237 |
+
lines=2
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
with gr.Row():
|
| 241 |
+
submit_btn = gr.Button("π€ Send", variant="primary")
|
| 242 |
+
clear_btn = gr.Button("ποΈ Clear Chat")
|
| 243 |
+
|
| 244 |
+
with gr.Column(scale=1):
|
| 245 |
+
gr.Markdown("""
|
| 246 |
+
### π‘ Chat Tips
|
| 247 |
+
- Ask specific questions about Kenya policies
|
| 248 |
+
- Ask follow-up questions based on responses
|
| 249 |
+
- Reference previous answers: *"What does this mean?"*
|
| 250 |
+
- Request elaboration: *"Can you explain more?"*
|
| 251 |
+
|
| 252 |
+
### π Example Questions
|
| 253 |
+
- *"What are Kenya's renewable energy policies?"*
|
| 254 |
+
- *"Tell me about water management regulations"*
|
| 255 |
+
- *"What penalties exist for environmental violations?"*
|
| 256 |
+
- *"How does this relate to what you mentioned earlier?"*
|
| 257 |
+
|
| 258 |
+
### βοΈ Analysis Features
|
| 259 |
+
**Sentiment Analysis**: Understands the tone and intent of policy text
|
| 260 |
+
|
| 261 |
+
**Coherence Analysis**: Checks if retrieved documents are relevant and consistent
|
| 262 |
+
""")
|
| 263 |
+
|
| 264 |
+
with gr.Accordion("π Analysis Status", open=False):
|
| 265 |
+
sentiment_status = gr.Textbox(
|
| 266 |
+
value="β
Sentiment analysis enabled",
|
| 267 |
+
label="Sentiment Status",
|
| 268 |
+
interactive=False
|
| 269 |
+
)
|
| 270 |
+
coherence_status = gr.Textbox(
|
| 271 |
+
value="β
Coherence analysis enabled",
|
| 272 |
+
label="Coherence Status",
|
| 273 |
+
interactive=False
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
# Chat functionality
|
| 277 |
+
def respond(message, history):
|
| 278 |
+
if message.strip():
|
| 279 |
+
bot_message = chat_response(message, history)
|
| 280 |
+
history.append([message, ""])
|
| 281 |
+
|
| 282 |
+
for partial_response in bot_message:
|
| 283 |
+
history[-1][1] = partial_response
|
| 284 |
+
yield history, ""
|
| 285 |
+
else:
|
| 286 |
+
yield history, ""
|
| 287 |
+
|
| 288 |
+
submit_btn.click(respond, [msg, chatbot], [chatbot, msg])
|
| 289 |
+
msg.submit(respond, [msg, chatbot], [chatbot, msg])
|
| 290 |
+
clear_btn.click(lambda: ([], ""), outputs=[chatbot, msg])
|
| 291 |
+
|
| 292 |
+
# Update settings when toggles change
|
| 293 |
+
sentiment_toggle.change(
|
| 294 |
+
fn=update_sentiment_setting,
|
| 295 |
+
inputs=[sentiment_toggle],
|
| 296 |
+
outputs=[sentiment_status]
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
coherence_toggle.change(
|
| 300 |
+
fn=update_coherence_setting,
|
| 301 |
+
inputs=[coherence_toggle],
|
| 302 |
+
outputs=[coherence_status]
|
| 303 |
)
|
| 304 |
|
| 305 |
if __name__ == "__main__":
|
| 306 |
+
print("π Starting Kenya Policy Assistant Chat...")
|
| 307 |
+
demo.queue(max_size=20).launch(
|
| 308 |
+
share=True,
|
| 309 |
+
debug=True,
|
| 310 |
+
server_name="0.0.0.0",
|
| 311 |
+
server_port=7860
|
| 312 |
+
)
|
app_chat.py
ADDED
|
@@ -0,0 +1,295 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import requests
|
| 3 |
+
import numpy as np
|
| 4 |
+
import time
|
| 5 |
+
import json
|
| 6 |
+
import os
|
| 7 |
+
|
| 8 |
+
# Import the utilities with proper error handling
|
| 9 |
+
try:
|
| 10 |
+
from utils.encoding_input import encode_text
|
| 11 |
+
from utils.retrieve_n_rerank import retrieve_and_rerank
|
| 12 |
+
from utils.sentiment_analysis import get_sentiment
|
| 13 |
+
from utils.coherence_bbscore import coherence_report
|
| 14 |
+
from utils.loading_embeddings import get_vectorstore
|
| 15 |
+
from utils.model_generation import build_messages
|
| 16 |
+
except ImportError as e:
|
| 17 |
+
print(f"Import error: {e}")
|
| 18 |
+
print("Make sure you're running from the correct directory and all dependencies are installed.")
|
| 19 |
+
|
| 20 |
+
API_KEY = os.getenv("API_KEY", "sk-do-8Hjf0liuGQCoPwglilL49xiqrthMECwjGP_kAjPM53OTOFQczPyfPK8xJc")
|
| 21 |
+
MODEL = "llama3.3-70b-instruct"
|
| 22 |
+
|
| 23 |
+
# Global settings for sentiment and coherence analysis
|
| 24 |
+
ENABLE_SENTIMENT = True
|
| 25 |
+
ENABLE_COHERENCE = True
|
| 26 |
+
|
| 27 |
+
def chat_response(message, history):
|
| 28 |
+
"""
|
| 29 |
+
Generate response for chat interface.
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
message: Current user message
|
| 33 |
+
history: List of [user_message, bot_response] pairs
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
try:
|
| 37 |
+
# Initialize vectorstore when needed
|
| 38 |
+
vectorstore = get_vectorstore()
|
| 39 |
+
|
| 40 |
+
# Retrieve and rerank documents
|
| 41 |
+
reranked_results = retrieve_and_rerank(
|
| 42 |
+
query_text=message,
|
| 43 |
+
vectorstore=vectorstore,
|
| 44 |
+
k=50, # number of initial documents to retrieve
|
| 45 |
+
rerank_model="cross-encoder/ms-marco-MiniLM-L-6-v2",
|
| 46 |
+
top_m=20, # number of documents to return after reranking
|
| 47 |
+
min_score=0.5, # minimum score for reranked documents
|
| 48 |
+
only_docs=False # return both documents and scores
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
if not reranked_results:
|
| 52 |
+
return "I'm sorry, I couldn't find any relevant information in the policy documents to answer your question. Could you try rephrasing your question or asking about a different topic?"
|
| 53 |
+
|
| 54 |
+
top_docs = [doc for doc, score in reranked_results]
|
| 55 |
+
|
| 56 |
+
# Perform sentiment and coherence analysis if enabled
|
| 57 |
+
sentiment_rollup = get_sentiment(top_docs) if ENABLE_SENTIMENT else {}
|
| 58 |
+
coherence_report_ = coherence_report(reranked_results=top_docs, input_text=message) if ENABLE_COHERENCE else ""
|
| 59 |
+
|
| 60 |
+
# Build messages for the LLM, including conversation history
|
| 61 |
+
messages = build_messages_with_history(
|
| 62 |
+
query=message,
|
| 63 |
+
history=history,
|
| 64 |
+
top_docs=top_docs,
|
| 65 |
+
task_mode="verbatim_sentiment",
|
| 66 |
+
sentiment_rollup=sentiment_rollup,
|
| 67 |
+
coherence_report=coherence_report_,
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
# Stream response from the API
|
| 71 |
+
response = ""
|
| 72 |
+
for chunk in stream_llm_response(messages):
|
| 73 |
+
response += chunk
|
| 74 |
+
yield response
|
| 75 |
+
|
| 76 |
+
except Exception as e:
|
| 77 |
+
error_msg = f"I encountered an error while processing your request: {str(e)}"
|
| 78 |
+
yield error_msg
|
| 79 |
+
|
| 80 |
+
def build_messages_with_history(query, history, top_docs, task_mode, sentiment_rollup, coherence_report):
|
| 81 |
+
"""Build messages including conversation history for better context."""
|
| 82 |
+
|
| 83 |
+
# System message
|
| 84 |
+
system_msg = (
|
| 85 |
+
"You are a compliance-grade policy analyst assistant specializing in Kenya policy documents. "
|
| 86 |
+
"Your job is to return precise, fact-grounded responses based on the provided policy documents. "
|
| 87 |
+
"Avoid hallucinations. Base everything strictly on the content provided. "
|
| 88 |
+
"Maintain conversation context from previous exchanges when relevant. "
|
| 89 |
+
"If sentiment or coherence analysis is not available, do not mention it in the response."
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
messages = [{"role": "system", "content": system_msg}]
|
| 93 |
+
|
| 94 |
+
# Add conversation history (keep last 4 exchanges to maintain context without exceeding limits)
|
| 95 |
+
recent_history = history[-4:] if len(history) > 4 else history
|
| 96 |
+
for user_msg, bot_msg in recent_history:
|
| 97 |
+
messages.append({"role": "user", "content": user_msg})
|
| 98 |
+
messages.append({"role": "assistant", "content": bot_msg})
|
| 99 |
+
|
| 100 |
+
# Build context from retrieved documents
|
| 101 |
+
context_block = "\n\n".join([
|
| 102 |
+
f"**Source: {getattr(doc, 'metadata', {}).get('source', 'Unknown')} "
|
| 103 |
+
f"(Page {getattr(doc, 'metadata', {}).get('page', 'Unknown')})**\n"
|
| 104 |
+
f"{doc.page_content}\n"
|
| 105 |
+
for doc in top_docs[:10] # Limit to top 10 docs to avoid token limits
|
| 106 |
+
])
|
| 107 |
+
|
| 108 |
+
# Current user query with context
|
| 109 |
+
current_query = f"""
|
| 110 |
+
Query: {query}
|
| 111 |
+
|
| 112 |
+
Based on the following policy documents, please provide:
|
| 113 |
+
1) **Quoted Policy Excerpts**: Quote key policy content directly. Cite the source using filename and page.
|
| 114 |
+
2) **Analysis**: Explain the policy implications in clear terms.
|
| 115 |
+
"""
|
| 116 |
+
|
| 117 |
+
if sentiment_rollup:
|
| 118 |
+
current_query += f"\n3) **Sentiment Summary**: {sentiment_rollup}"
|
| 119 |
+
|
| 120 |
+
if coherence_report:
|
| 121 |
+
current_query += f"\n4) **Coherence Assessment**: {coherence_report}"
|
| 122 |
+
|
| 123 |
+
current_query += f"\n\nContext Sources:\n{context_block}"
|
| 124 |
+
|
| 125 |
+
messages.append({"role": "user", "content": current_query})
|
| 126 |
+
|
| 127 |
+
return messages
|
| 128 |
+
|
| 129 |
+
def stream_llm_response(messages):
|
| 130 |
+
"""Stream response from the LLM API."""
|
| 131 |
+
headers = {
|
| 132 |
+
"Authorization": f"Bearer {API_KEY}",
|
| 133 |
+
"Content-Type": "application/json"
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
data = {
|
| 137 |
+
"model": MODEL,
|
| 138 |
+
"messages": messages,
|
| 139 |
+
"temperature": 0.2,
|
| 140 |
+
"stream": True,
|
| 141 |
+
"max_tokens": 2000
|
| 142 |
+
}
|
| 143 |
+
|
| 144 |
+
try:
|
| 145 |
+
with requests.post("https://inference.do-ai.run/v1/chat/completions",
|
| 146 |
+
headers=headers, json=data, stream=True, timeout=30) as r:
|
| 147 |
+
if r.status_code != 200:
|
| 148 |
+
yield f"[ERROR] API returned status {r.status_code}: {r.text}"
|
| 149 |
+
return
|
| 150 |
+
|
| 151 |
+
for line in r.iter_lines(decode_unicode=True):
|
| 152 |
+
if not line or line.strip() == "data: [DONE]":
|
| 153 |
+
continue
|
| 154 |
+
if line.startswith("data: "):
|
| 155 |
+
line = line[len("data: "):]
|
| 156 |
+
|
| 157 |
+
try:
|
| 158 |
+
chunk = json.loads(line)
|
| 159 |
+
delta = chunk.get("choices", [{}])[0].get("delta", {}).get("content", "")
|
| 160 |
+
if delta:
|
| 161 |
+
yield delta
|
| 162 |
+
time.sleep(0.01) # Small delay for smooth streaming
|
| 163 |
+
except json.JSONDecodeError:
|
| 164 |
+
continue
|
| 165 |
+
except Exception as e:
|
| 166 |
+
print(f"Streaming error: {e}")
|
| 167 |
+
continue
|
| 168 |
+
|
| 169 |
+
except requests.exceptions.RequestException as e:
|
| 170 |
+
yield f"[ERROR] Network error: {str(e)}"
|
| 171 |
+
except Exception as e:
|
| 172 |
+
yield f"[ERROR] Unexpected error: {str(e)}"
|
| 173 |
+
|
| 174 |
+
def update_sentiment_setting(enable):
|
| 175 |
+
"""Update global sentiment analysis setting."""
|
| 176 |
+
global ENABLE_SENTIMENT
|
| 177 |
+
ENABLE_SENTIMENT = enable
|
| 178 |
+
return f"β
Sentiment analysis {'enabled' if enable else 'disabled'}"
|
| 179 |
+
|
| 180 |
+
def update_coherence_setting(enable):
|
| 181 |
+
"""Update global coherence analysis setting."""
|
| 182 |
+
global ENABLE_COHERENCE
|
| 183 |
+
ENABLE_COHERENCE = enable
|
| 184 |
+
return f"β
Coherence analysis {'enabled' if enable else 'disabled'}"
|
| 185 |
+
|
| 186 |
+
# Create the chat interface
|
| 187 |
+
with gr.Blocks(title="Kenya Policy Assistant - Chat", theme=gr.themes.Soft()) as demo:
|
| 188 |
+
gr.Markdown("""
|
| 189 |
+
# ποΈ Kenya Policy Assistant - Interactive Chat
|
| 190 |
+
Ask questions about Kenya's policies and have a conversation! I can help you understand policy documents with sentiment and coherence analysis.
|
| 191 |
+
""")
|
| 192 |
+
|
| 193 |
+
with gr.Row():
|
| 194 |
+
with gr.Column(scale=3):
|
| 195 |
+
# Settings row at the top
|
| 196 |
+
with gr.Row():
|
| 197 |
+
sentiment_toggle = gr.Checkbox(
|
| 198 |
+
label="π Sentiment Analysis",
|
| 199 |
+
value=True,
|
| 200 |
+
info="Analyze tone and sentiment of policy documents"
|
| 201 |
+
)
|
| 202 |
+
coherence_toggle = gr.Checkbox(
|
| 203 |
+
label="π Coherence Analysis",
|
| 204 |
+
value=True,
|
| 205 |
+
info="Check coherence and consistency of retrieved documents"
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
# Main chat interface
|
| 209 |
+
chatbot = gr.Chatbot(
|
| 210 |
+
height=500,
|
| 211 |
+
bubble_full_width=False,
|
| 212 |
+
show_copy_button=True,
|
| 213 |
+
show_share_button=True,
|
| 214 |
+
avatar_images=("π€", "π€")
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
msg = gr.Textbox(
|
| 218 |
+
placeholder="Ask me about Kenya's policies... (e.g., 'What are the renewable energy regulations?')",
|
| 219 |
+
label="Your Question",
|
| 220 |
+
lines=2
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
with gr.Row():
|
| 224 |
+
submit_btn = gr.Button("π€ Send", variant="primary")
|
| 225 |
+
clear_btn = gr.Button("ποΈ Clear Chat")
|
| 226 |
+
|
| 227 |
+
with gr.Column(scale=1):
|
| 228 |
+
gr.Markdown("""
|
| 229 |
+
### π‘ Chat Tips
|
| 230 |
+
- Ask specific questions about Kenya policies
|
| 231 |
+
- Ask follow-up questions based on responses
|
| 232 |
+
- Reference previous answers: *"What does this mean?"*
|
| 233 |
+
- Request elaboration: *"Can you explain more?"*
|
| 234 |
+
|
| 235 |
+
### π Example Questions
|
| 236 |
+
- *"What are Kenya's renewable energy policies?"*
|
| 237 |
+
- *"Tell me about water management regulations"*
|
| 238 |
+
- *"What penalties exist for environmental violations?"*
|
| 239 |
+
- *"How does this relate to what you mentioned earlier?"*
|
| 240 |
+
|
| 241 |
+
### βοΈ Analysis Features
|
| 242 |
+
**Sentiment Analysis**: Understands the tone and intent of policy text
|
| 243 |
+
|
| 244 |
+
**Coherence Analysis**: Checks if retrieved documents are relevant and consistent
|
| 245 |
+
""")
|
| 246 |
+
|
| 247 |
+
with gr.Accordion("π Analysis Status", open=False):
|
| 248 |
+
sentiment_status = gr.Textbox(
|
| 249 |
+
value="β
Sentiment analysis enabled",
|
| 250 |
+
label="Sentiment Status",
|
| 251 |
+
interactive=False
|
| 252 |
+
)
|
| 253 |
+
coherence_status = gr.Textbox(
|
| 254 |
+
value="β
Coherence analysis enabled",
|
| 255 |
+
label="Coherence Status",
|
| 256 |
+
interactive=False
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
# Chat functionality
|
| 260 |
+
def respond(message, history):
|
| 261 |
+
if message.strip():
|
| 262 |
+
bot_message = chat_response(message, history)
|
| 263 |
+
history.append([message, ""])
|
| 264 |
+
|
| 265 |
+
for partial_response in bot_message:
|
| 266 |
+
history[-1][1] = partial_response
|
| 267 |
+
yield history, ""
|
| 268 |
+
else:
|
| 269 |
+
yield history, ""
|
| 270 |
+
|
| 271 |
+
submit_btn.click(respond, [msg, chatbot], [chatbot, msg])
|
| 272 |
+
msg.submit(respond, [msg, chatbot], [chatbot, msg])
|
| 273 |
+
clear_btn.click(lambda: ([], ""), outputs=[chatbot, msg])
|
| 274 |
+
|
| 275 |
+
# Update settings when toggles change
|
| 276 |
+
sentiment_toggle.change(
|
| 277 |
+
fn=update_sentiment_setting,
|
| 278 |
+
inputs=[sentiment_toggle],
|
| 279 |
+
outputs=[sentiment_status]
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
coherence_toggle.change(
|
| 283 |
+
fn=update_coherence_setting,
|
| 284 |
+
inputs=[coherence_toggle],
|
| 285 |
+
outputs=[coherence_status]
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
if __name__ == "__main__":
|
| 289 |
+
print("π Starting Kenya Policy Assistant Chat...")
|
| 290 |
+
demo.queue(max_size=20).launch(
|
| 291 |
+
share=True,
|
| 292 |
+
debug=True,
|
| 293 |
+
server_name="0.0.0.0",
|
| 294 |
+
server_port=7860
|
| 295 |
+
)
|
app_original.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
from utils.generation_streaming import generate_response_stream
|
| 3 |
+
|
| 4 |
+
with gr.Blocks(title="Policy Assistant") as demo:
|
| 5 |
+
gr.Markdown("### β‘ Kenya Policy QA β Verbatim, Sentiment, and Coherence")
|
| 6 |
+
|
| 7 |
+
with gr.Row():
|
| 8 |
+
input_box = gr.Textbox(
|
| 9 |
+
label="Enter your policy question",
|
| 10 |
+
lines=2,
|
| 11 |
+
placeholder="e.g., What are the objectives of Kenyaβs energy policies?"
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
with gr.Row():
|
| 15 |
+
sentiment_toggle = gr.Checkbox(label="Enable Sentiment Analysis", value=True)
|
| 16 |
+
coherence_toggle = gr.Checkbox(label="Enable Coherence Check", value=True)
|
| 17 |
+
|
| 18 |
+
output_box = gr.Textbox(label="LLM Response", lines=25, interactive=False)
|
| 19 |
+
# output_box = gr.Textbox(label="LLM Response", lines=25, interactive=False, stream=True)
|
| 20 |
+
|
| 21 |
+
run_btn = gr.Button("π Generate")
|
| 22 |
+
|
| 23 |
+
run_btn.click(
|
| 24 |
+
fn=generate_response_stream,
|
| 25 |
+
inputs=[input_box, sentiment_toggle, coherence_toggle],
|
| 26 |
+
outputs=output_box
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
input_box.submit(
|
| 30 |
+
fn=generate_response_stream,
|
| 31 |
+
inputs=[input_box, sentiment_toggle, coherence_toggle],
|
| 32 |
+
outputs=output_box
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
if __name__ == "__main__":
|
| 36 |
+
demo.queue().launch(share=True, debug=True)
|
app_reserve.py
ADDED
|
@@ -0,0 +1,296 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import uuid
|
| 3 |
+
import time
|
| 4 |
+
import json
|
| 5 |
+
import requests
|
| 6 |
+
import gradio as gr
|
| 7 |
+
import time
|
| 8 |
+
import utils.helpers as helpers
|
| 9 |
+
from utils.helpers import retrieve_context, log_interaction_hf, upload_log_to_hf
|
| 10 |
+
|
| 11 |
+
# ========= Config & Globals =========
|
| 12 |
+
with open("config.json") as f:
|
| 13 |
+
config = json.load(f)
|
| 14 |
+
|
| 15 |
+
DO_API_KEY = config["do_token"]
|
| 16 |
+
token_ = config['token']
|
| 17 |
+
HF_TOKEN = 'hf_' + token_
|
| 18 |
+
session_id = f"{int(time.time())}-{uuid.uuid4().hex[:8]}"
|
| 19 |
+
helpers.session_id = session_id
|
| 20 |
+
BASE_URL = "https://inference.do-ai.run/v1"
|
| 21 |
+
UPLOAD_INTERVAL = 5
|
| 22 |
+
|
| 23 |
+
# ========= Inference Utilities =========
|
| 24 |
+
def _auth_headers():
|
| 25 |
+
return {"Authorization": f"Bearer {DO_API_KEY}", "Content-Type": "application/json"}
|
| 26 |
+
|
| 27 |
+
def list_models():
|
| 28 |
+
try:
|
| 29 |
+
r = requests.get(f"{BASE_URL}/models", headers=_auth_headers(), timeout=15)
|
| 30 |
+
r.raise_for_status()
|
| 31 |
+
data = r.json().get("data", [])
|
| 32 |
+
ids = [m["id"] for m in data]
|
| 33 |
+
if ids:
|
| 34 |
+
return ids
|
| 35 |
+
except Exception as e:
|
| 36 |
+
print(f"β οΈ list_models failed: {e}")
|
| 37 |
+
return ["llama3.3-70b-instruct"]
|
| 38 |
+
|
| 39 |
+
def gradient_request(model_id, prompt, max_tokens=512, temperature=0.7, top_p=0.95):
|
| 40 |
+
url = f"{BASE_URL}/chat/completions"
|
| 41 |
+
if not model_id:
|
| 42 |
+
model_id = list_models()[0]
|
| 43 |
+
payload = {
|
| 44 |
+
"model": model_id,
|
| 45 |
+
"messages": [{"role": "user", "content": prompt}],
|
| 46 |
+
"max_tokens": max_tokens,
|
| 47 |
+
"temperature": temperature,
|
| 48 |
+
"top_p": top_p,
|
| 49 |
+
}
|
| 50 |
+
for attempt in range(3):
|
| 51 |
+
try:
|
| 52 |
+
resp = requests.post(url, headers=_auth_headers(), json=payload, timeout=30)
|
| 53 |
+
if resp.status_code == 404:
|
| 54 |
+
ids = list_models()
|
| 55 |
+
if model_id not in ids and ids:
|
| 56 |
+
payload["model"] = ids[0]
|
| 57 |
+
continue
|
| 58 |
+
resp.raise_for_status()
|
| 59 |
+
j = resp.json()
|
| 60 |
+
return j["choices"][0]["message"]["content"].strip()
|
| 61 |
+
except requests.HTTPError as e:
|
| 62 |
+
msg = getattr(e.response, "text", str(e))
|
| 63 |
+
raise RuntimeError(f"Inference error ({e.response.status_code}): {msg}") from e
|
| 64 |
+
except requests.RequestException as e:
|
| 65 |
+
if attempt == 2:
|
| 66 |
+
raise
|
| 67 |
+
raise RuntimeError("Exhausted retries")
|
| 68 |
+
|
| 69 |
+
def gradient_stream(model_id, prompt, max_tokens=512, temperature=0.7, top_p=0.95):
|
| 70 |
+
url = f"{BASE_URL}/chat/completions"
|
| 71 |
+
if not model_id:
|
| 72 |
+
model_id = list_models()[0]
|
| 73 |
+
payload = {
|
| 74 |
+
"model": model_id,
|
| 75 |
+
"messages": [{"role": "user", "content": prompt}],
|
| 76 |
+
"max_tokens": max_tokens,
|
| 77 |
+
"temperature": temperature,
|
| 78 |
+
"top_p": top_p,
|
| 79 |
+
"stream": True,
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
# Create a generator that yields tokens
|
| 83 |
+
try:
|
| 84 |
+
with requests.post(url, headers=_auth_headers(), json=payload, stream=True, timeout=120) as r:
|
| 85 |
+
if r.status_code != 200:
|
| 86 |
+
try:
|
| 87 |
+
err_txt = r.text
|
| 88 |
+
except Exception:
|
| 89 |
+
err_txt = "<no body>"
|
| 90 |
+
raise RuntimeError(f"HTTP {r.status_code}: {err_txt}")
|
| 91 |
+
|
| 92 |
+
buffer = ""
|
| 93 |
+
for line in r.iter_lines():
|
| 94 |
+
if line:
|
| 95 |
+
decoded_line = line.decode('utf-8')
|
| 96 |
+
if decoded_line.startswith('data:'):
|
| 97 |
+
data = decoded_line[5:].strip()
|
| 98 |
+
if data == '[DONE]':
|
| 99 |
+
break
|
| 100 |
+
try:
|
| 101 |
+
json_data = json.loads(data)
|
| 102 |
+
if 'choices' in json_data:
|
| 103 |
+
for choice in json_data['choices']:
|
| 104 |
+
if 'delta' in choice and 'content' in choice['delta']:
|
| 105 |
+
content = choice['delta']['content']
|
| 106 |
+
buffer += content
|
| 107 |
+
yield content
|
| 108 |
+
except json.JSONDecodeError:
|
| 109 |
+
continue
|
| 110 |
+
if not buffer:
|
| 111 |
+
yield "No response received from the model."
|
| 112 |
+
except Exception as e:
|
| 113 |
+
raise RuntimeError(f"Streaming error: {str(e)}")
|
| 114 |
+
|
| 115 |
+
def gradient_complete(model_id, prompt, max_tokens=512, temperature=0.7, top_p=0.95):
|
| 116 |
+
url = f"{BASE_URL}/chat/completions"
|
| 117 |
+
payload = {
|
| 118 |
+
"model": model_id,
|
| 119 |
+
"messages": [{"role": "user", "content": prompt}],
|
| 120 |
+
"max_tokens": max_tokens,
|
| 121 |
+
"temperature": temperature,
|
| 122 |
+
"top_p": top_p,
|
| 123 |
+
}
|
| 124 |
+
r = requests.post(url, headers=_auth_headers(), json=payload, timeout=60)
|
| 125 |
+
if r.status_code != 200:
|
| 126 |
+
raise RuntimeError(f"HTTP {r.status_code}: {r.text}")
|
| 127 |
+
j = r.json()
|
| 128 |
+
return j["choices"][0]["message"]["content"].strip()
|
| 129 |
+
|
| 130 |
+
# ========= Lightweight Intent Detection =========
|
| 131 |
+
def detect_intent(model_id, message: str) -> str:
|
| 132 |
+
try:
|
| 133 |
+
out = gradient_request(
|
| 134 |
+
model_id,
|
| 135 |
+
f"Classify as 'small_talk' or 'info_query': {message}",
|
| 136 |
+
max_tokens=8,
|
| 137 |
+
temperature=0.0,
|
| 138 |
+
top_p=1.0,
|
| 139 |
+
)
|
| 140 |
+
return "small_talk" if "small_talk" in out.lower() else "info_query"
|
| 141 |
+
except Exception as e:
|
| 142 |
+
print(f"β οΈ detect_intent failed: {e}")
|
| 143 |
+
return "info_query"
|
| 144 |
+
|
| 145 |
+
# ========= App Logic (Gradio Blocks) =========
|
| 146 |
+
with gr.Blocks(title="Gradient AI Chat") as demo:
|
| 147 |
+
# Keep a reactive turn counter in session state
|
| 148 |
+
turn_counter = gr.State(0)
|
| 149 |
+
|
| 150 |
+
gr.Markdown("## Gradient AI Chat")
|
| 151 |
+
gr.Markdown("Select a model and ask your question.")
|
| 152 |
+
|
| 153 |
+
# Model dropdown will be populated at runtime with live IDs
|
| 154 |
+
with gr.Row():
|
| 155 |
+
model_drop = gr.Dropdown(choices=[], label="Select Model")
|
| 156 |
+
system_msg = gr.Textbox(
|
| 157 |
+
value="You are a faithful assistant. Use only the provided context.",
|
| 158 |
+
label="System message"
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
with gr.Row():
|
| 162 |
+
max_tokens_slider = gr.Slider(minimum=1, maximum=4096, value=512, step=1, label="Max new tokens")
|
| 163 |
+
temperature_slider = gr.Slider(minimum=0.0, maximum=2.0, value=0.7, step=0.1, label="Temperature")
|
| 164 |
+
top_p_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p")
|
| 165 |
+
|
| 166 |
+
# Use tuples to silence deprecation warning in current Gradio
|
| 167 |
+
chatbot = gr.Chatbot(height=500, type="tuples")
|
| 168 |
+
msg = gr.Textbox(label="Your message")
|
| 169 |
+
|
| 170 |
+
with gr.Row():
|
| 171 |
+
submit_btn = gr.Button("Submit", variant="primary")
|
| 172 |
+
clear_btn = gr.ClearButton([msg, chatbot])
|
| 173 |
+
|
| 174 |
+
examples = gr.Examples(
|
| 175 |
+
examples=[
|
| 176 |
+
["What are the advantages of llama3.3-70b-instruct?"],
|
| 177 |
+
["Explain how DeepSeek R1 Distill Llama 70B handles reasoning tasks."],
|
| 178 |
+
["What is the difference between llama3.3-70b-instruct and qwen2.5-32b-instruct?"],
|
| 179 |
+
],
|
| 180 |
+
inputs=[msg]
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
# --- Load models into dropdown at startup
|
| 184 |
+
def load_models():
|
| 185 |
+
ids = list_models()
|
| 186 |
+
default = ids[0] if ids else None
|
| 187 |
+
return gr.Dropdown(choices=ids, value=default)
|
| 188 |
+
|
| 189 |
+
demo.load(load_models, outputs=[model_drop])
|
| 190 |
+
|
| 191 |
+
# Optional warm-up so first user doesn't pay cold start cost
|
| 192 |
+
def warmup():
|
| 193 |
+
try:
|
| 194 |
+
_ = retrieve_context("warmup", p=1, threshold=0.0)
|
| 195 |
+
except Exception as e:
|
| 196 |
+
print(f"β οΈ warmup failed: {e}")
|
| 197 |
+
|
| 198 |
+
demo.load(warmup, outputs=None)
|
| 199 |
+
|
| 200 |
+
# --- Event handlers
|
| 201 |
+
def user(user_message, chat_history):
|
| 202 |
+
# Seed a new assistant message for streaming
|
| 203 |
+
return "", (chat_history + [[user_message, ""]])
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def bot(chat_history, current_turn_count, model_id, system_message, max_tokens, temperature, top_p):
|
| 207 |
+
user_message = chat_history[-1][0]
|
| 208 |
+
|
| 209 |
+
# Build prompt
|
| 210 |
+
intent = detect_intent(model_id, user_message)
|
| 211 |
+
if intent == "small_talk":
|
| 212 |
+
full_prompt = f"[System]: Friendly chat.\n[User]: {user_message}\n[Assistant]: "
|
| 213 |
+
else:
|
| 214 |
+
try:
|
| 215 |
+
context = retrieve_context(user_message, p=5, threshold=0.5)
|
| 216 |
+
except Exception as e:
|
| 217 |
+
print(f"β οΈ retrieve_context failed: {e}")
|
| 218 |
+
context = ""
|
| 219 |
+
full_prompt = (
|
| 220 |
+
f"[System]: {system_message}\n"
|
| 221 |
+
"Use only the provided context. Quote verbatim; no inference.\n\n"
|
| 222 |
+
f"Context:\n{context}\n\nQuestion: {user_message}\n"
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
# Initialize assistant message to empty string and update chat history
|
| 226 |
+
chat_history[-1][1] = ""
|
| 227 |
+
yield chat_history, current_turn_count
|
| 228 |
+
|
| 229 |
+
# Attempt to stream the response
|
| 230 |
+
try:
|
| 231 |
+
received_any = False
|
| 232 |
+
for token in gradient_stream(model_id, full_prompt, max_tokens, temperature, top_p):
|
| 233 |
+
if token: # Skip empty tokens
|
| 234 |
+
received_any = True
|
| 235 |
+
chat_history[-1][1] += token
|
| 236 |
+
yield chat_history, current_turn_count
|
| 237 |
+
# If we didn't receive any tokens, fall back to non-streaming
|
| 238 |
+
if not received_any:
|
| 239 |
+
raise RuntimeError("Streaming returned no tokens; falling back.")
|
| 240 |
+
except Exception as e:
|
| 241 |
+
print(f"β οΈ Streaming failed: {e}")
|
| 242 |
+
try:
|
| 243 |
+
# Fall back to non-streaming
|
| 244 |
+
response = gradient_complete(model_id, full_prompt, max_tokens, temperature, top_p)
|
| 245 |
+
chat_history[-1][1] = response
|
| 246 |
+
yield chat_history, current_turn_count
|
| 247 |
+
except Exception as e2:
|
| 248 |
+
chat_history[-1][1] = f"β οΈ Inference failed: {e2}"
|
| 249 |
+
yield chat_history, current_turn_count
|
| 250 |
+
return
|
| 251 |
+
|
| 252 |
+
# After successful response, log and update turn counter
|
| 253 |
+
try:
|
| 254 |
+
log_interaction_hf(user_message, chat_history[-1][1])
|
| 255 |
+
except Exception as e:
|
| 256 |
+
print(f"β οΈ log_interaction_hf failed: {e}")
|
| 257 |
+
|
| 258 |
+
new_turn_count = (current_turn_count or 0) + 1
|
| 259 |
+
# Periodically upload logs
|
| 260 |
+
if new_turn_count % UPLOAD_INTERVAL == 0:
|
| 261 |
+
try:
|
| 262 |
+
upload_log_to_hf(HF_TOKEN)
|
| 263 |
+
except Exception as e:
|
| 264 |
+
print(f"β Log upload failed: {e}")
|
| 265 |
+
|
| 266 |
+
# Update the state with the new turn count
|
| 267 |
+
yield chat_history, new_turn_count
|
| 268 |
+
|
| 269 |
+
# Wiring (streaming generators supported)
|
| 270 |
+
msg.submit(
|
| 271 |
+
user,
|
| 272 |
+
[msg, chatbot],
|
| 273 |
+
[msg, chatbot],
|
| 274 |
+
queue=True
|
| 275 |
+
).then(
|
| 276 |
+
bot,
|
| 277 |
+
[chatbot, turn_counter, model_drop, system_msg, max_tokens_slider, temperature_slider, top_p_slider],
|
| 278 |
+
[chatbot, turn_counter],
|
| 279 |
+
queue=True
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
submit_btn.click(
|
| 283 |
+
user,
|
| 284 |
+
[msg, chatbot],
|
| 285 |
+
[msg, chatbot],
|
| 286 |
+
queue=True
|
| 287 |
+
).then(
|
| 288 |
+
bot,
|
| 289 |
+
[chatbot, turn_counter, model_drop, system_msg, max_tokens_slider, temperature_slider, top_p_slider],
|
| 290 |
+
[chatbot, turn_counter],
|
| 291 |
+
queue=True
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
if __name__ == "__main__":
|
| 295 |
+
# On HF Spaces, don't use share=True. Also disable API page to avoid schema churn.
|
| 296 |
+
demo.launch(show_api=False)
|
chat_app.py
ADDED
|
@@ -0,0 +1,284 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import requests
|
| 3 |
+
import numpy as np
|
| 4 |
+
import time
|
| 5 |
+
import json
|
| 6 |
+
import os
|
| 7 |
+
|
| 8 |
+
# Import the utilities with proper error handling
|
| 9 |
+
try:
|
| 10 |
+
from utils.encoding_input import encode_text
|
| 11 |
+
from utils.retrieve_n_rerank import retrieve_and_rerank
|
| 12 |
+
from utils.sentiment_analysis import get_sentiment
|
| 13 |
+
from utils.coherence_bbscore import coherence_report
|
| 14 |
+
from utils.loading_embeddings import get_vectorstore
|
| 15 |
+
from utils.model_generation import build_messages
|
| 16 |
+
except ImportError as e:
|
| 17 |
+
print(f"Import error: {e}")
|
| 18 |
+
print("Make sure you're running from the correct directory and all dependencies are installed.")
|
| 19 |
+
|
| 20 |
+
API_KEY = os.getenv("API_KEY", "sk-do-8Hjf0liuGQCoPwglilL49xiqrthMECwjGP_kAjPM53OTOFQczPyfPK8xJc")
|
| 21 |
+
MODEL = "llama3.3-70b-instruct"
|
| 22 |
+
|
| 23 |
+
# Global settings for sentiment and coherence analysis
|
| 24 |
+
ENABLE_SENTIMENT = True
|
| 25 |
+
ENABLE_COHERENCE = True
|
| 26 |
+
|
| 27 |
+
def chat_response(message, history, enable_sentiment, enable_coherence):
|
| 28 |
+
"""
|
| 29 |
+
Generate response for chat interface.
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
message: Current user message
|
| 33 |
+
history: List of [user_message, bot_response] pairs
|
| 34 |
+
enable_sentiment: Whether to enable sentiment analysis
|
| 35 |
+
enable_coherence: Whether to enable coherence analysis
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
try:
|
| 39 |
+
# Initialize vectorstore when needed
|
| 40 |
+
vectorstore = get_vectorstore()
|
| 41 |
+
|
| 42 |
+
# Retrieve and rerank documents
|
| 43 |
+
reranked_results = retrieve_and_rerank(
|
| 44 |
+
query_text=message,
|
| 45 |
+
vectorstore=vectorstore,
|
| 46 |
+
k=50, # number of initial documents to retrieve
|
| 47 |
+
rerank_model="cross-encoder/ms-marco-MiniLM-L-6-v2",
|
| 48 |
+
top_m=20, # number of documents to return after reranking
|
| 49 |
+
min_score=0.5, # minimum score for reranked documents
|
| 50 |
+
only_docs=False # return both documents and scores
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
if not reranked_results:
|
| 54 |
+
return "I'm sorry, I couldn't find any relevant information in the policy documents to answer your question. Could you try rephrasing your question or asking about a different topic?"
|
| 55 |
+
|
| 56 |
+
top_docs = [doc for doc, score in reranked_results]
|
| 57 |
+
|
| 58 |
+
# Perform sentiment and coherence analysis if enabled
|
| 59 |
+
sentiment_rollup = get_sentiment(top_docs) if enable_sentiment else {}
|
| 60 |
+
coherence_report_ = coherence_report(reranked_results=top_docs, input_text=message) if enable_coherence else ""
|
| 61 |
+
|
| 62 |
+
# Build messages for the LLM, including conversation history
|
| 63 |
+
messages = build_messages_with_history(
|
| 64 |
+
query=message,
|
| 65 |
+
history=history,
|
| 66 |
+
top_docs=top_docs,
|
| 67 |
+
task_mode="verbatim_sentiment",
|
| 68 |
+
sentiment_rollup=sentiment_rollup,
|
| 69 |
+
coherence_report=coherence_report_,
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
# Stream response from the API
|
| 73 |
+
response = ""
|
| 74 |
+
for chunk in stream_llm_response(messages):
|
| 75 |
+
response += chunk
|
| 76 |
+
yield response
|
| 77 |
+
|
| 78 |
+
except Exception as e:
|
| 79 |
+
error_msg = f"I encountered an error while processing your request: {str(e)}"
|
| 80 |
+
yield error_msg
|
| 81 |
+
|
| 82 |
+
def build_messages_with_history(query, history, top_docs, task_mode, sentiment_rollup, coherence_report):
|
| 83 |
+
"""Build messages including conversation history for better context."""
|
| 84 |
+
|
| 85 |
+
# System message
|
| 86 |
+
system_msg = (
|
| 87 |
+
"You are a compliance-grade policy analyst assistant specializing in Kenya policy documents. "
|
| 88 |
+
"Your job is to return precise, fact-grounded responses based on the provided policy documents. "
|
| 89 |
+
"Avoid hallucinations. Base everything strictly on the content provided. "
|
| 90 |
+
"Maintain conversation context from previous exchanges when relevant. "
|
| 91 |
+
"If sentiment or coherence analysis is not available, do not mention it in the response."
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
messages = [{"role": "system", "content": system_msg}]
|
| 95 |
+
|
| 96 |
+
# Add conversation history (keep last 4 exchanges to maintain context without exceeding limits)
|
| 97 |
+
recent_history = history[-4:] if len(history) > 4 else history
|
| 98 |
+
for user_msg, bot_msg in recent_history:
|
| 99 |
+
messages.append({"role": "user", "content": user_msg})
|
| 100 |
+
messages.append({"role": "assistant", "content": bot_msg})
|
| 101 |
+
|
| 102 |
+
# Build context from retrieved documents
|
| 103 |
+
context_block = "\n\n".join([
|
| 104 |
+
f"**Source: {getattr(doc, 'metadata', {}).get('source', 'Unknown')} "
|
| 105 |
+
f"(Page {getattr(doc, 'metadata', {}).get('page', 'Unknown')})**\n"
|
| 106 |
+
f"{doc.page_content}\n"
|
| 107 |
+
for doc in top_docs[:10] # Limit to top 10 docs to avoid token limits
|
| 108 |
+
])
|
| 109 |
+
|
| 110 |
+
# Current user query with context
|
| 111 |
+
current_query = f"""
|
| 112 |
+
Query: {query}
|
| 113 |
+
|
| 114 |
+
Based on the following policy documents, please provide:
|
| 115 |
+
1) **Quoted Policy Excerpts**: Quote key policy content directly. Cite the source using filename and page.
|
| 116 |
+
2) **Analysis**: Explain the policy implications in clear terms.
|
| 117 |
+
"""
|
| 118 |
+
|
| 119 |
+
if sentiment_rollup:
|
| 120 |
+
current_query += f"\n3) **Sentiment Summary**: {sentiment_rollup}"
|
| 121 |
+
|
| 122 |
+
if coherence_report:
|
| 123 |
+
current_query += f"\n4) **Coherence Assessment**: {coherence_report}"
|
| 124 |
+
|
| 125 |
+
current_query += f"\n\nContext Sources:\n{context_block}"
|
| 126 |
+
|
| 127 |
+
messages.append({"role": "user", "content": current_query})
|
| 128 |
+
|
| 129 |
+
return messages
|
| 130 |
+
|
| 131 |
+
def stream_llm_response(messages):
|
| 132 |
+
"""Stream response from the LLM API."""
|
| 133 |
+
headers = {
|
| 134 |
+
"Authorization": f"Bearer {API_KEY}",
|
| 135 |
+
"Content-Type": "application/json"
|
| 136 |
+
}
|
| 137 |
+
|
| 138 |
+
data = {
|
| 139 |
+
"model": MODEL,
|
| 140 |
+
"messages": messages,
|
| 141 |
+
"temperature": 0.2,
|
| 142 |
+
"stream": True,
|
| 143 |
+
"max_tokens": 2000
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
try:
|
| 147 |
+
with requests.post("https://inference.do-ai.run/v1/chat/completions",
|
| 148 |
+
headers=headers, json=data, stream=True, timeout=30) as r:
|
| 149 |
+
if r.status_code != 200:
|
| 150 |
+
yield f"[ERROR] API returned status {r.status_code}: {r.text}"
|
| 151 |
+
return
|
| 152 |
+
|
| 153 |
+
for line in r.iter_lines(decode_unicode=True):
|
| 154 |
+
if not line or line.strip() == "data: [DONE]":
|
| 155 |
+
continue
|
| 156 |
+
if line.startswith("data: "):
|
| 157 |
+
line = line[len("data: "):]
|
| 158 |
+
|
| 159 |
+
try:
|
| 160 |
+
chunk = json.loads(line)
|
| 161 |
+
delta = chunk.get("choices", [{}])[0].get("delta", {}).get("content", "")
|
| 162 |
+
if delta:
|
| 163 |
+
yield delta
|
| 164 |
+
time.sleep(0.01) # Small delay for smooth streaming
|
| 165 |
+
except json.JSONDecodeError:
|
| 166 |
+
continue
|
| 167 |
+
except Exception as e:
|
| 168 |
+
print(f"Streaming error: {e}")
|
| 169 |
+
continue
|
| 170 |
+
|
| 171 |
+
except requests.exceptions.RequestException as e:
|
| 172 |
+
yield f"[ERROR] Network error: {str(e)}"
|
| 173 |
+
except Exception as e:
|
| 174 |
+
yield f"[ERROR] Unexpected error: {str(e)}"
|
| 175 |
+
|
| 176 |
+
def update_sentiment_setting(enable):
|
| 177 |
+
"""Update global sentiment analysis setting."""
|
| 178 |
+
global ENABLE_SENTIMENT
|
| 179 |
+
ENABLE_SENTIMENT = enable
|
| 180 |
+
return f"Sentiment analysis {'enabled' if enable else 'disabled'}"
|
| 181 |
+
|
| 182 |
+
def update_coherence_setting(enable):
|
| 183 |
+
"""Update global coherence analysis setting."""
|
| 184 |
+
global ENABLE_COHERENCE
|
| 185 |
+
ENABLE_COHERENCE = enable
|
| 186 |
+
return f"Coherence analysis {'enabled' if enable else 'disabled'}"
|
| 187 |
+
|
| 188 |
+
# Create the chat interface
|
| 189 |
+
with gr.Blocks(title="Kenya Policy Assistant - Chat", theme=gr.themes.Soft()) as demo:
|
| 190 |
+
gr.Markdown("""
|
| 191 |
+
# ποΈ Kenya Policy Assistant - Interactive Chat
|
| 192 |
+
Ask questions about Kenya's policies and have a conversation! I can help you understand policy documents with sentiment and coherence analysis.
|
| 193 |
+
""")
|
| 194 |
+
|
| 195 |
+
with gr.Row():
|
| 196 |
+
with gr.Column(scale=3):
|
| 197 |
+
# Main chat interface
|
| 198 |
+
chatbot = gr.Chatbot(
|
| 199 |
+
height=600,
|
| 200 |
+
bubble_full_width=False,
|
| 201 |
+
show_copy_button=True,
|
| 202 |
+
show_share_button=True
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
with gr.Row():
|
| 206 |
+
sentiment_toggle = gr.Checkbox(
|
| 207 |
+
label="Enable Sentiment Analysis",
|
| 208 |
+
value=True,
|
| 209 |
+
info="Analyze the tone and sentiment of policy documents"
|
| 210 |
+
)
|
| 211 |
+
coherence_toggle = gr.Checkbox(
|
| 212 |
+
label="Enable Coherence Analysis",
|
| 213 |
+
value=True,
|
| 214 |
+
info="Check coherence and consistency of retrieved documents"
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
with gr.Column(scale=1):
|
| 218 |
+
gr.Markdown("""
|
| 219 |
+
### π‘ Tips for Better Results
|
| 220 |
+
- Ask specific questions about Kenya policies
|
| 221 |
+
- You can ask follow-up questions
|
| 222 |
+
- Reference previous answers in your questions
|
| 223 |
+
- Use phrases like "What does this mean?" or "Can you elaborate?"
|
| 224 |
+
|
| 225 |
+
### π Example Questions
|
| 226 |
+
- "What are Kenya's renewable energy policies?"
|
| 227 |
+
- "Tell me about water management regulations"
|
| 228 |
+
- "What penalties exist for environmental violations?"
|
| 229 |
+
- "How does this relate to what you just mentioned?"
|
| 230 |
+
""")
|
| 231 |
+
|
| 232 |
+
with gr.Accordion("βοΈ Settings", open=False):
|
| 233 |
+
gr.Markdown("Toggle analysis features on/off")
|
| 234 |
+
sentiment_status = gr.Textbox(
|
| 235 |
+
value="Sentiment analysis enabled",
|
| 236 |
+
label="Sentiment Status",
|
| 237 |
+
interactive=False
|
| 238 |
+
)
|
| 239 |
+
coherence_status = gr.Textbox(
|
| 240 |
+
value="Coherence analysis enabled",
|
| 241 |
+
label="Coherence Status",
|
| 242 |
+
interactive=False
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
# Create the chat interface with custom response function
|
| 246 |
+
chat_interface = gr.ChatInterface(
|
| 247 |
+
fn=lambda message, history: chat_response(message, history, ENABLE_SENTIMENT, ENABLE_COHERENCE),
|
| 248 |
+
chatbot=chatbot,
|
| 249 |
+
title="", # We already have a title above
|
| 250 |
+
description="", # We already have description above
|
| 251 |
+
examples=[
|
| 252 |
+
"What are the objectives of Kenya's energy policies?",
|
| 253 |
+
"Tell me about environmental protection regulations",
|
| 254 |
+
"What are the penalties for water pollution?",
|
| 255 |
+
"How are renewable energy projects regulated?",
|
| 256 |
+
"What does the constitution say about natural resources?"
|
| 257 |
+
],
|
| 258 |
+
cache_examples=False,
|
| 259 |
+
retry_btn="π Retry",
|
| 260 |
+
undo_btn="β©οΈ Undo",
|
| 261 |
+
clear_btn="ποΈ Clear Chat"
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
# Update settings when toggles change
|
| 265 |
+
sentiment_toggle.change(
|
| 266 |
+
fn=update_sentiment_setting,
|
| 267 |
+
inputs=[sentiment_toggle],
|
| 268 |
+
outputs=[sentiment_status]
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
coherence_toggle.change(
|
| 272 |
+
fn=update_coherence_setting,
|
| 273 |
+
inputs=[coherence_toggle],
|
| 274 |
+
outputs=[coherence_status]
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
if __name__ == "__main__":
|
| 278 |
+
print("π Starting Kenya Policy Assistant Chat...")
|
| 279 |
+
demo.queue(max_size=20).launch(
|
| 280 |
+
share=True,
|
| 281 |
+
debug=True,
|
| 282 |
+
server_name="0.0.0.0",
|
| 283 |
+
server_port=7860
|
| 284 |
+
)
|
config.json
CHANGED
|
@@ -1,5 +1,26 @@
|
|
| 1 |
{
|
| 2 |
"token": "tzcuKyLTBCzYzgPZXypkfiGswkewHvDjMK",
|
| 3 |
"hf": "hf_",
|
| 4 |
-
"do_token": "sk-do-8Hjf0liuGQCoPwglilL49xiqrthMECwjGP_kAjPM53OTOFQczPyfPK8xJc"
|
| 5 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
{
|
| 2 |
"token": "tzcuKyLTBCzYzgPZXypkfiGswkewHvDjMK",
|
| 3 |
"hf": "hf_",
|
| 4 |
+
"do_token": "sk-do-8Hjf0liuGQCoPwglilL49xiqrthMECwjGP_kAjPM53OTOFQczPyfPK8xJc",
|
| 5 |
+
|
| 6 |
+
"apiKey": "SensorDxKenya",
|
| 7 |
+
"apiSecret": "6GUXzKi#wvDvZ",
|
| 8 |
+
"map_api_key": "AIzaSyC1mHwJ_f2Wi8o-zt5N69lW3tgQZPlJTWE",
|
| 9 |
+
"weather_api_key": "AIzaSyAryn2T6hlQg7XmjTtGBfQkvTWQ8Ablkrs",
|
| 10 |
+
"spaces_url" : "https://forecasting-data.ams3.digitaloceanspaces.com",
|
| 11 |
+
"spaces_access_key": "DO801FGLVD99HMRHMMAF",
|
| 12 |
+
"spaces_secret_key": "rKhzUx/C9+0cfm61f3mnCOY/O3ncf9OJq01O4N8hzjc",
|
| 13 |
+
"spaces_bucket_endpoint" : "https://forecasting-data.ams3.digitaloceanspaces.com",
|
| 14 |
+
|
| 15 |
+
"TAHMO_API_KEY": "SensorDxKenya",
|
| 16 |
+
"TAHMO_API_SECRET": "6GUXzKi#wvDvZ",
|
| 17 |
+
"WEATHER_API_KEY": "AIzaSyAryn2T6hlQg7XmjTtGBfQkvTWQ8Ablkrs",
|
| 18 |
+
"MAPS_KEY": "AIzaSyC1mHwJ_f2Wi8o-zt5N69lW3tgQZPlJTWE",
|
| 19 |
+
"SPACES_KEY": "DO00EXAMPLEACCESSKEY",
|
| 20 |
+
"SPACES_SECRET": "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLESECRET",
|
| 21 |
+
"SPACES_BUCKET": "forecasting-data",
|
| 22 |
+
"SPACES_REGION": "ams3",
|
| 23 |
+
"OBJECT_PREFIX": "time-forecasts/",
|
| 24 |
+
"WORKERS": 4,
|
| 25 |
+
"access_token_deploy" : "dop_v1_44a3c7084fc02f7af8b215c18d6b2145924d37df36eb63b4b199039031bcad5c"
|
| 26 |
+
}
|
docker-compose.yml
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version: '3.8'
|
| 2 |
+
|
| 3 |
+
services:
|
| 4 |
+
policy-analysis:
|
| 5 |
+
build:
|
| 6 |
+
context: .
|
| 7 |
+
dockerfile: Dockerfile
|
| 8 |
+
ports:
|
| 9 |
+
- "7860:7860"
|
| 10 |
+
environment:
|
| 11 |
+
- PRELOAD_MODELS=false # Models are already cached in the image
|
| 12 |
+
volumes:
|
| 13 |
+
- model_cache:/root/.cache/huggingface # Optional: persist model cache
|
| 14 |
+
restart: unless-stopped
|
| 15 |
+
healthcheck:
|
| 16 |
+
test: ["CMD", "curl", "-f", "http://localhost:7860/health"]
|
| 17 |
+
interval: 30s
|
| 18 |
+
timeout: 10s
|
| 19 |
+
retries: 3
|
| 20 |
+
start_period: 40s
|
| 21 |
+
|
| 22 |
+
volumes:
|
| 23 |
+
model_cache:
|
download_models.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Model Downloader Script for Policy Analysis Application
|
| 4 |
+
|
| 5 |
+
This script pre-downloads all the ML models used in the application
|
| 6 |
+
to reduce inference time during runtime.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import os
|
| 10 |
+
import sys
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
|
| 13 |
+
def download_huggingface_models():
|
| 14 |
+
"""Download all HuggingFace models used in the application."""
|
| 15 |
+
|
| 16 |
+
# List of all models used in the application
|
| 17 |
+
models_to_download = {
|
| 18 |
+
# Sentence Transformers / Embedding Models
|
| 19 |
+
"sentence-transformers/all-MiniLM-L6-v2": "sentence_transformers",
|
| 20 |
+
"BAAI/bge-m3": "sentence_transformers",
|
| 21 |
+
|
| 22 |
+
# Cross-Encoder Models
|
| 23 |
+
"cross-encoder/ms-marco-MiniLM-L-6-v2": "sentence_transformers",
|
| 24 |
+
|
| 25 |
+
# Zero-shot Classification Models
|
| 26 |
+
"MoritzLaurer/deberta-v3-base-zeroshot-v2.0": "transformers",
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
print("π Starting model download process...")
|
| 30 |
+
print(f"π Models will be cached in: {os.path.expanduser('~/.cache/huggingface')}")
|
| 31 |
+
print("=" * 60)
|
| 32 |
+
|
| 33 |
+
for model_name, library in models_to_download.items():
|
| 34 |
+
print(f"\nπ¦ Downloading {model_name}...")
|
| 35 |
+
try:
|
| 36 |
+
if library == "sentence_transformers":
|
| 37 |
+
download_sentence_transformer(model_name)
|
| 38 |
+
elif library == "transformers":
|
| 39 |
+
download_transformers_model(model_name)
|
| 40 |
+
print(f"β
Successfully downloaded {model_name}")
|
| 41 |
+
except Exception as e:
|
| 42 |
+
print(f"β Failed to download {model_name}: {e}")
|
| 43 |
+
continue
|
| 44 |
+
|
| 45 |
+
print("\n" + "=" * 60)
|
| 46 |
+
print("π Model download process completed!")
|
| 47 |
+
print("π‘ All models are now cached locally for faster inference.")
|
| 48 |
+
|
| 49 |
+
def download_sentence_transformer(model_name):
|
| 50 |
+
"""Download a sentence transformer model."""
|
| 51 |
+
try:
|
| 52 |
+
from sentence_transformers import SentenceTransformer
|
| 53 |
+
print(f" Loading {model_name}...")
|
| 54 |
+
model = SentenceTransformer(model_name)
|
| 55 |
+
# Test encode to ensure model works
|
| 56 |
+
_ = model.encode(["test sentence"], show_progress_bar=False)
|
| 57 |
+
print(f" β Model loaded and tested successfully")
|
| 58 |
+
except ImportError:
|
| 59 |
+
print(f" β οΈ sentence-transformers not installed, skipping {model_name}")
|
| 60 |
+
raise
|
| 61 |
+
except Exception as e:
|
| 62 |
+
print(f" β Error downloading {model_name}: {e}")
|
| 63 |
+
raise
|
| 64 |
+
|
| 65 |
+
def download_transformers_model(model_name):
|
| 66 |
+
"""Download a transformers model using pipeline."""
|
| 67 |
+
try:
|
| 68 |
+
from transformers import pipeline
|
| 69 |
+
print(f" Loading {model_name}...")
|
| 70 |
+
|
| 71 |
+
# Load the model based on its intended use
|
| 72 |
+
if "zeroshot" in model_name.lower() or "deberta" in model_name.lower():
|
| 73 |
+
pipe = pipeline("zero-shot-classification", model=model_name, device=-1)
|
| 74 |
+
# Test the pipeline
|
| 75 |
+
_ = pipe("test text", ["test label"])
|
| 76 |
+
else:
|
| 77 |
+
# Generic text classification pipeline
|
| 78 |
+
pipe = pipeline("text-classification", model=model_name, device=-1)
|
| 79 |
+
|
| 80 |
+
print(f" β Model loaded and tested successfully")
|
| 81 |
+
except ImportError:
|
| 82 |
+
print(f" β οΈ transformers not installed, skipping {model_name}")
|
| 83 |
+
raise
|
| 84 |
+
except Exception as e:
|
| 85 |
+
print(f" β Error downloading {model_name}: {e}")
|
| 86 |
+
raise
|
| 87 |
+
|
| 88 |
+
def download_cross_encoder(model_name):
|
| 89 |
+
"""Download a cross-encoder model."""
|
| 90 |
+
try:
|
| 91 |
+
from sentence_transformers import CrossEncoder
|
| 92 |
+
print(f" Loading {model_name}...")
|
| 93 |
+
model = CrossEncoder(model_name)
|
| 94 |
+
# Test prediction to ensure model works
|
| 95 |
+
_ = model.predict([("test query", "test document")])
|
| 96 |
+
print(f" β Model loaded and tested successfully")
|
| 97 |
+
except ImportError:
|
| 98 |
+
print(f" β οΈ sentence-transformers not installed, skipping {model_name}")
|
| 99 |
+
raise
|
| 100 |
+
except Exception as e:
|
| 101 |
+
print(f" β Error downloading {model_name}: {e}")
|
| 102 |
+
raise
|
| 103 |
+
|
| 104 |
+
def check_dependencies():
|
| 105 |
+
"""Check if required packages are installed."""
|
| 106 |
+
required_packages = [
|
| 107 |
+
("sentence_transformers", "sentence-transformers"),
|
| 108 |
+
("transformers", "transformers"),
|
| 109 |
+
("torch", "torch"),
|
| 110 |
+
("numpy", "numpy"),
|
| 111 |
+
("requests", "requests")
|
| 112 |
+
]
|
| 113 |
+
|
| 114 |
+
missing_packages = []
|
| 115 |
+
for package, pip_name in required_packages:
|
| 116 |
+
try:
|
| 117 |
+
__import__(package)
|
| 118 |
+
except ImportError:
|
| 119 |
+
missing_packages.append(pip_name)
|
| 120 |
+
|
| 121 |
+
if missing_packages:
|
| 122 |
+
print("β Missing required packages:")
|
| 123 |
+
for package in missing_packages:
|
| 124 |
+
print(f" - {package}")
|
| 125 |
+
print("\nπ‘ Install missing packages with:")
|
| 126 |
+
print(f" pip install {' '.join(missing_packages)}")
|
| 127 |
+
return False
|
| 128 |
+
|
| 129 |
+
return True
|
| 130 |
+
|
| 131 |
+
if __name__ == "__main__":
|
| 132 |
+
print("π€ Policy Analysis Model Downloader")
|
| 133 |
+
print("=" * 60)
|
| 134 |
+
|
| 135 |
+
# Check dependencies first
|
| 136 |
+
if not check_dependencies():
|
| 137 |
+
sys.exit(1)
|
| 138 |
+
|
| 139 |
+
# Download all models
|
| 140 |
+
download_huggingface_models()
|
| 141 |
+
|
| 142 |
+
print("\nπ₯ Ready to deploy! All models are cached locally.")
|
requirements.txt
CHANGED
|
@@ -10,4 +10,9 @@ langchain-community>=0.0.30
|
|
| 10 |
pydantic==2.10.6
|
| 11 |
numpy
|
| 12 |
requests
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
|
|
|
| 10 |
pydantic==2.10.6
|
| 11 |
numpy
|
| 12 |
requests
|
| 13 |
+
boto3
|
| 14 |
+
rank-bm25
|
| 15 |
+
pypdf
|
| 16 |
+
Pillow
|
| 17 |
+
pytesseract
|
| 18 |
|
startup.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Startup Script for Policy Analysis Application
|
| 4 |
+
|
| 5 |
+
This script ensures all required models are downloaded before starting the application.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
import sys
|
| 10 |
+
import subprocess
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
|
| 13 |
+
def run_model_downloader():
|
| 14 |
+
"""Run the model downloader script."""
|
| 15 |
+
script_dir = Path(__file__).parent
|
| 16 |
+
downloader_script = script_dir / "download_models.py"
|
| 17 |
+
|
| 18 |
+
if not downloader_script.exists():
|
| 19 |
+
print("β Model downloader script not found!")
|
| 20 |
+
return False
|
| 21 |
+
|
| 22 |
+
print("π Ensuring all models are downloaded...")
|
| 23 |
+
try:
|
| 24 |
+
result = subprocess.run([sys.executable, str(downloader_script)],
|
| 25 |
+
capture_output=True, text=True, check=True)
|
| 26 |
+
print(result.stdout)
|
| 27 |
+
return True
|
| 28 |
+
except subprocess.CalledProcessError as e:
|
| 29 |
+
print("β Error running model downloader:")
|
| 30 |
+
print(e.stdout)
|
| 31 |
+
print(e.stderr)
|
| 32 |
+
return False
|
| 33 |
+
|
| 34 |
+
def start_application():
|
| 35 |
+
"""Start the main application."""
|
| 36 |
+
print("π Starting Policy Analysis Application...")
|
| 37 |
+
|
| 38 |
+
# Import and run the main app
|
| 39 |
+
try:
|
| 40 |
+
from app import demo
|
| 41 |
+
demo.queue().launch(share=True, debug=True)
|
| 42 |
+
except ImportError as e:
|
| 43 |
+
print(f"β Failed to import app: {e}")
|
| 44 |
+
sys.exit(1)
|
| 45 |
+
|
| 46 |
+
if __name__ == "__main__":
|
| 47 |
+
print("π€ Policy Analysis Application Startup")
|
| 48 |
+
print("=" * 50)
|
| 49 |
+
|
| 50 |
+
# Download models first
|
| 51 |
+
if not run_model_downloader():
|
| 52 |
+
print("β οΈ Model download failed, but continuing anyway...")
|
| 53 |
+
|
| 54 |
+
# Start the application
|
| 55 |
+
start_application()
|
test_imports.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Quick test script to verify the import fixes work
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
def test_imports():
|
| 7 |
+
"""Test that all utils modules can be imported correctly."""
|
| 8 |
+
try:
|
| 9 |
+
print("Testing utils imports...")
|
| 10 |
+
|
| 11 |
+
print(" - Importing utils.encoding_input...")
|
| 12 |
+
from utils.encoding_input import encode_text
|
| 13 |
+
|
| 14 |
+
print(" - Importing utils.loading_embeddings...")
|
| 15 |
+
from utils.loading_embeddings import get_vectorstore
|
| 16 |
+
|
| 17 |
+
print(" - Importing utils.retrieve_n_rerank...")
|
| 18 |
+
from utils.retrieve_n_rerank import retrieve_and_rerank
|
| 19 |
+
|
| 20 |
+
print(" - Importing utils.sentiment_analysis...")
|
| 21 |
+
from utils.sentiment_analysis import get_sentiment
|
| 22 |
+
|
| 23 |
+
print(" - Importing utils.coherence_bbscore...")
|
| 24 |
+
from utils.coherence_bbscore import coherence_report
|
| 25 |
+
|
| 26 |
+
print(" - Importing utils.model_generation...")
|
| 27 |
+
from utils.model_generation import build_messages
|
| 28 |
+
|
| 29 |
+
print(" - Importing utils.generation_streaming...")
|
| 30 |
+
from utils.generation_streaming import generate_response_stream
|
| 31 |
+
|
| 32 |
+
print("β
All imports successful!")
|
| 33 |
+
return True
|
| 34 |
+
|
| 35 |
+
except Exception as e:
|
| 36 |
+
print(f"β Import failed: {e}")
|
| 37 |
+
import traceback
|
| 38 |
+
traceback.print_exc()
|
| 39 |
+
return False
|
| 40 |
+
|
| 41 |
+
if __name__ == "__main__":
|
| 42 |
+
print("π Testing import fixes...")
|
| 43 |
+
if test_imports():
|
| 44 |
+
print("π Ready to run the application!")
|
| 45 |
+
else:
|
| 46 |
+
print("π₯ Still have import issues to fix.")
|
test_models.py
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Model Verification Script
|
| 4 |
+
|
| 5 |
+
This script tests all models used in the application to ensure they're working correctly.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import sys
|
| 9 |
+
from typing import Dict, Any
|
| 10 |
+
|
| 11 |
+
def test_sentence_transformers():
|
| 12 |
+
"""Test sentence transformer models."""
|
| 13 |
+
results = {}
|
| 14 |
+
|
| 15 |
+
models_to_test = [
|
| 16 |
+
"sentence-transformers/all-MiniLM-L6-v2",
|
| 17 |
+
"BAAI/bge-m3"
|
| 18 |
+
]
|
| 19 |
+
|
| 20 |
+
try:
|
| 21 |
+
from sentence_transformers import SentenceTransformer
|
| 22 |
+
|
| 23 |
+
for model_name in models_to_test:
|
| 24 |
+
try:
|
| 25 |
+
print(f"Testing {model_name}...")
|
| 26 |
+
model = SentenceTransformer(model_name)
|
| 27 |
+
embeddings = model.encode(["This is a test sentence."], show_progress_bar=False)
|
| 28 |
+
|
| 29 |
+
if embeddings is not None and len(embeddings) > 0:
|
| 30 |
+
results[model_name] = "β
PASS"
|
| 31 |
+
print(f" β
{model_name} working correctly")
|
| 32 |
+
else:
|
| 33 |
+
results[model_name] = "β FAIL - No embeddings generated"
|
| 34 |
+
print(f" β {model_name} failed to generate embeddings")
|
| 35 |
+
|
| 36 |
+
except Exception as e:
|
| 37 |
+
results[model_name] = f"β FAIL - {str(e)}"
|
| 38 |
+
print(f" β {model_name} failed: {e}")
|
| 39 |
+
|
| 40 |
+
except ImportError:
|
| 41 |
+
results["sentence-transformers"] = "β FAIL - Package not installed"
|
| 42 |
+
print("β sentence-transformers package not installed")
|
| 43 |
+
|
| 44 |
+
return results
|
| 45 |
+
|
| 46 |
+
def test_cross_encoder():
|
| 47 |
+
"""Test cross-encoder model."""
|
| 48 |
+
results = {}
|
| 49 |
+
model_name = "cross-encoder/ms-marco-MiniLM-L-6-v2"
|
| 50 |
+
|
| 51 |
+
try:
|
| 52 |
+
from sentence_transformers import CrossEncoder
|
| 53 |
+
|
| 54 |
+
print(f"Testing {model_name}...")
|
| 55 |
+
model = CrossEncoder(model_name)
|
| 56 |
+
scores = model.predict([("test query", "test document")])
|
| 57 |
+
|
| 58 |
+
if scores is not None and len(scores) > 0:
|
| 59 |
+
results[model_name] = "β
PASS"
|
| 60 |
+
print(f" β
{model_name} working correctly")
|
| 61 |
+
else:
|
| 62 |
+
results[model_name] = "β FAIL - No scores generated"
|
| 63 |
+
print(f" β {model_name} failed to generate scores")
|
| 64 |
+
|
| 65 |
+
except ImportError:
|
| 66 |
+
results["cross-encoder"] = "β FAIL - sentence-transformers not installed"
|
| 67 |
+
print("β sentence-transformers package not installed")
|
| 68 |
+
except Exception as e:
|
| 69 |
+
results[model_name] = f"β FAIL - {str(e)}"
|
| 70 |
+
print(f" β {model_name} failed: {e}")
|
| 71 |
+
|
| 72 |
+
return results
|
| 73 |
+
|
| 74 |
+
def test_transformers_pipeline():
|
| 75 |
+
"""Test transformers pipeline."""
|
| 76 |
+
results = {}
|
| 77 |
+
model_name = "MoritzLaurer/deberta-v3-base-zeroshot-v2.0"
|
| 78 |
+
|
| 79 |
+
try:
|
| 80 |
+
from transformers import pipeline
|
| 81 |
+
|
| 82 |
+
print(f"Testing {model_name}...")
|
| 83 |
+
classifier = pipeline(
|
| 84 |
+
"zero-shot-classification",
|
| 85 |
+
model=model_name,
|
| 86 |
+
device=-1 # CPU
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
result = classifier(
|
| 90 |
+
"This is a test sentence about policy.",
|
| 91 |
+
["policy", "technology", "sports"]
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
if result and 'labels' in result and len(result['labels']) > 0:
|
| 95 |
+
results[model_name] = "β
PASS"
|
| 96 |
+
print(f" β
{model_name} working correctly")
|
| 97 |
+
else:
|
| 98 |
+
results[model_name] = "β FAIL - No classification result"
|
| 99 |
+
print(f" β {model_name} failed to classify")
|
| 100 |
+
|
| 101 |
+
except ImportError:
|
| 102 |
+
results["transformers"] = "β FAIL - transformers package not installed"
|
| 103 |
+
print("β transformers package not installed")
|
| 104 |
+
except Exception as e:
|
| 105 |
+
results[model_name] = f"β FAIL - {str(e)}"
|
| 106 |
+
print(f" β {model_name} failed: {e}")
|
| 107 |
+
|
| 108 |
+
return results
|
| 109 |
+
|
| 110 |
+
def test_application_modules():
|
| 111 |
+
"""Test that application modules can be imported."""
|
| 112 |
+
results = {}
|
| 113 |
+
|
| 114 |
+
modules_to_test = [
|
| 115 |
+
"utils.encoding_input",
|
| 116 |
+
"utils.loading_embeddings",
|
| 117 |
+
"utils.retrieve_n_rerank",
|
| 118 |
+
"utils.sentiment_analysis",
|
| 119 |
+
"utils.coherence_bbscore",
|
| 120 |
+
"utils.model_generation",
|
| 121 |
+
"utils.generation_streaming"
|
| 122 |
+
]
|
| 123 |
+
|
| 124 |
+
for module_name in modules_to_test:
|
| 125 |
+
try:
|
| 126 |
+
__import__(module_name)
|
| 127 |
+
results[module_name] = "β
PASS"
|
| 128 |
+
print(f"β
{module_name} imported successfully")
|
| 129 |
+
except ImportError as e:
|
| 130 |
+
results[module_name] = f"β FAIL - {str(e)}"
|
| 131 |
+
print(f"β {module_name} import failed: {e}")
|
| 132 |
+
except Exception as e:
|
| 133 |
+
results[module_name] = f"β FAIL - {str(e)}"
|
| 134 |
+
print(f"β {module_name} error: {e}")
|
| 135 |
+
|
| 136 |
+
return results
|
| 137 |
+
|
| 138 |
+
def main():
|
| 139 |
+
"""Run all tests."""
|
| 140 |
+
print("π§ͺ Model Verification Test Suite")
|
| 141 |
+
print("=" * 50)
|
| 142 |
+
|
| 143 |
+
all_results = {}
|
| 144 |
+
|
| 145 |
+
print("\nπ¦ Testing Sentence Transformers...")
|
| 146 |
+
all_results.update(test_sentence_transformers())
|
| 147 |
+
|
| 148 |
+
print("\nπ Testing Cross Encoder...")
|
| 149 |
+
all_results.update(test_cross_encoder())
|
| 150 |
+
|
| 151 |
+
print("\nπ€ Testing Transformers Pipeline...")
|
| 152 |
+
all_results.update(test_transformers_pipeline())
|
| 153 |
+
|
| 154 |
+
print("\nπ Testing Application Modules...")
|
| 155 |
+
all_results.update(test_application_modules())
|
| 156 |
+
|
| 157 |
+
# Summary
|
| 158 |
+
print("\n" + "=" * 50)
|
| 159 |
+
print("π TEST SUMMARY")
|
| 160 |
+
print("=" * 50)
|
| 161 |
+
|
| 162 |
+
passed = 0
|
| 163 |
+
failed = 0
|
| 164 |
+
|
| 165 |
+
for name, result in all_results.items():
|
| 166 |
+
print(f"{result} {name}")
|
| 167 |
+
if "β
PASS" in result:
|
| 168 |
+
passed += 1
|
| 169 |
+
else:
|
| 170 |
+
failed += 1
|
| 171 |
+
|
| 172 |
+
print(f"\nπ Results: {passed} passed, {failed} failed")
|
| 173 |
+
|
| 174 |
+
if failed == 0:
|
| 175 |
+
print("π All tests passed! The application is ready to deploy.")
|
| 176 |
+
return 0
|
| 177 |
+
else:
|
| 178 |
+
print("β οΈ Some tests failed. Please check the errors above.")
|
| 179 |
+
return 1
|
| 180 |
+
|
| 181 |
+
if __name__ == "__main__":
|
| 182 |
+
sys.exit(main())
|