vithacocf commited on
Commit
c786b95
·
verified ·
1 Parent(s): 25db7d4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -27
app.py CHANGED
@@ -49,59 +49,92 @@
49
 
50
  # Code fix
51
  import gradio as gr
52
- from transformers import AutoProcessor, AutoModelForVision2Seq
53
  from PIL import Image, UnidentifiedImageError
 
 
54
  import torch
55
- import os
 
56
 
57
  # Cấu hình thiết bị
58
- device = "cuda" if torch.cuda.is_available() else "cpu"
59
  torch.cuda.empty_cache()
60
 
61
- # Load mô hình
62
  model_id = "prithivMLmods/Camel-Doc-OCR-062825"
 
 
 
 
 
 
 
 
63
  processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
64
- model = AutoModelForVision2Seq.from_pretrained(
65
  model_id,
66
- torch_dtype=torch.float16 if device == "cuda" else torch.float32,
 
67
  trust_remote_code=True
68
- ).to(device)
69
 
70
- # Hàm xử lý ảnh (nếu có kênh alpha)
71
  def convert_png_to_jpg(image):
72
  if image.mode in ["RGBA", "LA"]:
73
  converted = Image.new("RGB", image.size, (255, 255, 255))
74
- converted.paste(image, mask=image.split()[-1]) # Dùng alpha làm mask
75
  return converted
76
  return image.convert("RGB")
77
 
78
- # Hàm chính
79
- def predict(image, prompt=None):
80
  if image is None:
81
  return "=Vui lòng tải lên ảnh hợp lệ."
82
 
83
- if prompt is None or prompt.strip() == "":
84
- return "=Vui lòng nhập prompt để trích xuất dữ liệu."
85
-
86
  try:
87
  image = convert_png_to_jpg(image)
88
-
89
- inputs = processor(images=image, text=prompt, return_tensors="pt").to(device)
90
- generated_ids = model.generate(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  **inputs,
92
- max_new_tokens=512,
93
- do_sample=False,
94
- use_cache=False,
95
- eos_token_id=processor.tokenizer.eos_token_id,
96
- pad_token_id=processor.tokenizer.pad_token_id
97
- )
98
- result = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
99
- return result
 
 
 
 
 
 
 
100
 
101
  except UnidentifiedImageError:
102
- return "=Không thể đọc ảnh. Ảnh có thể bị hỏng hoặc sai định dạng."
103
  except Exception as e:
104
- return f"=Lỗi khi xử lý ảnh: {str(e)}"
105
 
106
  demo = gr.Interface(
107
  fn=predict,
 
49
 
50
  # Code fix
51
  import gradio as gr
 
52
  from PIL import Image, UnidentifiedImageError
53
+ from transformers import AutoProcessor, BitsAndBytesConfig, TextIteratorStreamer
54
+ from transformers.models.qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
55
  import torch
56
+ from threading import Thread
57
+ import time
58
 
59
  # Cấu hình thiết bị
60
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
61
  torch.cuda.empty_cache()
62
 
63
+ # Load mô hình Qwen2.5-VL với quantization 4-bit
64
  model_id = "prithivMLmods/Camel-Doc-OCR-062825"
65
+
66
+ bnb_config = BitsAndBytesConfig(
67
+ load_in_4bit=True,
68
+ bnb_4bit_use_double_quant=True,
69
+ bnb_4bit_quant_type="nf4",
70
+ bnb_4bit_compute_dtype=torch.float16
71
+ )
72
+
73
  processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
74
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
75
  model_id,
76
+ # quantization_config=bnb_config, Quantization
77
+ device_map="auto",
78
  trust_remote_code=True
79
+ ).eval()
80
 
 
81
  def convert_png_to_jpg(image):
82
  if image.mode in ["RGBA", "LA"]:
83
  converted = Image.new("RGB", image.size, (255, 255, 255))
84
+ converted.paste(image, mask=image.split()[-1])
85
  return converted
86
  return image.convert("RGB")
87
 
88
+ # Hàm dự đoán
89
+ def predict(image, prompt=""):
90
  if image is None:
91
  return "=Vui lòng tải lên ảnh hợp lệ."
92
 
 
 
 
93
  try:
94
  image = convert_png_to_jpg(image)
95
+ prompt = prompt.strip() if prompt else "Please describe the document."
96
+
97
+ # Xây dựng prompt theo định dạng Qwen2.5-VL
98
+ messages = [{
99
+ "role": "user",
100
+ "content": [
101
+ {"type": "image", "image": image},
102
+ {"type": "text", "text": prompt}
103
+ ]
104
+ }]
105
+ text_prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
106
+
107
+ inputs = processor(
108
+ text=[text_prompt],
109
+ images=[image],
110
+ return_tensors="pt",
111
+ padding=True
112
+ ).to(model.device)
113
+
114
+ # Dùng streamer để sinh kết quả mượt hơn
115
+ streamer = TextIteratorStreamer(processor.tokenizer, skip_special_tokens=True, skip_prompt=True)
116
+ generation_kwargs = {
117
  **inputs,
118
+ "streamer": streamer,
119
+ "max_new_tokens": 512,
120
+ "do_sample": False,
121
+ "use_cache": True
122
+ }
123
+
124
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
125
+ thread.start()
126
+
127
+ buffer = ""
128
+ for new_text in streamer:
129
+ buffer += new_text
130
+ time.sleep(0.01)
131
+
132
+ return buffer
133
 
134
  except UnidentifiedImageError:
135
+ return "Không thể đọc ảnh. Ảnh có thể bị hỏng hoặc sai định dạng."
136
  except Exception as e:
137
+ return f"Lỗi khi xử lý ảnh: {str(e)}"
138
 
139
  demo = gr.Interface(
140
  fn=predict,