SPOC_AI_HW / app.py
ooki0626's picture
Update app.py
2ad8303 verified
raw
history blame
5.19 kB
import io
from typing import List, Tuple, Dict, Any
from PIL import Image
import numpy as np
import torch
import gradio as gr
# Face detector
from facenet_pytorch import MTCNN
# HF image classifier
from transformers import AutoImageProcessor, AutoModelForImageClassification
# ========= Config =========
# You can change the model below to another public model on Hugging Face
# Example: prithivMLmods/Deep-Fake-Detector-v2-Model (binary: Deepfake vs Realism)
MODEL_ID = "prithivMLmods/Deep-Fake-Detector-v2-Model"
DEVICE = "cpu" # Use "cuda" if GPU is available
MAX_SIDE = 640 # Resize to keep the longest side ≀ 640px for efficiency
# =========================
# ---- Utilities ----
def resize_keep_ratio(img: Image.Image, max_side: int = MAX_SIDE) -> Image.Image:
"""Resize the image while keeping aspect ratio and limit max side length."""
w, h = img.size
m = max(w, h)
if m <= max_side:
return img
scale = max_side / float(m)
return img.resize((int(w * scale), int(h * scale)), Image.LANCZOS)
def canonical_label(label: str) -> str:
"""Map model-specific labels to canonical 'fake' or 'real' categories."""
l = label.lower()
if any(k in l for k in ["fake", "ai", "synthetic", "deepfake"]):
return "fake"
if any(k in l for k in ["real", "authentic", "genuine"]):
return "real"
# Default fallback if label doesn't match known keywords
return label
def rank_probs(id2label: Dict[int, str], probs: List[float]) -> List[Tuple[str, float]]:
"""Return sorted list of (label, probability) pairs."""
pairs = [(id2label[i], float(probs[i])) for i in range(len(probs))]
return sorted(pairs, key=lambda x: x[1], reverse=True)
# ---- Load models (once) ----
mtcnn = MTCNN(keep_all=True, device=DEVICE)
processor = AutoImageProcessor.from_pretrained(MODEL_ID)
clf = AutoModelForImageClassification.from_pretrained(MODEL_ID).to(DEVICE)
id2label = clf.config.id2label
# ---- Core inference ----
@torch.no_grad()
def classify_pil(img: Image.Image) -> Dict[str, Any]:
"""Run classification on a single PIL image and return ranked probabilities."""
inputs = processor(images=img, return_tensors="pt")
inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
logits = clf(**inputs).logits
probs = torch.softmax(logits, dim=-1).squeeze().tolist()
ranked = rank_probs(id2label, probs)
# Extract approximate fake / real probabilities based on label keywords
fake_p, real_p = None, None
for lbl, p in ranked:
cat = canonical_label(lbl)
if cat == "fake" and fake_p is None:
fake_p = p
if cat == "real" and real_p is None:
real_p = p
return {
"top": ranked[:3],
"fake_prob": fake_p,
"real_prob": real_p
}
def analyze(img: Image.Image) -> Dict[str, Any]:
"""Main analysis pipeline: detect faces, classify each face or full image."""
if img is None:
return {"error": "No image uploaded."}
img = img.convert("RGB")
img = resize_keep_ratio(img, MAX_SIDE)
# 1) Face detection
boxes, _ = mtcnn.detect(img)
crops = []
if boxes is not None:
for (x1, y1, x2, y2) in boxes:
x1 = max(0, int(x1)); y1 = max(0, int(y1))
x2 = min(img.width, int(x2)); y2 = min(img.height, int(y2))
if x2 > x1 and y2 > y1:
crops.append(img.crop((x1, y1, x2, y2)))
results = []
if crops:
# 2) Classify each detected face
for idx, face in enumerate(crops, 1):
r = classify_pil(face)
results.append({"face": idx, **r})
else:
# 3) If no face is detected, classify the whole image
r = classify_pil(img)
results.append({"face": None, **r})
# Aggregate: use median of fake probabilities across all faces
fake_scores = []
for r in results:
if r.get("fake_prob") is not None:
fake_scores.append(r["fake_prob"])
else:
# Fallback: use top-1 label keyword
top1 = r["top"][0][0]
fake_scores.append(1.0 if canonical_label(top1) == "fake" else 0.0)
if fake_scores:
overall_fake = float(np.median(fake_scores))
else:
overall_fake = 0.5
label = "Likely AI/Deepfake" if overall_fake >= 0.6 else ("Uncertain" if overall_fake >= 0.4 else "Likely Real")
return {
"label": label,
"overall_fake_probability": round(overall_fake, 3),
"faces_detected": len(crops),
"per_face_results": results
}
# ---- Gradio UI ----
with gr.Blocks() as demo:
gr.Markdown(
"""
# πŸ•΅οΈ FakeSpotter β€” Image Deepfake Detector (CPU)
Upload an image. If a face is detected, each face is analyzed; otherwise, the whole image is classified.
**No EXIF is used.** Model can be swapped by editing `MODEL_ID` in the code.
> Classroom demo β€” not a forensic tool.
"""
)
with gr.Row():
inp = gr.Image(type="pil", label="Upload image")
out = gr.JSON(label="Results")
gr.Button("Analyze").click(analyze, inputs=inp, outputs=out)
if __name__ == "__main__":
demo.launch()