Spaces:
Paused
Paused
| from __future__ import annotations | |
| import os, io, re, json, time, mimetypes, tempfile, string | |
| from typing import List, Union, Tuple, Any, Iterable | |
| from PIL import Image | |
| import pandas as pd | |
| import gradio as gr | |
| import google.generativeai as genai | |
| import requests | |
| import pdfplumber | |
| # ================== CONFIG ================== | |
| DEFAULT_API_KEY = "AIzaSyBbK-1P3JD6HPyE3QLhkOps6_-Xo3wUFbs" | |
| INTERNAL_MODEL_MAP = { | |
| "Gemini 2.5 Flash": "gemini-2.5-flash", | |
| "Gemini 2.5 Pro": "gemini-2.5-pro", | |
| } | |
| EXTERNAL_MODEL_NAME = "prithivMLmods/Camel-Doc-OCR-062825 (External)" | |
| try: | |
| RESAMPLE = Image.Resampling.LANCZOS | |
| except AttributeError: | |
| RESAMPLE = Image.LANCZOS | |
| PROMPT_FREIGHT_JSON = """ | |
| You are an expert in air freight rate extraction and normalization. | |
| The document contains rate information for multiple airlines. | |
| Please analyze all content (tables, headers, notes) and return **a list of JSON objects**, each representing a separate airline. | |
| Each airline should follow this schema: | |
| { | |
| "shipping_line": "...", | |
| "shipping_line_code": "...", | |
| "shipping_line_reason": "Why this carrier is chosen?", | |
| "fee_type": "Air Freight", | |
| "valid_from": "...", | |
| "valid_to": "...", | |
| "charges": [ ... ], # List of charge objects (see below) | |
| "local_charges": [ ... ] # Optional local charges if available | |
| } | |
| Each `charges` object must follow this schema: | |
| { | |
| "frequency": "...", | |
| "package_type": "...", # e.g. Carton, Pallet, Skid | |
| "aircraft_type": "...", | |
| "direction": "Export / Import / null", | |
| "origin": "...", | |
| "destination": "...", | |
| "charge_name": "...", | |
| "charge_code": "GCR / PER / DGR / etc.", | |
| "charge_code_reason": "...", | |
| "cargo_type": "...", | |
| "currency": "...", | |
| "transit": "...", | |
| "transit_time": "...", | |
| "weight_breaks": { | |
| "M": ..., | |
| "N": ..., | |
| "+45kg": ..., | |
| "+100kg": ..., | |
| "+300kg": ..., | |
| "+500kg": ..., | |
| "+1000kg": ..., | |
| "other": { key: value }, | |
| "weight_breaks_reason": "Why chosen weight_breaks?" | |
| }, | |
| "remark": "..." | |
| } | |
| Each `local_charges` object: | |
| { | |
| "charge_name": "...", | |
| "charge_code": "...", | |
| "unit": "...", | |
| "amount": ..., | |
| "remark": "..." | |
| } | |
| --- | |
| ### ✈️ Airline Separation Logic: | |
| - If multiple airlines are detected in the document, separate each section and return a distinct JSON object per airline. | |
| - Infer `shipping_line` and `shipping_line_code` from the header (e.g. "AIR CHINA CARGO (CA)" → name = "AIR CHINA CARGO", code = "CA"). | |
| - Each JSON object must include only data relevant to that airline. | |
| --- | |
| ### 💡 Date rules: | |
| - valid_from: | |
| - `DD/MM/YYYY` if exact | |
| - `01/MM/YYYY` if only month/year | |
| - `01/01/YYYY` if only year | |
| - `UFN` if missing | |
| - valid_to: | |
| - exact `DD/MM/YYYY` if present | |
| - else `UFN` | |
| --- | |
| ### 📦 Package and Surcharge Logic: | |
| Apply these when the remark or note indicates such rules: | |
| 1. **Default case**: If no package mentioned → `"Carton"` is the default. | |
| 2. **“Carton = Pallet”**: Duplicate rates with `package_type="Pallet"`. | |
| 3. **“SKID shipment: add 10 cents (GEN & PER)”**: Add new charges with `+0.10 USD/kg` for GEN/PER, with `package_type="Pallet"` or `"Skid"`. | |
| 4. **EU vs Non-EU surcharges**: If different pallet surcharges by region → split charges accordingly. | |
| 5. **“All-in” or “inclusive of MY and SC”**: Record `FSC` and `WSC` as `local_charges` with `"NIL"` amount. | |
| 6. **Flight number is not a charge code**. Always use standard cargo code (GCR, PER, etc.). | |
| --- | |
| ### ⚙️ Other Business Rules: | |
| - RQ / Request → "RQST" | |
| - Combine same-rate destinations using `/` | |
| - Always use **IATA code** for origin/destination | |
| - Direction = Export if origin is in Vietnam (SGN, HAN, DAD), else Import | |
| - Frequency: | |
| - D[1-7] = day of week | |
| - "Daily" = D1234567 | |
| - Remarks: Replace `,` with `;` | |
| - Add meaningful `"shipping_line_reason"` and `"charge_code_reason"` | |
| --- | |
| ### 🚨 STRICT OUTPUT: | |
| - Return **a JSON array**, where each item is a full airline object | |
| - Do NOT return markdown or explanation | |
| - All fields must be valid | |
| - All numbers = numeric types | |
| - Use `null` if value missing | |
| """ | |
| # ================== HELPERS ================== | |
| import fitz # PyMuPDF | |
| def pdf_to_images(pdf_bytes: bytes) -> list[Image.Image]: | |
| doc = fitz.open(stream=pdf_bytes, filetype="pdf") | |
| pages = [] | |
| for p in doc: | |
| pix = p.get_pixmap(dpi=200) | |
| img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples) | |
| pages.append(img) | |
| return pages | |
| def ensure_rgb(im: Image.Image) -> Image.Image: | |
| return im.convert("RGB") if im.mode != "RGB" else im | |
| def _read_file_bytes(upload: Union[str, os.PathLike, dict, object] | None) -> bytes: | |
| if upload is None: | |
| raise ValueError("No file uploaded.") | |
| if isinstance(upload, (str, os.PathLike)): | |
| with open(upload, "rb") as f: | |
| return f.read() | |
| if isinstance(upload, dict) and "path" in upload: | |
| with open(upload["path"], "rb") as f: | |
| return f.read() | |
| if hasattr(upload, "read"): | |
| return upload.read() | |
| raise TypeError(f"Unsupported file object: {type(upload)}") | |
| def _guess_name_and_mime(file, file_bytes: bytes) -> Tuple[str, str]: | |
| if isinstance(file, (str, os.PathLike)): | |
| filename = os.path.basename(str(file)) | |
| elif isinstance(file, dict) and "name" in file: | |
| filename = os.path.basename(file["name"]) | |
| elif isinstance(file, dict) and "path" in file: | |
| filename = os.path.basename(file["path"]) | |
| else: | |
| filename = "upload.bin" | |
| mime, _ = mimetypes.guess_type(filename) | |
| if not mime: | |
| if len(file_bytes) >= 4 and file_bytes[:4] == b"%PDF": | |
| mime = "application/pdf" | |
| if not filename.lower().endswith(".pdf"): | |
| filename += ".pdf" | |
| else: | |
| mime = "image/png" | |
| return filename, mime | |
| # ================== PDF CHECK STEP ================== | |
| def check_pdf_structure(file_bytes: bytes) -> str: | |
| """Kiểm tra nhanh file PDF có phải bảng nhiều cột, nhiều trang không.""" | |
| try: | |
| with pdfplumber.open(io.BytesIO(file_bytes)) as pdf: | |
| if len(pdf.pages) <= 2: | |
| return "không" | |
| table_pages = 0 | |
| for page in pdf.pages[:3]: | |
| tables = page.find_tables() | |
| if tables and len(tables) > 0: | |
| table_pages += 1 | |
| if table_pages >= 1: | |
| return "có" | |
| text = "\n".join([(p.extract_text() or "") for p in pdf.pages[:2]]) | |
| num_tokens = sum(ch.isdigit() for ch in text) | |
| line_count = len(text.splitlines()) | |
| if num_tokens > 100 and line_count > 20: | |
| return "có" | |
| return "không" | |
| except Exception as e: | |
| print("PDF check error:", e) | |
| return "không" | |
| # ================== OCR CORE (Gemini) ================== | |
| def run_process_internal_base_v2(file_bytes, filename, mime, question, model_choice, temperature, top_p, batch_size=3): | |
| api_key = os.environ.get("GOOGLE_API_KEY", DEFAULT_API_KEY) | |
| if not api_key: | |
| return "ERROR: Missing GOOGLE_API_KEY.", None | |
| genai.configure(api_key=api_key) | |
| model_name = INTERNAL_MODEL_MAP.get(model_choice, "gemini-2.5-flash") | |
| model = genai.GenerativeModel(model_name=model_name, | |
| generation_config={"temperature": float(temperature), "top_p": float(top_p)}) | |
| if file_bytes[:4] == b"%PDF": | |
| pages = pdf_to_images(file_bytes) | |
| else: | |
| pages = [Image.open(io.BytesIO(file_bytes))] | |
| user_prompt = (question or "").strip() or PROMPT_FREIGHT_JSON | |
| all_json_results, all_text_results = [], [] | |
| previous_header_json = None | |
| def _safe_text(resp): | |
| try: | |
| return resp.text | |
| except: | |
| return "" | |
| for i in range(0, len(pages), batch_size): | |
| batch = pages[i:i+batch_size] | |
| uploaded = [] | |
| for im in batch: | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as tmp: | |
| im.save(tmp.name) | |
| up = genai.upload_file(path=tmp.name, mime_type="image/png") | |
| up = genai.get_file(up.name) | |
| uploaded.append(up) | |
| context_prompt = user_prompt | |
| resp = model.generate_content([context_prompt] + uploaded) | |
| text = _safe_text(resp) | |
| all_text_results.append(text) | |
| for up in uploaded: | |
| try: | |
| genai.delete_file(up.name) | |
| except: | |
| pass | |
| return "\n\n".join(all_text_results), None | |
| # ================== EXTERNAL API (nếu có) ================== | |
| def run_process_external(file_bytes, filename, mime, question, api_url, temperature, top_p): | |
| if not api_url: | |
| return "ERROR: Missing external API endpoint.", None | |
| data = {"prompt": question or "", "temperature": str(temperature), "top_p": str(top_p)} | |
| files = {"file": (filename, file_bytes, mime)} | |
| r = requests.post(api_url, files=files, data=data, timeout=60) | |
| if r.status_code >= 400: | |
| return f"ERROR: External API HTTP {r.status_code}: {r.text[:200]}", None | |
| return r.text, None | |
| # ================== MAIN ROUTER (đã thêm STEP CHECK) ================== | |
| def run_process(file, question, model_choice, temperature, top_p, external_api_url): | |
| """ | |
| Router (có bước kiểm tra PDF/table trước khi xử lý): | |
| - Nếu PDF nhiều trang/nhiều bảng -> extract trước (pdfplumber) | |
| - Ngược lại -> OCR trực tiếp Gemini | |
| """ | |
| try: | |
| if file is None: | |
| return "ERROR: No file uploaded.", None | |
| file_bytes = _read_file_bytes(file) | |
| filename, mime = _guess_name_and_mime(file, file_bytes) | |
| # STEP 1️⃣: Check PDF structure | |
| if mime == "application/pdf" or file_bytes[:4] == b"%PDF": | |
| check_result = check_pdf_structure(file_bytes) | |
| print(f"[PDF Check] {filename}: {check_result}") | |
| if check_result == "có" and 1==2: # bỏ qua if này test thử prompt nhiều hãng | |
| try: | |
| print("➡️ PDF có nhiều cột/nhiều trang → dùng pdfplumber extract trước rồi Gemini.") | |
| all_dfs = [] | |
| saved_header = None | |
| with pdfplumber.open(io.BytesIO(file_bytes)) as pdf: | |
| for page_idx, page in enumerate(pdf.pages, start=1): | |
| print(f"📄 Đang xử lý trang {page_idx}...") | |
| table = page.extract_table({ | |
| "vertical_strategy": "lines", | |
| "horizontal_strategy": "text", | |
| "snap_tolerance": 3, | |
| "intersection_tolerance": 5, | |
| }) | |
| if not table or len(table) < 2: | |
| print(f"⚠️ Trang {page_idx}: Không phát hiện bảng hợp lệ.") | |
| continue | |
| header = table[0] | |
| rows = table[1:] | |
| # Lưu header đầu tiên | |
| if saved_header is None: | |
| saved_header = header | |
| print(f"✅ Trang {page_idx}: Lưu header đầu tiên: {saved_header}") | |
| # Nếu trang sau không có header rõ → dùng header cũ | |
| if len(header) < len(saved_header) or "REGION" not in header[0]: | |
| print(f"↩️ Trang {page_idx}: Không có header rõ ràng, dùng lại header trước.") | |
| header = saved_header | |
| rows = table | |
| else: | |
| saved_header = header # cập nhật header hợp lệ | |
| if len(rows) == 0: | |
| print(f"⚠️ Trang {page_idx}: Không có dữ liệu dưới header.") | |
| continue | |
| try: | |
| df = pd.DataFrame(rows, columns=header) | |
| all_dfs.append(df) | |
| print(f"✅ Trang {page_idx}: {len(df)} dòng được thêm.") | |
| except Exception as e: | |
| print(f"❌ Lỗi tạo DataFrame ở trang {page_idx}: {e}") | |
| if all_dfs: | |
| final_df = pd.concat(all_dfs, ignore_index=True).dropna(how="all").reset_index(drop=True) | |
| print(f"✅ Tổng cộng {len(final_df)} dòng được trích xuất từ PDF.") | |
| # Xuất ra file tạm (Excel + JSON) | |
| base_name = os.path.splitext(filename)[0] | |
| tmp_dir = tempfile.gettempdir() | |
| # json_path = os.path.join(tmp_dir, f"{base_name}.json") | |
| # excel_path = os.path.join(tmp_dir, f"{base_name}.xlsx") | |
| # final_df.to_json(json_path, orient="records", force_ascii=False, indent=2) | |
| # final_df.to_excel(excel_path, index=False) | |
| # print(f"✅ Xuất JSON: {json_path}") | |
| # print(f"✅ Xuất Excel: {excel_path}") | |
| # Convert bảng thành CSV text để Gemini đọc tiếp | |
| table_text = final_df.to_csv(index=False) | |
| print(f"✅ Đang Gen text từ file CSV") | |
| question = ( | |
| f"{PROMPT_FREIGHT_JSON}\n" | |
| "Below is the table text extracted from the PDF (CSV format):\n" | |
| f"{table_text}\n\n" | |
| "Please convert this into valid JSON as per the schema." | |
| ) | |
| else: | |
| print("⚠️ Không có bảng hợp lệ để extract bằng pdfplumber.") | |
| except Exception as e: | |
| print("❌ pdfplumber extract failed:", e) | |
| # STEP 2️⃣: Route model | |
| if model_choice == EXTERNAL_MODEL_NAME: | |
| return run_process_external( | |
| file_bytes=file_bytes, filename=filename, mime=mime, | |
| question=question, api_url=external_api_url, | |
| temperature=temperature, top_p=top_p | |
| ) | |
| return run_process_internal_base_v2( | |
| file_bytes=file_bytes, filename=filename, mime=mime, | |
| question=question, model_choice=model_choice, | |
| temperature=temperature, top_p=top_p | |
| ) | |
| except Exception as e: | |
| return f"ERROR: {type(e).__name__}: {str(e)}", None | |
| # ================== UI ================== | |
| def main(): | |
| with gr.Blocks(title="OCR Multi-Agent System") as demo: | |
| file = gr.File(label="Upload PDF/Image") | |
| question = gr.Textbox(label="Prompt", lines=2) | |
| model_choice = gr.Dropdown(choices=[*INTERNAL_MODEL_MAP.keys(), EXTERNAL_MODEL_NAME], | |
| value="Gemini 2.5 Flash", label="Model") | |
| temperature = gr.Slider(0.0, 2.0, value=0.2, step=0.05) | |
| top_p = gr.Slider(0.0, 1.0, value=0.95, step=0.01) | |
| external_api_url = gr.Textbox(label="External API URL", visible=False) | |
| output_text = gr.Code(label="Output", language="json") | |
| run_btn = gr.Button("🚀 Process") | |
| run_btn.click( | |
| run_process, | |
| inputs=[file, question, model_choice, temperature, top_p, external_api_url], | |
| outputs=[output_text, gr.State()] | |
| ) | |
| return demo | |
| demo = main() | |
| if __name__ == "__main__": | |
| demo.launch() | |