abdeljalilELmajjodi commited on
Commit
11c251e
·
verified ·
1 Parent(s): 186fd60

Create atlasocr_model.py

Browse files
Files changed (1) hide show
  1. atlasocr_model.py +59 -0
atlasocr_model.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from PIL import Image
3
+ from unsloth import FastVisionModel
4
+ import torch
5
+
6
+ class AtlasOCR:
7
+ def __init__(self, model_name: str="atlasia/AtlasOCR-v0", max_tokens: int=2000):
8
+ self.model, self.processor = FastVisionModel.from_pretrained(
9
+ model_name,
10
+ device_map="auto",
11
+ load_in_4bit=True,
12
+ use_gradient_checkpointing="unsloth"
13
+ )
14
+ self.max_tokens = max_tokens
15
+ self.prompt = ""
16
+
17
+ def prepare_inputs(self,image:Image):
18
+ messages = [
19
+ {
20
+ "role": "user",
21
+ "content": [
22
+ {
23
+ "type": "image",
24
+ },
25
+ {"type": "text", "text": self.prompt},
26
+ ],
27
+ }
28
+ ]
29
+
30
+ text = self.processor.apply_chat_template(
31
+ messages, tokenize=False, add_generation_prompt=True
32
+ )
33
+
34
+ inputs = self.processor(
35
+ image,
36
+ text,
37
+ add_special_tokens=False,
38
+ return_tensors="pt",
39
+ )
40
+ return inputs
41
+
42
+ def predict(self,image:Image) -> str:
43
+ inputs = self.prepare_inputs(image)
44
+ inputs = inputs.to("cuda")
45
+
46
+ inputs['attention_mask'] = inputs['attention_mask'].to(torch.float32)
47
+ print("attention_mask dtype:", inputs['attention_mask'].dtype)
48
+
49
+ generated_ids = self.model.generate(**inputs, max_new_tokens=self.max_tokens, use_cache=True)
50
+ generated_ids_trimmed = [
51
+ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
52
+ ]
53
+ output_text = self.processor.batch_decode(
54
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
55
+ )
56
+ return output_text[0]
57
+
58
+ def __call__(self, _: str, image: Image) -> str:
59
+ return self.predict(image)