File size: 10,660 Bytes
508f678
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
"""
Client for Czech text correction API with local server auto-start
"""

import requests
import time
import subprocess
import os
import sys
from pathlib import Path
from typing import Optional, Dict, List, Any
import logging

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class CzechCorrectionClient:
    """Client for Czech text correction with automatic local server startup"""

    # Local API endpoint only
    LOCAL_ENDPOINT = {
        "name": "Local",
        "base_url": "http://localhost:8042",
        "timeout": 3600  # 1 hour for local (grammar correction can be slow)
    }

    def __init__(self, prefer_local: bool = True):
        """
        Initialize the client

        Args:
            prefer_local: Deprecated, always uses local API now
        """
        self.endpoint = self.LOCAL_ENDPOINT
        self._working_endpoint = None
        self._last_health_check = 0
        self.health_check_interval = 3600  # Cache endpoint for 1 hour
        self._server_process = None

    def _check_endpoint_health(self, endpoint: Dict) -> bool:
        """Check if an endpoint is healthy"""
        try:
            response = requests.get(
                f"{endpoint['base_url']}/api/health",
                timeout=10  # Increased timeout for health check
            )
            if response.status_code == 200:
                data = response.json()
                return data.get('status') == 'healthy'
        except Exception as e:
            logger.debug(f"Health check failed for {endpoint['name']}: {e}")
        return False

    def _is_port_in_use(self, port: int) -> bool:
        """Check if a port is already in use"""
        import socket
        with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
            try:
                s.bind(('localhost', port))
                return False
            except OSError:
                return True

    def _start_local_server(self) -> bool:
        """Start the local API server if not already running"""
        try:
            # Check if port 8042 is already in use
            if self._is_port_in_use(8042):
                logger.warning("Port 8042 is already in use - server may already be running")
                # Wait a bit and check health again
                time.sleep(2)
                if self._check_endpoint_health(self.endpoint):
                    logger.info("βœ… Server is already running on port 8042")
                    return True
                else:
                    logger.error("Port 8042 is in use but server is not responding to health checks")
                    return False

            # Find the api_service directory
            current_file = Path(__file__).resolve()
            api_service_dir = current_file.parent
            api_script = api_service_dir / "api.py"

            if not api_script.exists():
                logger.error(f"API script not found at {api_script}")
                return False

            logger.info("Starting local API server...")
            logger.info("This may take 1-2 minutes to load models...")

            # Start the server in the background
            env = os.environ.copy()
            env['PORT'] = '8042'  # Set port to 8042

            self._server_process = subprocess.Popen(
                [sys.executable, str(api_script)],
                cwd=str(api_service_dir),
                env=env,
                stdout=subprocess.PIPE,
                stderr=subprocess.PIPE,
                start_new_session=True
            )

            # Wait for server to be ready (up to 2 minutes)
            max_wait = 120
            start_time = time.time()

            while time.time() - start_time < max_wait:
                if self._check_endpoint_health(self.endpoint):
                    logger.info("βœ… Local API server started successfully")
                    return True
                time.sleep(2)

            logger.error("Server failed to start within timeout")
            return False

        except Exception as e:
            logger.error(f"Failed to start local server: {e}")
            return False

    def _get_working_endpoint(self) -> Optional[Dict]:
        """Get working endpoint, starting server if needed"""
        current_time = time.time()

        # Use cached endpoint if still valid
        if self._working_endpoint and (current_time - self._last_health_check < self.health_check_interval):
            return self._working_endpoint

        # Check if local server is running
        if self._check_endpoint_health(self.endpoint):
            logger.info(f"Using {self.endpoint['name']} API endpoint")
            self._working_endpoint = self.endpoint
            self._last_health_check = current_time
            return self.endpoint

        # Try to start the server
        logger.info("Local API server not running, attempting to start...")
        if self._start_local_server():
            self._working_endpoint = self.endpoint
            self._last_health_check = current_time
            return self.endpoint

        logger.error("Could not start or connect to local API server")
        return None

    def correct_text(self, text: str, include_timing: bool = False) -> Dict[str, Any]:
        """
        Correct Czech text (grammar and punctuation)

        Args:
            text: Text to correct
            include_timing: Whether to include processing time in response

        Returns:
            Dict with 'success', 'corrected_text', and optionally 'processing_time_ms'
        """
        if not text or not text.strip():
            return {
                "success": True,
                "corrected_text": text,
                "error": None
            }

        endpoint = self._get_working_endpoint()
        if not endpoint:
            return {
                "success": False,
                "corrected_text": text,
                "error": "Could not start or connect to local API server"
            }

        try:
            payload = {
                "text": text,
                "options": {"include_timing": include_timing}
            }

            response = requests.post(
                f"{endpoint['base_url']}/api/correct",
                json=payload,
                timeout=endpoint['timeout']
            )

            if response.status_code == 200:
                return response.json()
            else:
                return {
                    "success": False,
                    "corrected_text": text,
                    "error": f"API error: {response.status_code}"
                }

        except requests.exceptions.Timeout:
            logger.warning(f"Timeout on {endpoint['name']} API")
            return {
                "success": False,
                "corrected_text": text,
                "error": "Request timeout"
            }

        except Exception as e:
            logger.error(f"Error calling API: {e}")
            return {
                "success": False,
                "corrected_text": text,
                "error": str(e)
            }

    def correct_batch(self, texts: List[str], include_timing: bool = False) -> Dict[str, Any]:
        """
        Correct multiple Czech texts in batch

        Args:
            texts: List of texts to correct (max 10)
            include_timing: Whether to include processing time

        Returns:
            Dict with 'success', 'corrected_texts', and optionally 'processing_time_ms'
        """
        if not texts:
            return {
                "success": True,
                "corrected_texts": [],
                "error": None
            }

        if len(texts) > 10:
            return {
                "success": False,
                "corrected_texts": texts,
                "error": "Batch size exceeds limit (10)"
            }

        endpoint = self._get_working_endpoint()
        if not endpoint:
            return {
                "success": False,
                "corrected_texts": texts,
                "error": "Could not start or connect to local API server"
            }

        try:
            payload = {
                "texts": texts,
                "options": {"include_timing": include_timing}
            }

            response = requests.post(
                f"{endpoint['base_url']}/api/correct/batch",
                json=payload,
                timeout=endpoint['timeout'] * 2  # Longer timeout for batch
            )

            if response.status_code == 200:
                return response.json()
            else:
                # Fallback to individual corrections
                logger.warning(f"Batch API failed, falling back to individual corrections")
                corrected_texts = []
                for text in texts:
                    result = self.correct_text(text, include_timing=False)
                    corrected_texts.append(result.get('corrected_text', text))

                return {
                    "success": True,
                    "corrected_texts": corrected_texts,
                    "error": None
                }

        except Exception as e:
            logger.error(f"Error calling batch API: {e}")
            # Fallback to individual corrections
            corrected_texts = []
            for text in texts:
                result = self.correct_text(text, include_timing=False)
                corrected_texts.append(result.get('corrected_text', text))

            return {
                "success": True,
                "corrected_texts": corrected_texts,
                "error": None
            }


# Convenience functions for backward compatibility
_default_client = None

def get_client(prefer_local: bool = True) -> CzechCorrectionClient:
    """Get or create the default client (always uses local now)"""
    global _default_client
    if _default_client is None:
        _default_client = CzechCorrectionClient(prefer_local=True)
    return _default_client

def correct_text(text: str, prefer_local: bool = True) -> str:
    """Simple function for text correction (always uses local now)"""
    client = get_client(prefer_local=True)
    result = client.correct_text(text)
    if result['success']:
        return result['corrected_text']
    return text

def correct_batch(texts: List[str], prefer_local: bool = True) -> List[str]:
    """Simple function for batch correction (always uses local now)"""
    client = get_client(prefer_local=True)
    result = client.correct_batch(texts)
    if result['success']:
        return result.get('corrected_texts', texts)
    return texts