Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| from PIL import Image | |
| import logging | |
| from typing import Optional, Union | |
| import os | |
| import spaces | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| # Disable torch compilation to avoid dynamo issues | |
| torch._dynamo.config.disable = True | |
| torch.backends.cudnn.allow_tf32 = True | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| class AtlasOCR: | |
| def __init__(self, model_name: str = "atlasia/AtlasOCR", max_tokens: int = 2000): | |
| """Initialize the AtlasOCR model with proper error handling.""" | |
| try: | |
| from unsloth import FastVisionModel | |
| logger.info(f"Loading model: {model_name}") | |
| # Disable compilation for the model | |
| with torch._dynamo.config.patch(disable=True): | |
| self.model, self.processor = FastVisionModel.from_pretrained( | |
| model_name, | |
| device_map="auto", | |
| load_in_4bit=True, | |
| use_gradient_checkpointing="unsloth", | |
| token=os.environ["HF_API_KEY"] | |
| ) | |
| # Ensure model is not compiled | |
| if hasattr(self.model, '_dynamo_compile'): | |
| self.model._dynamo_compile = False | |
| self.max_tokens = max_tokens | |
| self.prompt = "" | |
| self.device = next(self.model.parameters()).device | |
| logger.info(f"Model loaded successfully on device: {self.device}") | |
| except ImportError: | |
| logger.error("unsloth not found. Please install it: pip install unsloth") | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error loading model: {e}") | |
| raise | |
| def prepare_inputs(self, image: Image.Image) -> dict: | |
| """Prepare inputs for the model with proper error handling.""" | |
| try: | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| { | |
| "type": "image", | |
| }, | |
| {"type": "text", "text": self.prompt}, | |
| ], | |
| } | |
| ] | |
| text = self.processor.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True | |
| ) | |
| inputs = self.processor( | |
| image, | |
| text, | |
| add_special_tokens=False, | |
| return_tensors="pt", | |
| ) | |
| return inputs | |
| except Exception as e: | |
| logger.error(f"Error preparing inputs: {e}") | |
| raise | |
| def predict(self, image: Image.Image) -> str: | |
| """Predict text from image with comprehensive error handling.""" | |
| try: | |
| if image is None: | |
| return "Please upload an image." | |
| # Convert numpy array to PIL Image if needed | |
| if hasattr(image, 'shape'): # numpy array | |
| image = Image.fromarray(image) | |
| inputs = self.prepare_inputs(image) | |
| # Move inputs to the same device as model with explicit device handling | |
| device = self.device | |
| logger.info(f"Moving inputs to device: {device}") | |
| # Manually move each tensor to device | |
| for key in inputs: | |
| if hasattr(inputs[key], 'to'): | |
| inputs[key] = inputs[key].to(device) | |
| # Ensure attention_mask is float32 and on correct device | |
| if 'attention_mask' in inputs: | |
| inputs['attention_mask'] = inputs['attention_mask'].to(dtype=torch.float32, device=device) | |
| logger.info(f"Generating text with max_tokens={self.max_tokens}") | |
| # Disable compilation during generation | |
| with torch.no_grad(), torch._dynamo.config.patch(disable=True): | |
| generated_ids = self.model.generate( | |
| **inputs, | |
| max_new_tokens=self.max_tokens, | |
| use_cache=True, | |
| do_sample=False, | |
| temperature=0.1, | |
| pad_token_id=self.processor.tokenizer.eos_token_id | |
| ) | |
| generated_ids_trimmed = [ | |
| out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs['input_ids'], generated_ids) | |
| ] | |
| output_text = self.processor.batch_decode( | |
| generated_ids_trimmed, | |
| skip_special_tokens=True, | |
| clean_up_tokenization_spaces=False | |
| ) | |
| result = output_text[0].strip() | |
| logger.info(f"Generated text: {result[:100]}...") | |
| return result | |
| except Exception as e: | |
| logger.error(f"Error during prediction: {e}") | |
| return f"Error processing image: {str(e)}" | |
| def __call__(self, image: Union[Image.Image, str]) -> str: | |
| """Callable interface for the model.""" | |
| if isinstance(image, str): | |
| return "Please upload an image file." | |
| return self.predict(image) | |
| # Global model instance | |
| atlas_ocr = None | |
| def load_model(): | |
| """Load the model globally to avoid reloading.""" | |
| global atlas_ocr | |
| if atlas_ocr is None: | |
| try: | |
| atlas_ocr = AtlasOCR() | |
| except Exception as e: | |
| logger.error(f"Failed to load model: {e}") | |
| return False | |
| return True | |
| def perform_ocr(image): | |
| """Main OCR function with proper error handling.""" | |
| try: | |
| if not load_model(): | |
| return "Error: Failed to load model. Please check the logs." | |
| if image is None: | |
| return "Please upload an image to extract text." | |
| result = atlas_ocr(image) | |
| return result | |
| except Exception as e: | |
| logger.error(f"Error in perform_ocr: {e}") | |
| return f"An error occurred: {str(e)}" | |
| def process_with_status(image): | |
| """Process image and return result with status - moved outside to avoid pickling issues.""" | |
| if image is None: | |
| return "Please upload an image.", "No image provided" | |
| try: | |
| result = perform_ocr(image) | |
| return result, "Processing completed successfully" | |
| except Exception as e: | |
| return f"Error: {str(e)}", f"Error occurred: {str(e)}" | |
| def create_interface(): | |
| """Create the Gradio interface with proper configuration.""" | |
| with gr.Blocks( | |
| title="AtlasOCR - Darija Document OCR", | |
| theme=gr.themes.Soft(), | |
| css=""" | |
| .gradio-container { | |
| max-width: 1200px !important; | |
| } | |
| """ | |
| ) as demo: | |
| gr.Markdown(""" | |
| # AtlasOCR - Darija Document OCR | |
| Upload an image to extract Darija text in real-time. This model is specialized for Darija document OCR. | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| # Input image | |
| image_input = gr.Image( | |
| type="pil", | |
| label="Upload Image", | |
| height=400 | |
| ) | |
| # Submit button | |
| submit_btn = gr.Button( | |
| "Extract Text", | |
| variant="primary", | |
| size="lg" | |
| ) | |
| # Clear button | |
| clear_btn = gr.Button("Clear", variant="secondary") | |
| with gr.Column(scale=1): | |
| # Output text | |
| output = gr.Textbox( | |
| label="Extracted Text", | |
| lines=20, | |
| show_copy_button=True, | |
| placeholder="Extracted text will appear here..." | |
| ) | |
| # Status indicator | |
| status = gr.Textbox( | |
| label="Status", | |
| value="Ready to process images", | |
| interactive=False | |
| ) | |
| # Model details | |
| with gr.Accordion("Model Information", open=False): | |
| gr.Markdown(""" | |
| **Model:** AtlasOCR-v0 | |
| **Description:** Specialized Darija OCR model for Arabic dialect text extraction | |
| **Size:** 3B parameters | |
| **Context window:** Supports up to 2000 output tokens | |
| **Optimization:** 4-bit quantization for efficient inference | |
| """) | |
| gr.Examples( | |
| examples=[ | |
| ["i3.png"], | |
| ["i6.png"] | |
| ], | |
| inputs=image_input, | |
| outputs=[output, status], # <-- required | |
| fn=process_with_status, # <-- required | |
| label="Example Images", | |
| examples_per_page=4, | |
| cache_examples=True | |
| ) | |
| # Set up processing flow | |
| submit_btn.click( | |
| fn=process_with_status, | |
| inputs=image_input, | |
| outputs=[output, status] | |
| ) | |
| image_input.change( | |
| fn=process_with_status, | |
| inputs=image_input, | |
| outputs=[output, status] | |
| ) | |
| clear_btn.click( | |
| fn=lambda: (None, "", "Ready to process images"), | |
| outputs=[image_input, output, status] | |
| ) | |
| return demo | |
| # Create and launch the interface | |
| if __name__ == "__main__": | |
| demo = create_interface() | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False, | |
| debug=True | |
| ) |