Spaces:
Runtime error
Runtime error
feat: refactor prediction functions and enhance image loading capabilities for improved processing and noise estimation
Browse files- app.py +38 -17
- forensics/__init__.py +2 -2
- forensics/wavelet.py +1 -1
- utils/load.py +51 -0
app.py
CHANGED
|
@@ -13,9 +13,10 @@ from utils.utils import softmax, augment_image
|
|
| 13 |
from forensics.gradient import gradient_processing
|
| 14 |
from forensics.minmax import minmax_process
|
| 15 |
from forensics.ela import ELA
|
| 16 |
-
from forensics.wavelet import
|
| 17 |
from forensics.bitplane import bit_plane_extractor
|
| 18 |
from utils.hf_logger import log_inference_data
|
|
|
|
| 19 |
from agents.ensemble_team import EnsembleMonitorAgent, WeightOptimizationAgent, SystemHealthAgent
|
| 20 |
from agents.smart_agents import ContextualIntelligenceAgent, ForensicAnomalyDetectionAgent
|
| 21 |
from utils.registry import register_model, MODEL_REGISTRY, ModelEntry
|
|
@@ -191,9 +192,10 @@ def simple_prediction(img):
|
|
| 191 |
img_byte_arr = io.BytesIO()
|
| 192 |
img.save(img_byte_arr, format='PNG') # Using PNG for lossless conversion, can be JPEG if preferred
|
| 193 |
img_byte_arr.seek(0) # Rewind to the beginning of the stream
|
|
|
|
| 194 |
|
| 195 |
result = client.predict(
|
| 196 |
-
input_image=handle_file(
|
| 197 |
api_name="/simple_predict"
|
| 198 |
)
|
| 199 |
return result
|
|
@@ -247,15 +249,34 @@ def infer(image: Image.Image, model_id: str, confidence_threshold: float = 0.75)
|
|
| 247 |
"Label": f"Error: {str(e)}"
|
| 248 |
}
|
| 249 |
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 253 |
if not isinstance(img, Image.Image):
|
| 254 |
try:
|
| 255 |
img = Image.fromarray(img)
|
| 256 |
except Exception as e:
|
| 257 |
logger.error(f"Error converting input image to PIL: {e}")
|
| 258 |
-
raise
|
|
|
|
|
|
|
|
|
|
|
|
|
| 259 |
|
| 260 |
monitor_agent = EnsembleMonitorAgent()
|
| 261 |
weight_manager = ModelWeightManager(strongest_model_id="simple_prediction")
|
|
@@ -406,7 +427,7 @@ def ensemble_prediction_stream(img, confidence_threshold, augment_methods, rotat
|
|
| 406 |
yield img_pil, cleaned_forensics_images, table_rows, json_results, consensus_html
|
| 407 |
|
| 408 |
detection_model_eval_playground = gr.Interface(
|
| 409 |
-
fn=
|
| 410 |
inputs=[
|
| 411 |
gr.Image(label="Upload Image to Analyze", sources=['upload', 'webcam'], type='pil'),
|
| 412 |
gr.Slider(0.0, 1.0, value=0.7, step=0.05, label="Confidence Threshold"),
|
|
@@ -426,8 +447,8 @@ detection_model_eval_playground = gr.Interface(
|
|
| 426 |
gr.JSON(label="Raw Model Results", visible=False),
|
| 427 |
gr.Markdown(label="Consensus", value="")
|
| 428 |
],
|
| 429 |
-
title="
|
| 430 |
-
description="
|
| 431 |
api_name="predict",
|
| 432 |
live=True # Enable streaming
|
| 433 |
)
|
|
@@ -436,9 +457,9 @@ community_forensics_preview = gr.Interface(
|
|
| 436 |
fn=lambda: gr.load("aiwithoutborders-xyz/OpenSight-Community-Forensics-Preview", src="spaces"),
|
| 437 |
inputs=None,
|
| 438 |
outputs=gr.HTML(), # or gr.Markdown() if it's just text
|
| 439 |
-
title="
|
| 440 |
-
description="
|
| 441 |
-
api_name="
|
| 442 |
)
|
| 443 |
|
| 444 |
leaderboard = gr.Interface(
|
|
@@ -453,13 +474,13 @@ simple_predict_interface = gr.Interface(
|
|
| 453 |
fn=simple_prediction,
|
| 454 |
inputs=gr.Image(type="filepath"),
|
| 455 |
outputs=gr.Text(),
|
| 456 |
-
title="
|
| 457 |
-
description="",
|
| 458 |
api_name="simple_predict"
|
| 459 |
)
|
| 460 |
|
| 461 |
-
|
| 462 |
-
fn=
|
| 463 |
inputs=[gr.Image(type="pil"), gr.Slider(1, 32, value=8, step=1, label="Block Size")],
|
| 464 |
outputs=gr.Image(type="pil"),
|
| 465 |
title="Wavelet-Based Noise Analysis",
|
|
@@ -529,7 +550,7 @@ demo = gr.TabbedInterface(
|
|
| 529 |
[
|
| 530 |
detection_model_eval_playground,
|
| 531 |
simple_predict_interface,
|
| 532 |
-
|
| 533 |
bit_plane_interface,
|
| 534 |
ela_interface,
|
| 535 |
gradient_processing_interface,
|
|
|
|
| 13 |
from forensics.gradient import gradient_processing
|
| 14 |
from forensics.minmax import minmax_process
|
| 15 |
from forensics.ela import ELA
|
| 16 |
+
from forensics.wavelet import noise_estimation
|
| 17 |
from forensics.bitplane import bit_plane_extractor
|
| 18 |
from utils.hf_logger import log_inference_data
|
| 19 |
+
from utils.load import load_image
|
| 20 |
from agents.ensemble_team import EnsembleMonitorAgent, WeightOptimizationAgent, SystemHealthAgent
|
| 21 |
from agents.smart_agents import ContextualIntelligenceAgent, ForensicAnomalyDetectionAgent
|
| 22 |
from utils.registry import register_model, MODEL_REGISTRY, ModelEntry
|
|
|
|
| 192 |
img_byte_arr = io.BytesIO()
|
| 193 |
img.save(img_byte_arr, format='PNG') # Using PNG for lossless conversion, can be JPEG if preferred
|
| 194 |
img_byte_arr.seek(0) # Rewind to the beginning of the stream
|
| 195 |
+
im = load_image(img)
|
| 196 |
|
| 197 |
result = client.predict(
|
| 198 |
+
input_image=handle_file(im),
|
| 199 |
api_name="/simple_predict"
|
| 200 |
)
|
| 201 |
return result
|
|
|
|
| 249 |
"Label": f"Error: {str(e)}"
|
| 250 |
}
|
| 251 |
|
| 252 |
+
def full_prediction(img, confidence_threshold, augment_methods, rotate_degrees, noise_level, sharpen_strength):
|
| 253 |
+
"""Full prediction run, with a team of ensembles and agents.
|
| 254 |
+
|
| 255 |
+
Args:
|
| 256 |
+
img (url: str, Image.Image, np.ndarray): The input image to classify.
|
| 257 |
+
confidence_threshold (float, optional): The confidence threshold for classification. Defaults to 0.75.
|
| 258 |
+
augment_methods (list, optional): The augmentation methods to use.
|
| 259 |
+
rotate_degrees (int, optional): The degrees to rotate the image.
|
| 260 |
+
noise_level (int, optional): The noise level to use.
|
| 261 |
+
sharpen_strength (int, optional): The sharpen strength to use.
|
| 262 |
+
|
| 263 |
+
Returns:
|
| 264 |
+
dict: A dictionary containing the model details, classification scores, and label.
|
| 265 |
+
"""
|
| 266 |
+
# Ensure img is a PIL Image object
|
| 267 |
+
if img is None:
|
| 268 |
+
raise gr.Error("No image provided. Please upload an image to analyze.")
|
| 269 |
+
|
| 270 |
if not isinstance(img, Image.Image):
|
| 271 |
try:
|
| 272 |
img = Image.fromarray(img)
|
| 273 |
except Exception as e:
|
| 274 |
logger.error(f"Error converting input image to PIL: {e}")
|
| 275 |
+
raise gr.Error("Input image could not be converted to a valid image format. Please try another image.")
|
| 276 |
+
|
| 277 |
+
# Ensure image is in RGB format for consistent processing
|
| 278 |
+
if img.mode != 'RGB':
|
| 279 |
+
img = img.convert('RGB')
|
| 280 |
|
| 281 |
monitor_agent = EnsembleMonitorAgent()
|
| 282 |
weight_manager = ModelWeightManager(strongest_model_id="simple_prediction")
|
|
|
|
| 427 |
yield img_pil, cleaned_forensics_images, table_rows, json_results, consensus_html
|
| 428 |
|
| 429 |
detection_model_eval_playground = gr.Interface(
|
| 430 |
+
fn=full_prediction,
|
| 431 |
inputs=[
|
| 432 |
gr.Image(label="Upload Image to Analyze", sources=['upload', 'webcam'], type='pil'),
|
| 433 |
gr.Slider(0.0, 1.0, value=0.7, step=0.05, label="Confidence Threshold"),
|
|
|
|
| 447 |
gr.JSON(label="Raw Model Results", visible=False),
|
| 448 |
gr.Markdown(label="Consensus", value="")
|
| 449 |
],
|
| 450 |
+
title="Multi-Model Ensemble + Agentic Coordinated Deepfake Detection",
|
| 451 |
+
description="The detection of AI-generated images has entered a critical inflection point. While existing solutions struggle with outdated datasets and inflated claims, our approach prioritizes agility, community collaboration, and an offensive approach to deepfake detection.",
|
| 452 |
api_name="predict",
|
| 453 |
live=True # Enable streaming
|
| 454 |
)
|
|
|
|
| 457 |
fn=lambda: gr.load("aiwithoutborders-xyz/OpenSight-Community-Forensics-Preview", src="spaces"),
|
| 458 |
inputs=None,
|
| 459 |
outputs=gr.HTML(), # or gr.Markdown() if it's just text
|
| 460 |
+
title="Quick and simple prediction by our strongest model.",
|
| 461 |
+
description="No ensemble, no context, no agents, just a quick and simple prediction by our strongest model.",
|
| 462 |
+
api_name="quick_predict"
|
| 463 |
)
|
| 464 |
|
| 465 |
leaderboard = gr.Interface(
|
|
|
|
| 474 |
fn=simple_prediction,
|
| 475 |
inputs=gr.Image(type="filepath"),
|
| 476 |
outputs=gr.Text(),
|
| 477 |
+
title="Quick and simple prediction by our strongest model.",
|
| 478 |
+
description="No ensemble, no context, no agents, just a quick and simple prediction by our strongest model.",
|
| 479 |
api_name="simple_predict"
|
| 480 |
)
|
| 481 |
|
| 482 |
+
noise_estimation_interface = gr.Interface(
|
| 483 |
+
fn=noise_estimation,
|
| 484 |
inputs=[gr.Image(type="pil"), gr.Slider(1, 32, value=8, step=1, label="Block Size")],
|
| 485 |
outputs=gr.Image(type="pil"),
|
| 486 |
title="Wavelet-Based Noise Analysis",
|
|
|
|
| 550 |
[
|
| 551 |
detection_model_eval_playground,
|
| 552 |
simple_predict_interface,
|
| 553 |
+
noise_estimation_interface,
|
| 554 |
bit_plane_interface,
|
| 555 |
ela_interface,
|
| 556 |
gradient_processing_interface,
|
forensics/__init__.py
CHANGED
|
@@ -3,7 +3,7 @@ from .ela import ELA
|
|
| 3 |
# from .exif import exif_full_dump
|
| 4 |
from .gradient import gradient_processing
|
| 5 |
from .minmax import minmax_process
|
| 6 |
-
from .wavelet import
|
| 7 |
|
| 8 |
__all__ = [
|
| 9 |
'bit_plane_extractor',
|
|
@@ -11,5 +11,5 @@ __all__ = [
|
|
| 11 |
# 'exif_full_dump',
|
| 12 |
'gradient_processing',
|
| 13 |
'minmax_process',
|
| 14 |
-
'
|
| 15 |
]
|
|
|
|
| 3 |
# from .exif import exif_full_dump
|
| 4 |
from .gradient import gradient_processing
|
| 5 |
from .minmax import minmax_process
|
| 6 |
+
from .wavelet import noise_estimation
|
| 7 |
|
| 8 |
__all__ = [
|
| 9 |
'bit_plane_extractor',
|
|
|
|
| 11 |
# 'exif_full_dump',
|
| 12 |
'gradient_processing',
|
| 13 |
'minmax_process',
|
| 14 |
+
'noise_estimation'
|
| 15 |
]
|
forensics/wavelet.py
CHANGED
|
@@ -3,7 +3,7 @@ import pywt
|
|
| 3 |
import cv2
|
| 4 |
from PIL import Image
|
| 5 |
|
| 6 |
-
def
|
| 7 |
"""Estimate local noise using wavelet blocking. Returns a PIL image of the noise map."""
|
| 8 |
im = np.array(image.convert('L'))
|
| 9 |
y = np.double(im)
|
|
|
|
| 3 |
import cv2
|
| 4 |
from PIL import Image
|
| 5 |
|
| 6 |
+
def noise_estimation(image: Image.Image, blocksize: int = 8) -> Image.Image:
|
| 7 |
"""Estimate local noise using wavelet blocking. Returns a PIL image of the noise map."""
|
| 8 |
im = np.array(image.convert('L'))
|
| 9 |
y = np.double(im)
|
utils/load.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import tempfile
|
| 3 |
+
from typing import Any, Callable, List, Optional, Tuple, Union
|
| 4 |
+
from urllib.parse import unquote, urlparse
|
| 5 |
+
|
| 6 |
+
import PIL.Image
|
| 7 |
+
import PIL.ImageOps
|
| 8 |
+
import requests
|
| 9 |
+
|
| 10 |
+
def load_image(
|
| 11 |
+
image: Union[str, PIL.Image.Image], convert_method: Optional[Callable[[PIL.Image.Image], PIL.Image.Image]] = None
|
| 12 |
+
) -> PIL.Image.Image:
|
| 13 |
+
"""
|
| 14 |
+
Loads `image` to a PIL Image.
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
image (`str` or `PIL.Image.Image`):
|
| 18 |
+
The image to convert to the PIL Image format.
|
| 19 |
+
convert_method (Callable[[PIL.Image.Image], PIL.Image.Image], *optional*):
|
| 20 |
+
A conversion method to apply to the image after loading it. When set to `None` the image will be converted
|
| 21 |
+
"RGB".
|
| 22 |
+
|
| 23 |
+
Returns:
|
| 24 |
+
`PIL.Image.Image`:
|
| 25 |
+
A PIL Image.
|
| 26 |
+
"""
|
| 27 |
+
if isinstance(image, str):
|
| 28 |
+
if image.startswith("http://") or image.startswith("https://"):
|
| 29 |
+
image = PIL.Image.open(requests.get(image, stream=True, timeout=600).raw)
|
| 30 |
+
elif os.path.isfile(image):
|
| 31 |
+
image = PIL.Image.open(image)
|
| 32 |
+
else:
|
| 33 |
+
raise ValueError(
|
| 34 |
+
f"Incorrect path or URL. URLs must start with `http://` or `https://`, and {image} is not a valid path."
|
| 35 |
+
)
|
| 36 |
+
elif isinstance(image, PIL.Image.Image):
|
| 37 |
+
image = image
|
| 38 |
+
else:
|
| 39 |
+
raise ValueError(
|
| 40 |
+
"Incorrect format used for the image. Should be a URL linking to an image, a local path, or a PIL image."
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
image = PIL.ImageOps.exif_transpose(image)
|
| 44 |
+
|
| 45 |
+
if convert_method is not None:
|
| 46 |
+
image = convert_method(image)
|
| 47 |
+
else:
|
| 48 |
+
image = image.convert("RGB")
|
| 49 |
+
|
| 50 |
+
return image
|
| 51 |
+
|