abdeljalilELmajjodi commited on
Commit
343dd2c
·
verified ·
1 Parent(s): 10f15b5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -18
app.py CHANGED
@@ -5,31 +5,41 @@ import logging
5
  from typing import Optional, Union
6
  import os
7
  import spaces
8
- from dotenv import load_dotenv
9
 
10
- load_dotenv()
 
 
11
 
12
  # Configure logging
13
  logging.basicConfig(level=logging.INFO)
14
  logger = logging.getLogger(__name__)
15
 
16
  class AtlasOCR:
17
- def __init__(self, model_name: str = "atlasia/AtlasOCR", max_tokens: int = 2000):
18
  """Initialize the AtlasOCR model with proper error handling."""
19
  try:
20
  from unsloth import FastVisionModel
21
 
22
  logger.info(f"Loading model: {model_name}")
23
- self.model, self.processor = FastVisionModel.from_pretrained(
24
- model_name,
25
- device_map="cuda",
26
- load_in_4bit=True,
27
- use_gradient_checkpointing="unsloth",
28
- token=os.environ["HF_API_KEY"]
29
- )
 
 
 
 
 
 
 
 
30
  self.max_tokens = max_tokens
31
  self.prompt = ""
32
- logger.info("Model loaded successfully")
 
33
 
34
  except ImportError:
35
  logger.error("unsloth not found. Please install it: pip install unsloth")
@@ -81,22 +91,30 @@ class AtlasOCR:
81
 
82
  inputs = self.prepare_inputs(image)
83
 
84
- # Move inputs to GPU if available
85
- device = "cuda" if torch.cuda.is_available() else "cpu"
86
- inputs = {k: v.to(device) if hasattr(v, 'to') else v for k, v in inputs.items()}
 
 
 
 
 
87
 
88
- # Ensure attention_mask is float32
89
  if 'attention_mask' in inputs:
90
- inputs['attention_mask'] = inputs['attention_mask'].to(torch.float32)
91
 
92
  logger.info(f"Generating text with max_tokens={self.max_tokens}")
93
- with torch.no_grad():
 
 
94
  generated_ids = self.model.generate(
95
  **inputs,
96
  max_new_tokens=self.max_tokens,
97
  use_cache=True,
98
  do_sample=False,
99
- temperature=0.1
 
100
  )
101
 
102
  generated_ids_trimmed = [
 
5
  from typing import Optional, Union
6
  import os
7
  import spaces
 
8
 
9
+ # Disable torch compilation to avoid dynamo issues
10
+ torch._dynamo.config.disable = True
11
+ torch.backends.cudnn.allow_tf32 = True
12
 
13
  # Configure logging
14
  logging.basicConfig(level=logging.INFO)
15
  logger = logging.getLogger(__name__)
16
 
17
  class AtlasOCR:
18
+ def __init__(self, model_name: str = "atlasia/AtlasOCR-v0", max_tokens: int = 2000):
19
  """Initialize the AtlasOCR model with proper error handling."""
20
  try:
21
  from unsloth import FastVisionModel
22
 
23
  logger.info(f"Loading model: {model_name}")
24
+
25
+ # Disable compilation for the model
26
+ with torch._dynamo.config.patch(disable=True):
27
+ self.model, self.processor = FastVisionModel.from_pretrained(
28
+ model_name,
29
+ device_map="auto",
30
+ load_in_4bit=True,
31
+ use_gradient_checkpointing="unsloth",
32
+ torch_dtype=torch.float16
33
+ )
34
+
35
+ # Ensure model is not compiled
36
+ if hasattr(self.model, '_dynamo_compile'):
37
+ self.model._dynamo_compile = False
38
+
39
  self.max_tokens = max_tokens
40
  self.prompt = ""
41
+ self.device = next(self.model.parameters()).device
42
+ logger.info(f"Model loaded successfully on device: {self.device}")
43
 
44
  except ImportError:
45
  logger.error("unsloth not found. Please install it: pip install unsloth")
 
91
 
92
  inputs = self.prepare_inputs(image)
93
 
94
+ # Move inputs to the same device as model with explicit device handling
95
+ device = self.device
96
+ logger.info(f"Moving inputs to device: {device}")
97
+
98
+ # Manually move each tensor to device
99
+ for key in inputs:
100
+ if hasattr(inputs[key], 'to'):
101
+ inputs[key] = inputs[key].to(device)
102
 
103
+ # Ensure attention_mask is float32 and on correct device
104
  if 'attention_mask' in inputs:
105
+ inputs['attention_mask'] = inputs['attention_mask'].to(dtype=torch.float32, device=device)
106
 
107
  logger.info(f"Generating text with max_tokens={self.max_tokens}")
108
+
109
+ # Disable compilation during generation
110
+ with torch.no_grad(), torch._dynamo.config.patch(disable=True):
111
  generated_ids = self.model.generate(
112
  **inputs,
113
  max_new_tokens=self.max_tokens,
114
  use_cache=True,
115
  do_sample=False,
116
+ temperature=0.1,
117
+ pad_token_id=self.processor.tokenizer.eos_token_id
118
  )
119
 
120
  generated_ids_trimmed = [