czech-correction / api_client.py
asdfasdfdsafdsa's picture
Upload 2 files
508f678 verified
"""
Client for Czech text correction API with local server auto-start
"""
import requests
import time
import subprocess
import os
import sys
from pathlib import Path
from typing import Optional, Dict, List, Any
import logging
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class CzechCorrectionClient:
"""Client for Czech text correction with automatic local server startup"""
# Local API endpoint only
LOCAL_ENDPOINT = {
"name": "Local",
"base_url": "http://localhost:8042",
"timeout": 3600 # 1 hour for local (grammar correction can be slow)
}
def __init__(self, prefer_local: bool = True):
"""
Initialize the client
Args:
prefer_local: Deprecated, always uses local API now
"""
self.endpoint = self.LOCAL_ENDPOINT
self._working_endpoint = None
self._last_health_check = 0
self.health_check_interval = 3600 # Cache endpoint for 1 hour
self._server_process = None
def _check_endpoint_health(self, endpoint: Dict) -> bool:
"""Check if an endpoint is healthy"""
try:
response = requests.get(
f"{endpoint['base_url']}/api/health",
timeout=10 # Increased timeout for health check
)
if response.status_code == 200:
data = response.json()
return data.get('status') == 'healthy'
except Exception as e:
logger.debug(f"Health check failed for {endpoint['name']}: {e}")
return False
def _is_port_in_use(self, port: int) -> bool:
"""Check if a port is already in use"""
import socket
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
try:
s.bind(('localhost', port))
return False
except OSError:
return True
def _start_local_server(self) -> bool:
"""Start the local API server if not already running"""
try:
# Check if port 8042 is already in use
if self._is_port_in_use(8042):
logger.warning("Port 8042 is already in use - server may already be running")
# Wait a bit and check health again
time.sleep(2)
if self._check_endpoint_health(self.endpoint):
logger.info("βœ… Server is already running on port 8042")
return True
else:
logger.error("Port 8042 is in use but server is not responding to health checks")
return False
# Find the api_service directory
current_file = Path(__file__).resolve()
api_service_dir = current_file.parent
api_script = api_service_dir / "api.py"
if not api_script.exists():
logger.error(f"API script not found at {api_script}")
return False
logger.info("Starting local API server...")
logger.info("This may take 1-2 minutes to load models...")
# Start the server in the background
env = os.environ.copy()
env['PORT'] = '8042' # Set port to 8042
self._server_process = subprocess.Popen(
[sys.executable, str(api_script)],
cwd=str(api_service_dir),
env=env,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
start_new_session=True
)
# Wait for server to be ready (up to 2 minutes)
max_wait = 120
start_time = time.time()
while time.time() - start_time < max_wait:
if self._check_endpoint_health(self.endpoint):
logger.info("βœ… Local API server started successfully")
return True
time.sleep(2)
logger.error("Server failed to start within timeout")
return False
except Exception as e:
logger.error(f"Failed to start local server: {e}")
return False
def _get_working_endpoint(self) -> Optional[Dict]:
"""Get working endpoint, starting server if needed"""
current_time = time.time()
# Use cached endpoint if still valid
if self._working_endpoint and (current_time - self._last_health_check < self.health_check_interval):
return self._working_endpoint
# Check if local server is running
if self._check_endpoint_health(self.endpoint):
logger.info(f"Using {self.endpoint['name']} API endpoint")
self._working_endpoint = self.endpoint
self._last_health_check = current_time
return self.endpoint
# Try to start the server
logger.info("Local API server not running, attempting to start...")
if self._start_local_server():
self._working_endpoint = self.endpoint
self._last_health_check = current_time
return self.endpoint
logger.error("Could not start or connect to local API server")
return None
def correct_text(self, text: str, include_timing: bool = False) -> Dict[str, Any]:
"""
Correct Czech text (grammar and punctuation)
Args:
text: Text to correct
include_timing: Whether to include processing time in response
Returns:
Dict with 'success', 'corrected_text', and optionally 'processing_time_ms'
"""
if not text or not text.strip():
return {
"success": True,
"corrected_text": text,
"error": None
}
endpoint = self._get_working_endpoint()
if not endpoint:
return {
"success": False,
"corrected_text": text,
"error": "Could not start or connect to local API server"
}
try:
payload = {
"text": text,
"options": {"include_timing": include_timing}
}
response = requests.post(
f"{endpoint['base_url']}/api/correct",
json=payload,
timeout=endpoint['timeout']
)
if response.status_code == 200:
return response.json()
else:
return {
"success": False,
"corrected_text": text,
"error": f"API error: {response.status_code}"
}
except requests.exceptions.Timeout:
logger.warning(f"Timeout on {endpoint['name']} API")
return {
"success": False,
"corrected_text": text,
"error": "Request timeout"
}
except Exception as e:
logger.error(f"Error calling API: {e}")
return {
"success": False,
"corrected_text": text,
"error": str(e)
}
def correct_batch(self, texts: List[str], include_timing: bool = False) -> Dict[str, Any]:
"""
Correct multiple Czech texts in batch
Args:
texts: List of texts to correct (max 10)
include_timing: Whether to include processing time
Returns:
Dict with 'success', 'corrected_texts', and optionally 'processing_time_ms'
"""
if not texts:
return {
"success": True,
"corrected_texts": [],
"error": None
}
if len(texts) > 10:
return {
"success": False,
"corrected_texts": texts,
"error": "Batch size exceeds limit (10)"
}
endpoint = self._get_working_endpoint()
if not endpoint:
return {
"success": False,
"corrected_texts": texts,
"error": "Could not start or connect to local API server"
}
try:
payload = {
"texts": texts,
"options": {"include_timing": include_timing}
}
response = requests.post(
f"{endpoint['base_url']}/api/correct/batch",
json=payload,
timeout=endpoint['timeout'] * 2 # Longer timeout for batch
)
if response.status_code == 200:
return response.json()
else:
# Fallback to individual corrections
logger.warning(f"Batch API failed, falling back to individual corrections")
corrected_texts = []
for text in texts:
result = self.correct_text(text, include_timing=False)
corrected_texts.append(result.get('corrected_text', text))
return {
"success": True,
"corrected_texts": corrected_texts,
"error": None
}
except Exception as e:
logger.error(f"Error calling batch API: {e}")
# Fallback to individual corrections
corrected_texts = []
for text in texts:
result = self.correct_text(text, include_timing=False)
corrected_texts.append(result.get('corrected_text', text))
return {
"success": True,
"corrected_texts": corrected_texts,
"error": None
}
# Convenience functions for backward compatibility
_default_client = None
def get_client(prefer_local: bool = True) -> CzechCorrectionClient:
"""Get or create the default client (always uses local now)"""
global _default_client
if _default_client is None:
_default_client = CzechCorrectionClient(prefer_local=True)
return _default_client
def correct_text(text: str, prefer_local: bool = True) -> str:
"""Simple function for text correction (always uses local now)"""
client = get_client(prefer_local=True)
result = client.correct_text(text)
if result['success']:
return result['corrected_text']
return text
def correct_batch(texts: List[str], prefer_local: bool = True) -> List[str]:
"""Simple function for batch correction (always uses local now)"""
client = get_client(prefer_local=True)
result = client.correct_batch(texts)
if result['success']:
return result.get('corrected_texts', texts)
return texts