imomayiz commited on
Commit
58c54a2
·
verified ·
1 Parent(s): 00a9e34

fix file not found error

Browse files
Files changed (1) hide show
  1. app.py +264 -45
app.py CHANGED
@@ -1,55 +1,274 @@
1
  import gradio as gr
2
- import spaces
3
  from PIL import Image
 
 
 
4
 
 
 
 
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
- @spaces.GPU
8
  def perform_ocr(image):
9
- from atlasocr_model import AtlasOCR
10
- atlas_ocr=AtlasOCR()
11
- output_text = atlas_ocr(image)
12
- return output_text
13
-
14
- # Create Gradio interface
15
- with gr.Blocks(title="AtlasOCR") as demo:
16
- gr.Markdown("# AtlasOCR")
17
- gr.Markdown("Upload an image to extract Darija text in real-time. This model is specialized for Darija document OCR.")
 
 
 
 
 
 
 
 
18
 
19
- with gr.Row():
20
- with gr.Column(scale=1):
21
- # Input image
22
- image_input = gr.Image(type="numpy", label="Upload Image")
23
-
24
- # Example gallery
25
- gr.Examples(
26
- examples=[
27
- ["i3.jpg"],
28
- ["i6.jpg"]
29
- ],
30
- inputs=image_input,
31
- label="Example Images",
32
- examples_per_page=4
33
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
- # Submit button
36
- submit_btn = gr.Button("Extract Text")
37
-
38
- with gr.Column(scale=1):
39
- # Output text
40
- output = gr.Textbox(label="Extracted Text", lines=20, show_copy_button=True)
41
-
42
- # Model details
43
- with gr.Accordion("Model Information", open=False):
44
- gr.Markdown("""
45
- **Model:** AtlasOCR-v0
46
- **Description:** Darija OCR model
47
- **Size:** 3B parameters
48
- **Context window:** Supports up to 2000 output tokens
49
- """)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
- # Set up processing flow
52
- submit_btn.click(fn=perform_ocr, inputs=image_input, outputs=output)
53
- image_input.change(fn=perform_ocr, inputs=image_input, outputs=output)
54
 
55
- demo.launch()
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import torch
3
  from PIL import Image
4
+ import logging
5
+ from typing import Optional, Union
6
+ import os
7
 
8
+ # Configure logging
9
+ logging.basicConfig(level=logging.INFO)
10
+ logger = logging.getLogger(__name__)
11
 
12
+ class AtlasOCR:
13
+ def __init__(self, model_name: str = "atlasia/AtlasOCR-v0", max_tokens: int = 2000):
14
+ """Initialize the AtlasOCR model with proper error handling."""
15
+ try:
16
+ from unsloth import FastVisionModel
17
+
18
+ logger.info(f"Loading model: {model_name}")
19
+ self.model, self.processor = FastVisionModel.from_pretrained(
20
+ model_name,
21
+ device_map="auto",
22
+ load_in_4bit=True,
23
+ use_gradient_checkpointing="unsloth"
24
+ )
25
+ self.max_tokens = max_tokens
26
+ self.prompt = ""
27
+ logger.info("Model loaded successfully")
28
+
29
+ except ImportError:
30
+ logger.error("unsloth not found. Please install it: pip install unsloth")
31
+ raise
32
+ except Exception as e:
33
+ logger.error(f"Error loading model: {e}")
34
+ raise
35
+
36
+ def prepare_inputs(self, image: Image.Image) -> dict:
37
+ """Prepare inputs for the model with proper error handling."""
38
+ try:
39
+ messages = [
40
+ {
41
+ "role": "user",
42
+ "content": [
43
+ {
44
+ "type": "image",
45
+ },
46
+ {"type": "text", "text": self.prompt},
47
+ ],
48
+ }
49
+ ]
50
+
51
+ text = self.processor.apply_chat_template(
52
+ messages, tokenize=False, add_generation_prompt=True
53
+ )
54
+
55
+ inputs = self.processor(
56
+ image,
57
+ text,
58
+ add_special_tokens=False,
59
+ return_tensors="pt",
60
+ )
61
+ return inputs
62
+
63
+ except Exception as e:
64
+ logger.error(f"Error preparing inputs: {e}")
65
+ raise
66
+
67
+ def predict(self, image: Image.Image) -> str:
68
+ """Predict text from image with comprehensive error handling."""
69
+ try:
70
+ if image is None:
71
+ return "Please upload an image."
72
+
73
+ # Convert numpy array to PIL Image if needed
74
+ if hasattr(image, 'shape'): # numpy array
75
+ image = Image.fromarray(image)
76
+
77
+ inputs = self.prepare_inputs(image)
78
+
79
+ # Move inputs to GPU if available
80
+ device = "cuda" if torch.cuda.is_available() else "cpu"
81
+ inputs = {k: v.to(device) if hasattr(v, 'to') else v for k, v in inputs.items()}
82
+
83
+ # Ensure attention_mask is float32
84
+ if 'attention_mask' in inputs:
85
+ inputs['attention_mask'] = inputs['attention_mask'].to(torch.float32)
86
+
87
+ logger.info(f"Generating text with max_tokens={self.max_tokens}")
88
+ with torch.no_grad():
89
+ generated_ids = self.model.generate(
90
+ **inputs,
91
+ max_new_tokens=self.max_tokens,
92
+ use_cache=True,
93
+ do_sample=False,
94
+ temperature=0.1
95
+ )
96
+
97
+ generated_ids_trimmed = [
98
+ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs['input_ids'], generated_ids)
99
+ ]
100
+
101
+ output_text = self.processor.batch_decode(
102
+ generated_ids_trimmed,
103
+ skip_special_tokens=True,
104
+ clean_up_tokenization_spaces=False
105
+ )
106
+
107
+ result = output_text[0].strip()
108
+ logger.info(f"Generated text: {result[:100]}...")
109
+ return result
110
+
111
+ except Exception as e:
112
+ logger.error(f"Error during prediction: {e}")
113
+ return f"Error processing image: {str(e)}"
114
+
115
+ def __call__(self, image: Union[Image.Image, str]) -> str:
116
+ """Callable interface for the model."""
117
+ if isinstance(image, str):
118
+ return "Please upload an image file."
119
+ return self.predict(image)
120
+
121
+
122
+ # Global model instance
123
+ atlas_ocr = None
124
+
125
+ def load_model():
126
+ """Load the model globally to avoid reloading."""
127
+ global atlas_ocr
128
+ if atlas_ocr is None:
129
+ try:
130
+ atlas_ocr = AtlasOCR()
131
+ except Exception as e:
132
+ logger.error(f"Failed to load model: {e}")
133
+ return False
134
+ return True
135
 
 
136
  def perform_ocr(image):
