File size: 1,877 Bytes
11c251e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import os
from PIL import Image
from unsloth import FastVisionModel
import torch

class AtlasOCR:
    def __init__(self, model_name: str="atlasia/AtlasOCR-v0", max_tokens: int=2000):
        self.model, self.processor = FastVisionModel.from_pretrained(
            model_name,
            device_map="auto",
            load_in_4bit=True,
            use_gradient_checkpointing="unsloth"
        )
        self.max_tokens = max_tokens
        self.prompt = ""

    def prepare_inputs(self,image:Image):
        messages = [
            {
                "role": "user",
                "content": [
                    {
                        "type": "image",
                    },
                    {"type": "text", "text": self.prompt},
                ],
            }
        ]

        text = self.processor.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )

        inputs = self.processor(
            image,
            text,
            add_special_tokens=False,
            return_tensors="pt",
        )
        return inputs

    def predict(self,image:Image) -> str:
        inputs = self.prepare_inputs(image)
        inputs = inputs.to("cuda")

        inputs['attention_mask'] = inputs['attention_mask'].to(torch.float32)
        print("attention_mask dtype:", inputs['attention_mask'].dtype)

        generated_ids = self.model.generate(**inputs, max_new_tokens=self.max_tokens, use_cache=True)
        generated_ids_trimmed = [
            out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
        ]
        output_text = self.processor.batch_decode(
            generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
        )
        return output_text[0]

    def __call__(self, _: str, image: Image) -> str:
        return self.predict(image)