Spaces:
Runtime error
Runtime error
bugs
Browse files
app.py
CHANGED
|
@@ -15,6 +15,15 @@ import io
|
|
| 15 |
import base64
|
| 16 |
import json
|
| 17 |
from datetime import datetime
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
@st.cache_resource
|
| 20 |
def load_model(model_name):
|
|
@@ -27,9 +36,72 @@ def load_model(model_name):
|
|
| 27 |
dict: Dictionary containing model components
|
| 28 |
"""
|
| 29 |
try:
|
| 30 |
-
if model_name == "
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base")
|
| 32 |
model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base")
|
|
|
|
| 33 |
# Configure Donut specific parameters
|
| 34 |
model.config.decoder_start_token_id = processor.tokenizer.bos_token_id
|
| 35 |
model.config.pad_token_id = processor.tokenizer.pad_token_id
|
|
@@ -42,34 +114,13 @@ def load_model(model_name):
|
|
| 42 |
model = LayoutLMv3ForSequenceClassification.from_pretrained("microsoft/layoutlmv3-base")
|
| 43 |
|
| 44 |
return {'model': model, 'processor': processor}
|
| 45 |
-
|
| 46 |
-
elif model_name == "OmniParser":
|
| 47 |
-
# Load YOLO model for icon detection
|
| 48 |
-
yolo_model = YOLO("microsoft/OmniParser-icon-detection")
|
| 49 |
-
|
| 50 |
-
# Load Florence-2 processor and model for captioning
|
| 51 |
-
processor = AutoProcessor.from_pretrained(
|
| 52 |
-
"microsoft/OmniParser-caption",
|
| 53 |
-
trust_remote_code=True
|
| 54 |
-
)
|
| 55 |
-
|
| 56 |
-
# Load the captioning model
|
| 57 |
-
caption_model = AutoModelForCausalLM.from_pretrained(
|
| 58 |
-
"microsoft/OmniParser-caption",
|
| 59 |
-
trust_remote_code=True
|
| 60 |
-
)
|
| 61 |
-
|
| 62 |
-
return {
|
| 63 |
-
'yolo': yolo_model,
|
| 64 |
-
'processor': processor,
|
| 65 |
-
'model': caption_model
|
| 66 |
-
}
|
| 67 |
|
| 68 |
else:
|
| 69 |
raise ValueError(f"Unknown model name: {model_name}")
|
| 70 |
|
| 71 |
except Exception as e:
|
| 72 |
st.error(f"Error loading model {model_name}: {str(e)}")
|
|
|
|
| 73 |
return None
|
| 74 |
|
| 75 |
@spaces.GPU
|
|
@@ -357,16 +408,20 @@ if uploaded_file is not None and selected_model:
|
|
| 357 |
st.info("Loading model...")
|
| 358 |
|
| 359 |
add_debug(f"Loading {selected_model} model and processor...")
|
| 360 |
-
|
| 361 |
|
| 362 |
-
if
|
| 363 |
with result_col:
|
| 364 |
st.error("Failed to load model. Please try again.")
|
| 365 |
add_debug("Model loading failed!", "error")
|
| 366 |
else:
|
| 367 |
add_debug("Model loaded successfully", "success")
|
| 368 |
-
|
| 369 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 370 |
|
| 371 |
# Update progress
|
| 372 |
with result_col:
|
|
@@ -379,7 +434,7 @@ if uploaded_file is not None and selected_model:
|
|
| 379 |
|
| 380 |
# Analyze document
|
| 381 |
add_debug("Starting document analysis...")
|
| 382 |
-
results = analyze_document(image, selected_model,
|
| 383 |
add_debug("Analysis completed", "success")
|
| 384 |
|
| 385 |
# Update progress
|
|
@@ -425,6 +480,37 @@ if uploaded_file is not None and selected_model:
|
|
| 425 |
add_debug("Traceback available in logs", "warning")
|
| 426 |
|
| 427 |
# Add improved information about usage and limitations
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 428 |
st.markdown("""
|
| 429 |
---
|
| 430 |
### Usage Notes:
|
|
|
|
| 15 |
import base64
|
| 16 |
import json
|
| 17 |
from datetime import datetime
|
| 18 |
+
import os
|
| 19 |
+
import logging
|
| 20 |
+
|
| 21 |
+
# Add this near the top of the file, after imports
|
| 22 |
+
logging.basicConfig(
|
| 23 |
+
level=logging.INFO,
|
| 24 |
+
format='%(asctime)s - %(levelname)s - %(message)s'
|
| 25 |
+
)
|
| 26 |
+
logger = logging.getLogger(__name__)
|
| 27 |
|
| 28 |
@st.cache_resource
|
| 29 |
def load_model(model_name):
|
|
|
|
| 36 |
dict: Dictionary containing model components
|
| 37 |
"""
|
| 38 |
try:
|
| 39 |
+
if model_name == "OmniParser":
|
| 40 |
+
try:
|
| 41 |
+
# First try loading from HuggingFace Hub with correct repository structure
|
| 42 |
+
yolo_model = YOLO("microsoft/OmniParser/icon_detect") # Updated path
|
| 43 |
+
|
| 44 |
+
processor = AutoProcessor.from_pretrained(
|
| 45 |
+
"microsoft/OmniParser/icon_caption_florence", # Updated path
|
| 46 |
+
trust_remote_code=True
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
caption_model = AutoModelForCausalLM.from_pretrained(
|
| 50 |
+
"microsoft/OmniParser/icon_caption_florence", # Updated path
|
| 51 |
+
trust_remote_code=True,
|
| 52 |
+
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
if torch.cuda.is_available():
|
| 56 |
+
caption_model = caption_model.to("cuda")
|
| 57 |
+
|
| 58 |
+
st.success("Successfully loaded OmniParser models")
|
| 59 |
+
return {
|
| 60 |
+
'yolo': yolo_model,
|
| 61 |
+
'processor': processor,
|
| 62 |
+
'model': caption_model
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
except Exception as e:
|
| 66 |
+
st.error(f"Failed to load OmniParser from HuggingFace Hub: {str(e)}")
|
| 67 |
+
|
| 68 |
+
# Try loading from local weights if available
|
| 69 |
+
weights_path = "weights"
|
| 70 |
+
if os.path.exists(os.path.join(weights_path, "icon_detect/model.safetensors")):
|
| 71 |
+
st.info("Attempting to load from local weights...")
|
| 72 |
+
|
| 73 |
+
yolo_model = YOLO(os.path.join(weights_path, "icon_detect/model.safetensors"))
|
| 74 |
+
|
| 75 |
+
processor = AutoProcessor.from_pretrained(
|
| 76 |
+
os.path.join(weights_path, "icon_caption_florence"),
|
| 77 |
+
trust_remote_code=True,
|
| 78 |
+
local_files_only=True
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
caption_model = AutoModelForCausalLM.from_pretrained(
|
| 82 |
+
os.path.join(weights_path, "icon_caption_florence"),
|
| 83 |
+
trust_remote_code=True,
|
| 84 |
+
local_files_only=True,
|
| 85 |
+
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
if torch.cuda.is_available():
|
| 89 |
+
caption_model = caption_model.to("cuda")
|
| 90 |
+
|
| 91 |
+
st.success("Successfully loaded OmniParser from local weights")
|
| 92 |
+
return {
|
| 93 |
+
'yolo': yolo_model,
|
| 94 |
+
'processor': processor,
|
| 95 |
+
'model': caption_model
|
| 96 |
+
}
|
| 97 |
+
else:
|
| 98 |
+
st.error("Could not find local weights and HuggingFace Hub loading failed")
|
| 99 |
+
raise ValueError("No valid model weights found for OmniParser")
|
| 100 |
+
|
| 101 |
+
elif model_name == "Donut":
|
| 102 |
processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base")
|
| 103 |
model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base")
|
| 104 |
+
|
| 105 |
# Configure Donut specific parameters
|
| 106 |
model.config.decoder_start_token_id = processor.tokenizer.bos_token_id
|
| 107 |
model.config.pad_token_id = processor.tokenizer.pad_token_id
|
|
|
|
| 114 |
model = LayoutLMv3ForSequenceClassification.from_pretrained("microsoft/layoutlmv3-base")
|
| 115 |
|
| 116 |
return {'model': model, 'processor': processor}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
|
| 118 |
else:
|
| 119 |
raise ValueError(f"Unknown model name: {model_name}")
|
| 120 |
|
| 121 |
except Exception as e:
|
| 122 |
st.error(f"Error loading model {model_name}: {str(e)}")
|
| 123 |
+
logger.error(f"Error details: {str(e)}", exc_info=True)
|
| 124 |
return None
|
| 125 |
|
| 126 |
@spaces.GPU
|
|
|
|
| 408 |
st.info("Loading model...")
|
| 409 |
|
| 410 |
add_debug(f"Loading {selected_model} model and processor...")
|
| 411 |
+
models_dict = load_model(selected_model)
|
| 412 |
|
| 413 |
+
if models_dict is None:
|
| 414 |
with result_col:
|
| 415 |
st.error("Failed to load model. Please try again.")
|
| 416 |
add_debug("Model loading failed!", "error")
|
| 417 |
else:
|
| 418 |
add_debug("Model loaded successfully", "success")
|
| 419 |
+
# For device info, we need to check which model we're using
|
| 420 |
+
if selected_model == "OmniParser":
|
| 421 |
+
model_device = next(models_dict['model'].parameters()).device
|
| 422 |
+
else:
|
| 423 |
+
model_device = next(models_dict['model'].parameters()).device
|
| 424 |
+
add_debug(f"Model device: {model_device}")
|
| 425 |
|
| 426 |
# Update progress
|
| 427 |
with result_col:
|
|
|
|
| 434 |
|
| 435 |
# Analyze document
|
| 436 |
add_debug("Starting document analysis...")
|
| 437 |
+
results = analyze_document(image, selected_model, models_dict)
|
| 438 |
add_debug("Analysis completed", "success")
|
| 439 |
|
| 440 |
# Update progress
|
|
|
|
| 480 |
add_debug("Traceback available in logs", "warning")
|
| 481 |
|
| 482 |
# Add improved information about usage and limitations
|
| 483 |
+
def verify_weights_directory():
|
| 484 |
+
"""Verify the weights directory structure and files"""
|
| 485 |
+
weights_path = "weights"
|
| 486 |
+
required_files = {
|
| 487 |
+
os.path.join(weights_path, "icon_detect", "model.safetensors"): "YOLO model weights",
|
| 488 |
+
os.path.join(weights_path, "icon_detect", "model.yaml"): "YOLO model config",
|
| 489 |
+
os.path.join(weights_path, "icon_caption_florence", "model.safetensors"): "Florence model weights",
|
| 490 |
+
os.path.join(weights_path, "icon_caption_florence", "config.json"): "Florence model config",
|
| 491 |
+
os.path.join(weights_path, "icon_caption_florence", "generation_config.json"): "Florence generation config"
|
| 492 |
+
}
|
| 493 |
+
|
| 494 |
+
missing_files = []
|
| 495 |
+
for file_path, description in required_files.items():
|
| 496 |
+
if not os.path.exists(file_path):
|
| 497 |
+
missing_files.append(f"{description} at {file_path}")
|
| 498 |
+
|
| 499 |
+
if missing_files:
|
| 500 |
+
st.warning("Missing required model files:")
|
| 501 |
+
for missing in missing_files:
|
| 502 |
+
st.write(f"- {missing}")
|
| 503 |
+
return False
|
| 504 |
+
|
| 505 |
+
return True
|
| 506 |
+
|
| 507 |
+
# Add this in your app's initialization
|
| 508 |
+
if st.checkbox("Check Model Files"):
|
| 509 |
+
if verify_weights_directory():
|
| 510 |
+
st.success("All required model files are present")
|
| 511 |
+
else:
|
| 512 |
+
st.error("Some model files are missing. Please ensure all required files are in the weights directory")
|
| 513 |
+
|
| 514 |
st.markdown("""
|
| 515 |
---
|
| 516 |
### Usage Notes:
|