137
+ """Main OCR function with proper error handling."""
138
+ try:
139
+ if not load_model():
140
+ return "Error: Failed to load model. Please check the logs."
141
+
142
+ if image is None:
143
+ return "Please upload an image to extract text."
144
+
145
+ result = atlas_ocr(image)
146
+ return result
147
+
148
+ except Exception as e:
149
+ logger.error(f"Error in perform_ocr: {e}")
150
+ return f"An error occurred: {str(e)}"
151
+
152
+ def create_interface():
153
+ """Create the Gradio interface with proper configuration."""
154
 
155
+ # Example images from assets
156
+ example_images = []
157
+ assets_dir = "assets"
158
+ if os.path.exists(assets_dir):
159
+ for file in os.listdir(assets_dir):
160
+ if file.lower().endswith(('.png', '.jpg', '.jpeg')):
161
+ example_images.append([os.path.join(assets_dir, file)])
162
+
163
+ # If no example images found, use empty list
164
+ if not example_images:
165
+ example_images = []
166
+
167
+ with gr.Blocks(
168
+ title="AtlasOCR - Darija Document OCR",
169
+ theme=gr.themes.Soft(),
170
+ css="""
171
+ .gradio-container {
172
+ max-width: 1200px !important;
173
+ }
174
+ """
175
+ ) as demo:
176
+
177
+ gr.Markdown("""
178
+ # AtlasOCR - Darija Document OCR
179
+ Upload an image to extract Darija text in real-time. This model is specialized for Darija document OCR.
180
+ """)
181
+
182
+ with gr.Row():
183
+ with gr.Column(scale=1):
184
+ # Input image
185
+ image_input = gr.Image(
186
+ type="pil",
187
+ label="Upload Image",
188
+ height=400
189
+ )
190
+
191
+ # Example gallery
192
+ if example_images:
193
+ gr.Examples(
194
+ examples=example_images,
195
+ inputs=image_input,
196
+ label="Example Images",
197
+ examples_per_page=4
198
+ )
199
+
200
+ # Submit button
201
+ submit_btn = gr.Button(
202
+ "Extract Text",
203
+ variant="primary",
204
+ size="lg"
205
+ )
206
+
207
+ # Clear button
208
+ clear_btn = gr.Button("Clear", variant="secondary")
209
 
210
+ with gr.Column(scale=1):
211
+ # Output text
212
+ output = gr.Textbox(
213
+ label="Extracted Text",
214
+ lines=20,
215
+ show_copy_button=True,
216
+ placeholder="Extracted text will appear here..."
217
+ )
218
+
219
+ # Status indicator
220
+ status = gr.Textbox(
221
+ label="Status",
222
+ value="Ready to process images",
223
+ interactive=False
224
+ )
225
+
226
+ # Model details
227
+ with gr.Accordion("Model Information", open=False):
228
+ gr.Markdown("""
229
+ **Model:** AtlasOCR-v0
230
+ **Description:** Specialized Darija OCR model for Arabic dialect text extraction
231
+ **Size:** 3B parameters
232
+ **Context window:** Supports up to 2000 output tokens
233
+ **Optimization:** 4-bit quantization for efficient inference
234
+ """)
235
+
236
+ # Set up processing flow
237
+ def process_with_status(image):
238
+ if image is None:
239
+ return "Please upload an image.", "No image provided"
240
+
241
+ try:
242
+ result = perform_ocr(image)
243
+ return result, "Processing completed successfully"
244
+ except Exception as e:
245
+ return f"Error: {str(e)}", f"Error occurred: {str(e)}"
246
+
247
+ submit_btn.click(
248
+ fn=process_with_status,
249
+ inputs=image_input,
250
+ outputs=[output, status]
251
+ )
252
+
253
+ image_input.change(
254
+ fn=process_with_status,
255
+ inputs=image_input,
256
+ outputs=[output, status]
257
+ )
258
+
259
+ clear_btn.click(
260
+ fn=lambda: (None, "", "Ready to process images"),
261
+ outputs=[image_input, output, status]
262
+ )
263
 
264
+ return demo
 
 
265
 
266
+ # Create and launch the interface
267
+ if __name__ == "__main__":
268
+ demo = create_interface()
269
+ demo.launch(
270
+ server_name="0.0.0.0",
271
+ server_port=7860,
272
+ share=False,
273
+ debug=True
274
+ )