asdfasdfdsafdsa commited on
Commit
1601325
·
verified ·
1 Parent(s): 2dd94b7

Upload 3 files

Browse files
Files changed (3) hide show
  1. Dockerfile +30 -0
  2. api.py +394 -0
  3. requirements.txt +7 -0
Dockerfile ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Hugging Face Spaces GPU-enabled Dockerfile
2
+ FROM python:3.10
3
+
4
+ # Install system dependencies
5
+ RUN apt-get update && apt-get install -y \
6
+ git \
7
+ && rm -rf /var/lib/apt/lists/*
8
+
9
+ # Set working directory
10
+ WORKDIR /app
11
+
12
+ # Copy requirements first for better caching
13
+ COPY requirements.txt .
14
+
15
+ # Install Python dependencies
16
+ RUN pip install --no-cache-dir --upgrade pip && \
17
+ pip install --no-cache-dir -r requirements.txt
18
+
19
+ # Copy application code
20
+ COPY api.py .
21
+
22
+ # Create non-root user (HF Spaces requirement)
23
+ RUN useradd -m -u 1000 user
24
+ USER user
25
+
26
+ # HF Spaces expects port 7860
27
+ EXPOSE 7860
28
+
29
+ # Run the application on HF Spaces default port
30
+ CMD ["uvicorn", "api:app", "--host", "0.0.0.0", "--port", "7860"]
api.py ADDED
@@ -0,0 +1,394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FastAPI service for Czech text correction pipeline
3
+ Combines grammar error correction and punctuation restoration
4
+ """
5
+
6
+ from fastapi import FastAPI, HTTPException, Request
7
+ from fastapi.middleware.cors import CORSMiddleware
8
+ from pydantic import BaseModel, Field
9
+ from typing import Optional, List, Dict
10
+ import torch
11
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForTokenClassification, pipeline
12
+ import time
13
+ import re
14
+ import logging
15
+ from contextlib import asynccontextmanager
16
+
17
+ # Configure logging
18
+ logging.basicConfig(level=logging.INFO)
19
+ logger = logging.getLogger(__name__)
20
+
21
+ # Global variables for models
22
+ gec_model = None
23
+ gec_tokenizer = None
24
+ punct_pipeline = None
25
+ device = None
26
+
27
+ # Optimal hyperparameters for production
28
+ GEC_CONFIG = {
29
+ "num_beams": 8,
30
+ "do_sample": False,
31
+ "repetition_penalty": 1.0,
32
+ "length_penalty": 1.0,
33
+ "no_repeat_ngram_size": 0,
34
+ "early_stopping": True,
35
+ "max_new_tokens": 1500
36
+ }
37
+
38
+ @asynccontextmanager
39
+ async def lifespan(app: FastAPI):
40
+ """Load models on startup, cleanup on shutdown"""
41
+ global gec_model, gec_tokenizer, punct_pipeline, device
42
+
43
+ logger.info("Loading models...")
44
+
45
+ # Setup device
46
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
47
+ logger.info(f"Using device: {device}")
48
+
49
+ # Load GEC model
50
+ logger.info("Loading Czech GEC model...")
51
+ gec_tokenizer = AutoTokenizer.from_pretrained("ufal/byt5-large-geccc-mate")
52
+ gec_model = AutoModelForSeq2SeqLM.from_pretrained("ufal/byt5-large-geccc-mate")
53
+ gec_model = gec_model.to(device)
54
+ logger.info("GEC model loaded successfully")
55
+
56
+ # Load punctuation model
57
+ logger.info("Loading punctuation model...")
58
+ punct_tokenizer = AutoTokenizer.from_pretrained("kredor/punctuate-all")
59
+ punct_model = AutoModelForTokenClassification.from_pretrained("kredor/punctuate-all")
60
+ punct_pipeline = pipeline(
61
+ "token-classification",
62
+ model=punct_model,
63
+ tokenizer=punct_tokenizer,
64
+ device=0 if torch.cuda.is_available() else -1
65
+ )
66
+ logger.info("Punctuation model loaded successfully")
67
+
68
+ logger.info("All models loaded and ready")
69
+
70
+ yield
71
+
72
+ # Cleanup (if needed)
73
+ logger.info("Shutting down...")
74
+
75
+ # Create FastAPI app with lifespan
76
+ app = FastAPI(
77
+ title="Czech Text Correction API",
78
+ description="API for Czech grammar error correction and punctuation restoration",
79
+ version="1.0.0",
80
+ lifespan=lifespan
81
+ )
82
+
83
+ # Enable CORS
84
+ app.add_middleware(
85
+ CORSMiddleware,
86
+ allow_origins=["*"],
87
+ allow_credentials=True,
88
+ allow_methods=["*"],
89
+ allow_headers=["*"],
90
+ )
91
+
92
+ # Request/Response models
93
+ class CorrectionRequest(BaseModel):
94
+ text: str = Field(..., max_length=5000, description="Czech text to correct")
95
+ options: Optional[Dict] = Field(default={}, description="Optional parameters")
96
+
97
+ class CorrectionResponse(BaseModel):
98
+ success: bool
99
+ corrected_text: str
100
+ processing_time_ms: Optional[float] = None
101
+ error: Optional[str] = None
102
+
103
+ class BatchCorrectionRequest(BaseModel):
104
+ texts: List[str] = Field(..., max_items=10, description="List of texts to correct")
105
+ options: Optional[Dict] = Field(default={}, description="Optional parameters")
106
+
107
+ class BatchCorrectionResponse(BaseModel):
108
+ success: bool
109
+ corrected_texts: List[str]
110
+ processing_time_ms: Optional[float] = None
111
+ error: Optional[str] = None
112
+
113
+ class HealthResponse(BaseModel):
114
+ status: str
115
+ models_loaded: bool
116
+ gpu_available: bool
117
+ device: str
118
+
119
+ class InfoResponse(BaseModel):
120
+ name: str
121
+ version: str
122
+ models: Dict[str, str]
123
+ capabilities: List[str]
124
+ max_input_length: int
125
+
126
+ def apply_gec_correction(text: str) -> str:
127
+ """Apply grammar error correction to text"""
128
+ if not text.strip():
129
+ return text
130
+
131
+ # Tokenize
132
+ inputs = gec_tokenizer(
133
+ text,
134
+ return_tensors="pt",
135
+ max_length=1024,
136
+ truncation=True
137
+ )
138
+ inputs = {k: v.to(device) for k, v in inputs.items()}
139
+
140
+ # Generate correction
141
+ with torch.no_grad():
142
+ outputs = gec_model.generate(
143
+ **inputs,
144
+ **GEC_CONFIG
145
+ )
146
+
147
+ # Decode
148
+ corrected = gec_tokenizer.decode(outputs[0], skip_special_tokens=True)
149
+ return corrected
150
+
151
+ def apply_punctuation(text: str) -> str:
152
+ """Apply punctuation and capitalization to text"""
153
+ if not text.strip():
154
+ return text
155
+
156
+ # Process with pipeline
157
+ clean_text = text.lower()
158
+ results = punct_pipeline(clean_text)
159
+
160
+ # Build punctuation map
161
+ punct_map = {}
162
+ current_word = ""
163
+ current_punct = ""
164
+
165
+ for i, result in enumerate(results):
166
+ word = result['word'].replace('▁', '').strip()
167
+
168
+ # Map entity labels to punctuation
169
+ entity = result['entity']
170
+ punct_marks = {
171
+ 'LABEL_0': '',
172
+ 'LABEL_1': '.',
173
+ 'LABEL_2': ',',
174
+ 'LABEL_3': '?',
175
+ 'LABEL_4': '-',
176
+ 'LABEL_5': ':'
177
+ }
178
+ punct = punct_marks.get(entity, '')
179
+
180
+ # Handle subword tokens
181
+ if not result['word'].startswith('▁') and i > 0:
182
+ current_word += word
183
+ else:
184
+ if current_word:
185
+ punct_map[current_word] = current_punct
186
+ current_word = word
187
+ current_punct = punct
188
+
189
+ # Add last word
190
+ if current_word:
191
+ punct_map[current_word] = current_punct
192
+
193
+ # Reconstruct with punctuation
194
+ words = clean_text.split()
195
+ punctuated = []
196
+
197
+ for word in words:
198
+ if word in punct_map and punct_map[word]:
199
+ punctuated.append(word + punct_map[word])
200
+ else:
201
+ punctuated.append(word)
202
+
203
+ # Join and capitalize sentences
204
+ result = ' '.join(punctuated)
205
+
206
+ # Capitalize first letter and after sentence endings
207
+ sentences = re.split(r'(?<=[.?!])\s+', result)
208
+ capitalized = ' '.join(s[0].upper() + s[1:] if s else s for s in sentences)
209
+
210
+ # Clean spacing around punctuation
211
+ for p in [',', '.', '?', ':', '!', ';']:
212
+ capitalized = capitalized.replace(f' {p}', p)
213
+
214
+ return capitalized
215
+
216
+ def process_text(text: str) -> str:
217
+ """Full pipeline: GEC + punctuation"""
218
+ # Step 1: Grammar correction
219
+ gec_corrected = apply_gec_correction(text)
220
+
221
+ # Step 2: Punctuation and capitalization
222
+ final_text = apply_punctuation(gec_corrected)
223
+
224
+ return final_text
225
+
226
+ @app.post("/api/correct", response_model=CorrectionResponse)
227
+ async def correct_text(request: CorrectionRequest):
228
+ """
229
+ Correct Czech text (grammar + punctuation)
230
+ """
231
+ try:
232
+ start_time = time.time()
233
+
234
+ # Validate input
235
+ if not request.text.strip():
236
+ raise HTTPException(status_code=400, detail="Text cannot be empty")
237
+
238
+ if len(request.text) > 5000:
239
+ raise HTTPException(status_code=400, detail="Text too long (max 5000 characters)")
240
+
241
+ # Process text
242
+ corrected = process_text(request.text)
243
+
244
+ # Calculate processing time
245
+ processing_time = (time.time() - start_time) * 1000
246
+
247
+ # Include timing if requested
248
+ response = CorrectionResponse(
249
+ success=True,
250
+ corrected_text=corrected
251
+ )
252
+
253
+ if request.options.get("include_timing", False):
254
+ response.processing_time_ms = processing_time
255
+
256
+ return response
257
+
258
+ except Exception as e:
259
+ logger.error(f"Error processing text: {str(e)}")
260
+ return CorrectionResponse(
261
+ success=False,
262
+ corrected_text="",
263
+ error=str(e)
264
+ )
265
+
266
+ @app.post("/api/correct/batch", response_model=BatchCorrectionResponse)
267
+ async def correct_batch(request: BatchCorrectionRequest):
268
+ """
269
+ Correct multiple Czech texts
270
+ """
271
+ try:
272
+ start_time = time.time()
273
+
274
+ # Validate
275
+ if not request.texts:
276
+ raise HTTPException(status_code=400, detail="No texts provided")
277
+
278
+ # Process each text
279
+ corrected_texts = []
280
+ for text in request.texts:
281
+ if len(text) > 5000:
282
+ corrected_texts.append(f"[Error: Text too long]")
283
+ else:
284
+ corrected = process_text(text)
285
+ corrected_texts.append(corrected)
286
+
287
+ # Calculate processing time
288
+ processing_time = (time.time() - start_time) * 1000
289
+
290
+ response = BatchCorrectionResponse(
291
+ success=True,
292
+ corrected_texts=corrected_texts
293
+ )
294
+
295
+ if request.options.get("include_timing", False):
296
+ response.processing_time_ms = processing_time
297
+
298
+ return response
299
+
300
+ except Exception as e:
301
+ logger.error(f"Error processing batch: {str(e)}")
302
+ return BatchCorrectionResponse(
303
+ success=False,
304
+ corrected_texts=[],
305
+ error=str(e)
306
+ )
307
+
308
+ @app.post("/api/correct/gec-only")
309
+ async def correct_gec_only(request: CorrectionRequest):
310
+ """
311
+ Apply only grammar error correction (no punctuation)
312
+ """
313
+ try:
314
+ corrected = apply_gec_correction(request.text)
315
+ return CorrectionResponse(
316
+ success=True,
317
+ corrected_text=corrected
318
+ )
319
+ except Exception as e:
320
+ return CorrectionResponse(
321
+ success=False,
322
+ corrected_text="",
323
+ error=str(e)
324
+ )
325
+
326
+ @app.post("/api/correct/punct-only")
327
+ async def correct_punct_only(request: CorrectionRequest):
328
+ """
329
+ Apply only punctuation restoration (no grammar correction)
330
+ """
331
+ try:
332
+ corrected = apply_punctuation(request.text)
333
+ return CorrectionResponse(
334
+ success=True,
335
+ corrected_text=corrected
336
+ )
337
+ except Exception as e:
338
+ return CorrectionResponse(
339
+ success=False,
340
+ corrected_text="",
341
+ error=str(e)
342
+ )
343
+
344
+ @app.get("/api/health", response_model=HealthResponse)
345
+ async def health_check():
346
+ """
347
+ Check API health and model status
348
+ """
349
+ models_loaded = (gec_model is not None and punct_pipeline is not None)
350
+
351
+ return HealthResponse(
352
+ status="healthy" if models_loaded else "loading",
353
+ models_loaded=models_loaded,
354
+ gpu_available=torch.cuda.is_available(),
355
+ device=str(device) if device else "not initialized"
356
+ )
357
+
358
+ @app.get("/api/info", response_model=InfoResponse)
359
+ async def get_info():
360
+ """
361
+ Get API information and capabilities
362
+ """
363
+ return InfoResponse(
364
+ name="Czech Text Correction API",
365
+ version="1.0.0",
366
+ models={
367
+ "gec": "ufal/byt5-large-geccc-mate",
368
+ "punctuation": "kredor/punctuate-all"
369
+ },
370
+ capabilities=[
371
+ "Grammar error correction",
372
+ "Punctuation restoration",
373
+ "Capitalization",
374
+ "Batch processing",
375
+ "Czech language focus"
376
+ ],
377
+ max_input_length=5000
378
+ )
379
+
380
+ @app.get("/")
381
+ async def root():
382
+ """Root endpoint with API documentation link"""
383
+ return {
384
+ "message": "Czech Text Correction API",
385
+ "docs": "/docs",
386
+ "health": "/api/health",
387
+ "info": "/api/info"
388
+ }
389
+
390
+ if __name__ == "__main__":
391
+ import uvicorn
392
+ import os
393
+ port = int(os.environ.get("PORT", 7860))
394
+ uvicorn.run(app, host="0.0.0.0", port=port)
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ fastapi>=0.104.0
2
+ uvicorn[standard]>=0.24.0
3
+ torch>=2.0.0
4
+ transformers>=4.30.0
5
+ python-multipart>=0.0.6
6
+ pydantic>=2.0.0
7
+ python-dotenv>=1.0.0