Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -5,31 +5,41 @@ import logging
|
|
| 5 |
from typing import Optional, Union
|
| 6 |
import os
|
| 7 |
import spaces
|
| 8 |
-
from dotenv import load_dotenv
|
| 9 |
|
| 10 |
-
|
|
|
|
|
|
|
| 11 |
|
| 12 |
# Configure logging
|
| 13 |
logging.basicConfig(level=logging.INFO)
|
| 14 |
logger = logging.getLogger(__name__)
|
| 15 |
|
| 16 |
class AtlasOCR:
|
| 17 |
-
def __init__(self, model_name: str = "atlasia/AtlasOCR", max_tokens: int = 2000):
|
| 18 |
"""Initialize the AtlasOCR model with proper error handling."""
|
| 19 |
try:
|
| 20 |
from unsloth import FastVisionModel
|
| 21 |
|
| 22 |
logger.info(f"Loading model: {model_name}")
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
self.max_tokens = max_tokens
|
| 31 |
self.prompt = ""
|
| 32 |
-
|
|
|
|
| 33 |
|
| 34 |
except ImportError:
|
| 35 |
logger.error("unsloth not found. Please install it: pip install unsloth")
|
|
@@ -81,22 +91,30 @@ class AtlasOCR:
|
|
| 81 |
|
| 82 |
inputs = self.prepare_inputs(image)
|
| 83 |
|
| 84 |
-
# Move inputs to
|
| 85 |
-
device =
|
| 86 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
|
| 88 |
-
# Ensure attention_mask is float32
|
| 89 |
if 'attention_mask' in inputs:
|
| 90 |
-
inputs['attention_mask'] = inputs['attention_mask'].to(torch.float32)
|
| 91 |
|
| 92 |
logger.info(f"Generating text with max_tokens={self.max_tokens}")
|
| 93 |
-
|
|
|
|
|
|
|
| 94 |
generated_ids = self.model.generate(
|
| 95 |
**inputs,
|
| 96 |
max_new_tokens=self.max_tokens,
|
| 97 |
use_cache=True,
|
| 98 |
do_sample=False,
|
| 99 |
-
temperature=0.1
|
|
|
|
| 100 |
)
|
| 101 |
|
| 102 |
generated_ids_trimmed = [
|
|
|
|
| 5 |
from typing import Optional, Union
|
| 6 |
import os
|
| 7 |
import spaces
|
|
|
|
| 8 |
|
| 9 |
+
# Disable torch compilation to avoid dynamo issues
|
| 10 |
+
torch._dynamo.config.disable = True
|
| 11 |
+
torch.backends.cudnn.allow_tf32 = True
|
| 12 |
|
| 13 |
# Configure logging
|
| 14 |
logging.basicConfig(level=logging.INFO)
|
| 15 |
logger = logging.getLogger(__name__)
|
| 16 |
|
| 17 |
class AtlasOCR:
|
| 18 |
+
def __init__(self, model_name: str = "atlasia/AtlasOCR-v0", max_tokens: int = 2000):
|
| 19 |
"""Initialize the AtlasOCR model with proper error handling."""
|
| 20 |
try:
|
| 21 |
from unsloth import FastVisionModel
|
| 22 |
|
| 23 |
logger.info(f"Loading model: {model_name}")
|
| 24 |
+
|
| 25 |
+
# Disable compilation for the model
|
| 26 |
+
with torch._dynamo.config.patch(disable=True):
|
| 27 |
+
self.model, self.processor = FastVisionModel.from_pretrained(
|
| 28 |
+
model_name,
|
| 29 |
+
device_map="auto",
|
| 30 |
+
load_in_4bit=True,
|
| 31 |
+
use_gradient_checkpointing="unsloth",
|
| 32 |
+
torch_dtype=torch.float16
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
# Ensure model is not compiled
|
| 36 |
+
if hasattr(self.model, '_dynamo_compile'):
|
| 37 |
+
self.model._dynamo_compile = False
|
| 38 |
+
|
| 39 |
self.max_tokens = max_tokens
|
| 40 |
self.prompt = ""
|
| 41 |
+
self.device = next(self.model.parameters()).device
|
| 42 |
+
logger.info(f"Model loaded successfully on device: {self.device}")
|
| 43 |
|
| 44 |
except ImportError:
|
| 45 |
logger.error("unsloth not found. Please install it: pip install unsloth")
|
|
|
|
| 91 |
|
| 92 |
inputs = self.prepare_inputs(image)
|
| 93 |
|
| 94 |
+
# Move inputs to the same device as model with explicit device handling
|
| 95 |
+
device = self.device
|
| 96 |
+
logger.info(f"Moving inputs to device: {device}")
|
| 97 |
+
|
| 98 |
+
# Manually move each tensor to device
|
| 99 |
+
for key in inputs:
|
| 100 |
+
if hasattr(inputs[key], 'to'):
|
| 101 |
+
inputs[key] = inputs[key].to(device)
|
| 102 |
|
| 103 |
+
# Ensure attention_mask is float32 and on correct device
|
| 104 |
if 'attention_mask' in inputs:
|
| 105 |
+
inputs['attention_mask'] = inputs['attention_mask'].to(dtype=torch.float32, device=device)
|
| 106 |
|
| 107 |
logger.info(f"Generating text with max_tokens={self.max_tokens}")
|
| 108 |
+
|
| 109 |
+
# Disable compilation during generation
|
| 110 |
+
with torch.no_grad(), torch._dynamo.config.patch(disable=True):
|
| 111 |
generated_ids = self.model.generate(
|
| 112 |
**inputs,
|
| 113 |
max_new_tokens=self.max_tokens,
|
| 114 |
use_cache=True,
|
| 115 |
do_sample=False,
|
| 116 |
+
temperature=0.1,
|
| 117 |
+
pad_token_id=self.processor.tokenizer.eos_token_id
|
| 118 |
)
|
| 119 |
|
| 120 |
generated_ids_trimmed = [
|