|
|
import gradio as gr |
|
|
from transformers import AutoProcessor, AutoModelForVision2Seq |
|
|
from PIL import Image |
|
|
import torch |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
model_id = "prithivMLmods/Camel-Doc-OCR-062825" |
|
|
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) |
|
|
model = AutoModelForVision2Seq.from_pretrained( |
|
|
model_id, |
|
|
torch_dtype=torch.float16 if device == "cuda" else torch.float32, |
|
|
trust_remote_code=True |
|
|
).to(device) |
|
|
|
|
|
def predict(image, prompt=None): |
|
|
image = image.convert("RGB") |
|
|
|
|
|
|
|
|
inputs = processor(images=image, text="", return_tensors="pt").to(device) |
|
|
|
|
|
print(">>> input_ids shape:", inputs.input_ids.shape) |
|
|
generated_ids = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=512, |
|
|
do_sample=False, |
|
|
use_cache=False, |
|
|
eos_token_id=processor.tokenizer.eos_token_id, |
|
|
pad_token_id=processor.tokenizer.pad_token_id |
|
|
) |
|
|
|
|
|
result = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] |
|
|
return result |
|
|
|
|
|
Hàm mock xử lý ảnh — chỉ để test UI |
|
|
def mock_predict(image, prompt=None): |
|
|
return f"Fake OCR result for image. Prompt: {prompt or 'N/A'}" |
|
|
|
|
|
demo = gr.Interface( |
|
|
|
|
|
fn=mock_predict, |
|
|
inputs=[ |
|
|
gr.Image(type="pil", label="Tải ảnh tài liệu lên"), |
|
|
gr.Textbox(label="Gợi ý (tuỳ chọn)", placeholder="VD: Trích số hóa đơn") |
|
|
], |
|
|
outputs="text", |
|
|
title="Camel-Doc OCR - Trích xuất văn bản từ ảnh" |
|
|
) |
|
|
|
|
|
|
|
|
|