import spaces import gradio as gr import torch from PIL import Image from qwen_vl_utils import process_vision_info from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, AutoTokenizer from transformers import Qwen2VLProcessor, Qwen2VLImageProcessor import traceback import json import os # ======================================== # AIN VLM MODEL FOR OCR # ======================================== # Model configuration MODEL_ID = "MBZUAI/AIN" # Image resolution settings for the processor # The default range for the number of visual tokens per image in the model is 4-16384 # These settings balance speed and memory usage MIN_PIXELS = 256 * 28 * 28 # Minimum resolution MAX_PIXELS = 1280 * 28 * 28 # Maximum resolution # Global model and processor model = None processor = None # Strict OCR-focused prompt OCR_PROMPT = """Extract all text from this image exactly as it appears. Requirements: 1. Extract ONLY the text content - do not describe, analyze, or interpret the image 2. Maintain the original text structure, layout, and formatting 3. Preserve line breaks, paragraphs, and spacing as they appear 4. Do not translate the text - keep it in its original language 5. Do not add any explanations, descriptions, or additional commentary 6. If there are tables, maintain their structure 7. If there are headers, titles, or sections, preserve their hierarchy Output only the extracted text, nothing else.""" def ensure_model_loaded(): """Lazily load the AIN VLM model and processor.""" global model, processor if model is not None and processor is not None: return print("đ Loading AIN VLM model...") try: # Determine device and dtype if torch.cuda.is_available(): device_map = "auto" torch_dtype = "auto" print("â Using GPU (CUDA)") else: device_map = "cpu" torch_dtype = torch.float32 print("â Using CPU") # Load model loaded_model = Qwen2VLForConditionalGeneration.from_pretrained( MODEL_ID, torch_dtype=torch_dtype, device_map=device_map, trust_remote_code=True, ) # Load processor with proper configuration # Manual construction to avoid size parameter issues try: # First, try the standard way loaded_processor = AutoProcessor.from_pretrained( MODEL_ID, trust_remote_code=True, ) print("â Processor loaded successfully (standard method)") except ValueError as e: if "size must contain 'shortest_edge' and 'longest_edge' keys" in str(e): print("â ī¸ Standard processor loading failed, trying manual construction...") # Manually construct processor with correct size format try: # Load tokenizer separately tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) # Create image processor with correct size format image_processor = Qwen2VLImageProcessor( size={"shortest_edge": 224, "longest_edge": 1120}, # Valid format do_resize=True, do_rescale=True, do_normalize=True, ) # Create processor from components loaded_processor = Qwen2VLProcessor( image_processor=image_processor, tokenizer=tokenizer, ) print("â Processor loaded successfully (manual construction)") except Exception as manual_error: print(f"â Manual construction also failed: {manual_error}") raise else: raise model = loaded_model processor = loaded_processor print("â Model loaded successfully!") except Exception as e: print(f"â Error loading model: {e}") traceback.print_exc() raise @spaces.GPU(duration=100) def extract_text_from_image( image: Image.Image, custom_prompt: str = None, max_new_tokens: int = 2048, min_pixels: int = None, max_pixels: int = None ) -> str: """ Extract text from image using AIN VLM model. Args: image: PIL Image to process custom_prompt: Optional custom prompt (uses default OCR prompt if None) max_new_tokens: Maximum tokens to generate min_pixels: Minimum image resolution (optional) max_pixels: Maximum image resolution (optional) Returns: Extracted text as string """ try: # Ensure model is loaded ensure_model_loaded() if model is None or processor is None: return "â Error: Model not loaded. Please refresh and try again." # Use custom prompt or default OCR prompt prompt_to_use = custom_prompt if custom_prompt and custom_prompt.strip() else OCR_PROMPT # Use custom resolution settings if provided, otherwise use defaults min_pix = min_pixels if min_pixels else MIN_PIXELS max_pix = max_pixels if max_pixels else MAX_PIXELS # Prepare messages in the format expected by the model # Include min_pixels and max_pixels in the image content for proper resizing messages = [ { "role": "user", "content": [ { "type": "image", "image": image, "min_pixels": min_pix, "max_pixels": max_pix, }, { "type": "text", "text": prompt_to_use }, ], } ] # Apply chat template text = processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) # Process vision information image_inputs, video_inputs = process_vision_info(messages) # Prepare inputs inputs = processor( text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt", ) # Move to device device = next(model.parameters()).device inputs = inputs.to(device) # Generate output with torch.no_grad(): generated_ids = model.generate( **inputs, max_new_tokens=max_new_tokens, do_sample=False, # Greedy decoding for consistency ) # Decode output generated_ids_trimmed = [ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) ] output_text = processor.batch_decode( generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False ) result = output_text[0] if output_text else "" return result.strip() if result else "No text extracted" except Exception as e: error_msg = f"â Error during text extraction: {str(e)}" print(error_msg) traceback.print_exc() return error_msg def create_gradio_interface(): """Create the Gradio interface for AIN OCR.""" # Custom CSS for better UI css = """ .main-container { max-width: 1400px; margin: 0 auto; padding: 20px; } .header-text { text-align: center; color: #2c3e50; margin-bottom: 30px; } .process-button { background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important; border: none !important; color: white !important; font-weight: bold !important; font-size: 1.1em !important; padding: 12px 24px !important; width: 100% !important; margin-top: 10px !important; } .process-button:hover { transform: translateY(-2px) !important; box-shadow: 0 6px 12px rgba(0,0,0,0.2) !important; } /* Larger font for extracted text */ .output-textbox textarea { font-size: 20px !important; line-height: 2.0 !important; font-family: 'Segoe UI', 'Tahoma', 'Traditional Arabic', 'Arabic Typesetting', sans-serif !important; padding: 24px !important; direction: auto !important; text-align: start !important; } .output-textbox { background: #ffffff; border: 2px solid #e0e0e0; border-radius: 8px; box-shadow: 0 2px 8px rgba(0,0,0,0.1); } /* Better Arabic text support */ .output-textbox textarea[dir="rtl"] { text-align: right !important; direction: rtl !important; } .info-box { background: #e3f2fd; border-left: 4px solid #2196f3; padding: 15px; margin: 10px 0; border-radius: 4px; } /* Status box styling */ .status-box { background: #f0f4f8; border: 1px solid #d0dae6; border-radius: 6px; padding: 12px; margin-top: 10px; text-align: center; font-size: 14px; } /* Better spacing for rows and columns */ .gradio-container { gap: 20px !important; } .contain { gap: 15px !important; } /* Image preview styling */ .image-preview { border: 2px solid #e0e0e0; border-radius: 8px; box-shadow: 0 2px 8px rgba(0,0,0,0.1); } /* Accordion styling */ .accordion { background: #f8f9fa; border-radius: 8px; margin-top: 15px; padding: 5px; } /* Clear button */ button[variant="secondary"] { width: 100% !important; margin-top: 10px !important; } /* Label styling */ label { font-weight: 600 !important; margin-bottom: 8px !important; } /* Better component spacing */ .gr-form { gap: 12px !important; } /* Example images styling */ .gr-examples { margin-top: 15px; } """ with gr.Blocks(theme=gr.themes.Soft(), css=css, title="AIN VLM OCR") as demo: # Header gr.HTML("""
Advanced OCR using Vision Language Model (VLM) for accurate text extraction
Powered by MBZUAI/AIN - Specialized for understanding and extracting text from images