|
|
""" |
|
|
Client for Czech text correction API with local server auto-start |
|
|
""" |
|
|
|
|
|
import requests |
|
|
import time |
|
|
import subprocess |
|
|
import os |
|
|
import sys |
|
|
from pathlib import Path |
|
|
from typing import Optional, Dict, List, Any |
|
|
import logging |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class CzechCorrectionClient: |
|
|
"""Client for Czech text correction with automatic local server startup""" |
|
|
|
|
|
|
|
|
LOCAL_ENDPOINT = { |
|
|
"name": "Local", |
|
|
"base_url": "http://localhost:8042", |
|
|
"timeout": 3600 |
|
|
} |
|
|
|
|
|
def __init__(self, prefer_local: bool = True): |
|
|
""" |
|
|
Initialize the client |
|
|
|
|
|
Args: |
|
|
prefer_local: Deprecated, always uses local API now |
|
|
""" |
|
|
self.endpoint = self.LOCAL_ENDPOINT |
|
|
self._working_endpoint = None |
|
|
self._last_health_check = 0 |
|
|
self.health_check_interval = 3600 |
|
|
self._server_process = None |
|
|
|
|
|
def _check_endpoint_health(self, endpoint: Dict) -> bool: |
|
|
"""Check if an endpoint is healthy""" |
|
|
try: |
|
|
response = requests.get( |
|
|
f"{endpoint['base_url']}/api/health", |
|
|
timeout=10 |
|
|
) |
|
|
if response.status_code == 200: |
|
|
data = response.json() |
|
|
return data.get('status') == 'healthy' |
|
|
except Exception as e: |
|
|
logger.debug(f"Health check failed for {endpoint['name']}: {e}") |
|
|
return False |
|
|
|
|
|
def _is_port_in_use(self, port: int) -> bool: |
|
|
"""Check if a port is already in use""" |
|
|
import socket |
|
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: |
|
|
try: |
|
|
s.bind(('localhost', port)) |
|
|
return False |
|
|
except OSError: |
|
|
return True |
|
|
|
|
|
def _start_local_server(self) -> bool: |
|
|
"""Start the local API server if not already running""" |
|
|
try: |
|
|
|
|
|
if self._is_port_in_use(8042): |
|
|
logger.warning("Port 8042 is already in use - server may already be running") |
|
|
|
|
|
time.sleep(2) |
|
|
if self._check_endpoint_health(self.endpoint): |
|
|
logger.info("β
Server is already running on port 8042") |
|
|
return True |
|
|
else: |
|
|
logger.error("Port 8042 is in use but server is not responding to health checks") |
|
|
return False |
|
|
|
|
|
|
|
|
current_file = Path(__file__).resolve() |
|
|
api_service_dir = current_file.parent |
|
|
api_script = api_service_dir / "api.py" |
|
|
|
|
|
if not api_script.exists(): |
|
|
logger.error(f"API script not found at {api_script}") |
|
|
return False |
|
|
|
|
|
logger.info("Starting local API server...") |
|
|
logger.info("This may take 1-2 minutes to load models...") |
|
|
|
|
|
|
|
|
env = os.environ.copy() |
|
|
env['PORT'] = '8042' |
|
|
|
|
|
self._server_process = subprocess.Popen( |
|
|
[sys.executable, str(api_script)], |
|
|
cwd=str(api_service_dir), |
|
|
env=env, |
|
|
stdout=subprocess.PIPE, |
|
|
stderr=subprocess.PIPE, |
|
|
start_new_session=True |
|
|
) |
|
|
|
|
|
|
|
|
max_wait = 120 |
|
|
start_time = time.time() |
|
|
|
|
|
while time.time() - start_time < max_wait: |
|
|
if self._check_endpoint_health(self.endpoint): |
|
|
logger.info("β
Local API server started successfully") |
|
|
return True |
|
|
time.sleep(2) |
|
|
|
|
|
logger.error("Server failed to start within timeout") |
|
|
return False |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to start local server: {e}") |
|
|
return False |
|
|
|
|
|
def _get_working_endpoint(self) -> Optional[Dict]: |
|
|
"""Get working endpoint, starting server if needed""" |
|
|
current_time = time.time() |
|
|
|
|
|
|
|
|
if self._working_endpoint and (current_time - self._last_health_check < self.health_check_interval): |
|
|
return self._working_endpoint |
|
|
|
|
|
|
|
|
if self._check_endpoint_health(self.endpoint): |
|
|
logger.info(f"Using {self.endpoint['name']} API endpoint") |
|
|
self._working_endpoint = self.endpoint |
|
|
self._last_health_check = current_time |
|
|
return self.endpoint |
|
|
|
|
|
|
|
|
logger.info("Local API server not running, attempting to start...") |
|
|
if self._start_local_server(): |
|
|
self._working_endpoint = self.endpoint |
|
|
self._last_health_check = current_time |
|
|
return self.endpoint |
|
|
|
|
|
logger.error("Could not start or connect to local API server") |
|
|
return None |
|
|
|
|
|
def correct_text(self, text: str, include_timing: bool = False) -> Dict[str, Any]: |
|
|
""" |
|
|
Correct Czech text (grammar and punctuation) |
|
|
|
|
|
Args: |
|
|
text: Text to correct |
|
|
include_timing: Whether to include processing time in response |
|
|
|
|
|
Returns: |
|
|
Dict with 'success', 'corrected_text', and optionally 'processing_time_ms' |
|
|
""" |
|
|
if not text or not text.strip(): |
|
|
return { |
|
|
"success": True, |
|
|
"corrected_text": text, |
|
|
"error": None |
|
|
} |
|
|
|
|
|
endpoint = self._get_working_endpoint() |
|
|
if not endpoint: |
|
|
return { |
|
|
"success": False, |
|
|
"corrected_text": text, |
|
|
"error": "Could not start or connect to local API server" |
|
|
} |
|
|
|
|
|
try: |
|
|
payload = { |
|
|
"text": text, |
|
|
"options": {"include_timing": include_timing} |
|
|
} |
|
|
|
|
|
response = requests.post( |
|
|
f"{endpoint['base_url']}/api/correct", |
|
|
json=payload, |
|
|
timeout=endpoint['timeout'] |
|
|
) |
|
|
|
|
|
if response.status_code == 200: |
|
|
return response.json() |
|
|
else: |
|
|
return { |
|
|
"success": False, |
|
|
"corrected_text": text, |
|
|
"error": f"API error: {response.status_code}" |
|
|
} |
|
|
|
|
|
except requests.exceptions.Timeout: |
|
|
logger.warning(f"Timeout on {endpoint['name']} API") |
|
|
return { |
|
|
"success": False, |
|
|
"corrected_text": text, |
|
|
"error": "Request timeout" |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error calling API: {e}") |
|
|
return { |
|
|
"success": False, |
|
|
"corrected_text": text, |
|
|
"error": str(e) |
|
|
} |
|
|
|
|
|
def correct_batch(self, texts: List[str], include_timing: bool = False) -> Dict[str, Any]: |
|
|
""" |
|
|
Correct multiple Czech texts in batch |
|
|
|
|
|
Args: |
|
|
texts: List of texts to correct (max 10) |
|
|
include_timing: Whether to include processing time |
|
|
|
|
|
Returns: |
|
|
Dict with 'success', 'corrected_texts', and optionally 'processing_time_ms' |
|
|
""" |
|
|
if not texts: |
|
|
return { |
|
|
"success": True, |
|
|
"corrected_texts": [], |
|
|
"error": None |
|
|
} |
|
|
|
|
|
if len(texts) > 10: |
|
|
return { |
|
|
"success": False, |
|
|
"corrected_texts": texts, |
|
|
"error": "Batch size exceeds limit (10)" |
|
|
} |
|
|
|
|
|
endpoint = self._get_working_endpoint() |
|
|
if not endpoint: |
|
|
return { |
|
|
"success": False, |
|
|
"corrected_texts": texts, |
|
|
"error": "Could not start or connect to local API server" |
|
|
} |
|
|
|
|
|
try: |
|
|
payload = { |
|
|
"texts": texts, |
|
|
"options": {"include_timing": include_timing} |
|
|
} |
|
|
|
|
|
response = requests.post( |
|
|
f"{endpoint['base_url']}/api/correct/batch", |
|
|
json=payload, |
|
|
timeout=endpoint['timeout'] * 2 |
|
|
) |
|
|
|
|
|
if response.status_code == 200: |
|
|
return response.json() |
|
|
else: |
|
|
|
|
|
logger.warning(f"Batch API failed, falling back to individual corrections") |
|
|
corrected_texts = [] |
|
|
for text in texts: |
|
|
result = self.correct_text(text, include_timing=False) |
|
|
corrected_texts.append(result.get('corrected_text', text)) |
|
|
|
|
|
return { |
|
|
"success": True, |
|
|
"corrected_texts": corrected_texts, |
|
|
"error": None |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error calling batch API: {e}") |
|
|
|
|
|
corrected_texts = [] |
|
|
for text in texts: |
|
|
result = self.correct_text(text, include_timing=False) |
|
|
corrected_texts.append(result.get('corrected_text', text)) |
|
|
|
|
|
return { |
|
|
"success": True, |
|
|
"corrected_texts": corrected_texts, |
|
|
"error": None |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
_default_client = None |
|
|
|
|
|
def get_client(prefer_local: bool = True) -> CzechCorrectionClient: |
|
|
"""Get or create the default client (always uses local now)""" |
|
|
global _default_client |
|
|
if _default_client is None: |
|
|
_default_client = CzechCorrectionClient(prefer_local=True) |
|
|
return _default_client |
|
|
|
|
|
def correct_text(text: str, prefer_local: bool = True) -> str: |
|
|
"""Simple function for text correction (always uses local now)""" |
|
|
client = get_client(prefer_local=True) |
|
|
result = client.correct_text(text) |
|
|
if result['success']: |
|
|
return result['corrected_text'] |
|
|
return text |
|
|
|
|
|
def correct_batch(texts: List[str], prefer_local: bool = True) -> List[str]: |
|
|
"""Simple function for batch correction (always uses local now)""" |
|
|
client = get_client(prefer_local=True) |
|
|
result = client.correct_batch(texts) |
|
|
if result['success']: |
|
|
return result.get('corrected_texts', texts) |
|
|
return texts |