Spaces:
Running
Running
| """ | |
| Author : Janarddan Sarkar | |
| file_name : mistral_ocr_st.py | |
| date : 10-03-2025 | |
| description : | |
| """ | |
| import os | |
| import json | |
| import base64 | |
| import streamlit as st | |
| from mistralai import Mistral | |
| from dotenv import find_dotenv, load_dotenv | |
| from mistralai import DocumentURLChunk, ImageURLChunk, TextChunk | |
| from mistralai.models import OCRResponse | |
| from enum import Enum | |
| from pydantic import BaseModel | |
| import pycountry | |
| # Load environment variables | |
| load_dotenv(find_dotenv()) | |
| api_key = os.environ.get("MISTRAL_API_KEY") | |
| client = Mistral(api_key=api_key) | |
| # Define Language Enum | |
| languages = {lang.alpha_2: lang.name for lang in pycountry.languages if hasattr(lang, 'alpha_2')} | |
| class LanguageMeta(Enum.__class__): | |
| def __new__(metacls, cls, bases, classdict): | |
| for code, name in languages.items(): | |
| classdict[name.upper().replace(' ', '_')] = name | |
| return super().__new__(metacls, cls, bases, classdict) | |
| class Language(Enum, metaclass=LanguageMeta): | |
| pass | |
| class StructuredOCR(BaseModel): | |
| file_name: str | |
| topics: list[str] | |
| languages: list[Language] | |
| ocr_contents: dict | |
| def replace_images_in_markdown(markdown_str: str, images_dict: dict) -> str: | |
| for img_name, base64_str in images_dict.items(): | |
| markdown_str = markdown_str.replace(f"", f"") | |
| return markdown_str | |
| def get_combined_markdown(ocr_response: OCRResponse) -> str: | |
| markdowns: list[str] = [] | |
| for page in ocr_response.pages: | |
| image_data = {img.id: img.image_base64 for img in page.images} | |
| markdowns.append(replace_images_in_markdown(page.markdown, image_data)) | |
| return "\n\n".join(markdowns) | |
| def process_pdf(pdf_bytes, file_name): | |
| """Process a PDF using OCR.""" | |
| uploaded_file = client.files.upload( | |
| file={"file_name": file_name, "content": pdf_bytes}, | |
| purpose = "ocr", | |
| ) | |
| signed_url = client.files.get_signed_url(file_id=uploaded_file.id, expiry=1) | |
| pdf_response = client.ocr.process( | |
| document=DocumentURLChunk(document_url=signed_url.url), | |
| model="mistral-ocr-latest", | |
| include_image_base64=True, | |
| ) | |
| # Ensure pdf_response is properly converted to OCRResponse model | |
| if isinstance(pdf_response, dict): # If response is a dictionary, convert it | |
| pdf_response = OCRResponse(**pdf_response) | |
| return pdf_response | |
| def process_image(image_bytes, file_name): | |
| """Process an image using OCR.""" | |
| encoded_image = base64.b64encode(image_bytes).decode() | |
| base64_data_url = f"data:image/jpeg;base64,{encoded_image}" | |
| image_response = client.ocr.process( | |
| document=ImageURLChunk(image_url=base64_data_url), model="mistral-ocr-latest" | |
| ) | |
| image_ocr_markdown = image_response.pages[0].markdown | |
| chat_response = client.chat.parse( | |
| model="pixtral-12b-latest", | |
| messages=[ | |
| { | |
| "role": "user", | |
| "content": [ | |
| ImageURLChunk(image_url=base64_data_url), | |
| TextChunk( | |
| text=( | |
| "This is the image's OCR in markdown:\n" | |
| f"<BEGIN_IMAGE_OCR>\n{image_ocr_markdown}\n<END_IMAGE_OCR>.\n" | |
| "Convert this into a structured JSON response with the OCR contents in a dictionary." | |
| ) | |
| ), | |
| ], | |
| }, | |
| ], | |
| response_format=StructuredOCR, | |
| temperature=0, | |
| ) | |
| return json.loads(chat_response.choices[0].message.parsed.model_dump_json()) | |
| # Streamlit UI | |
| st.title("Mistral OCR") | |
| uploaded_file = st.file_uploader("Upload a PDF or Image", type=["pdf", "png", "jpg", "jpeg"]) | |
| if uploaded_file: | |
| file_type = uploaded_file.type | |
| file_bytes = uploaded_file.read() | |
| file_name = uploaded_file.name | |
| if st.button("Submit"): | |
| st.write(f"**Processing file:** {file_name}") | |
| if "pdf" in file_type: | |
| pdf_response = process_pdf(file_bytes, file_name) | |
| st.markdown(get_combined_markdown(pdf_response)) | |
| else: | |
| result = process_image(file_bytes, file_name) | |
| st.json(result) | |