deepfakedetect / app.py
vaniv's picture
Update app.py
cfc9929 verified
# app.py
import numpy as np
import gradio as gr
from PIL import Image, ImageDraw
import cv2
import torch
from transformers import AutoImageProcessor, AutoModelForImageClassification
# ---- Config ----
MODEL_ID = "SadraCoding/SDXL-Deepfake-Detector"
THRESHOLD = 0.65 # >= -> "Likely Manipulated"
IMAGE_SIZE = 224 # ViT input size
try:
import mediapipe as mp
_mp_face = mp.solutions.face_detection.FaceDetection(
model_selection=0, min_detection_confidence=0.4
)
except Exception:
_mp_face = None
# ---- Face crop ----
def crop_face(pil_img, pad=0.25):
"""
Crop the most prominent face using MediaPipe. If MP missing or no face found,
return the original image.
"""
if _mp_face is None:
return pil_img
img = np.array(pil_img.convert("RGB"))
h, w = img.shape[:2]
res = _mp_face.process(cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
if not res.detections:
return pil_img
det = max(
res.detections,
key=lambda d: d.location_data.relative_bounding_box.width
)
b = det.location_data.relative_bounding_box
x, y, bw, bh = b.xmin, b.ymin, b.width, b.height
x1 = int(max(0, (x - pad*bw) * w)); y1 = int(max(0, (y - pad*bh) * h))
x2 = int(min(w, (x + bw + pad*bw) * w)); y2 = int(min(h, (y + bh + pad*bh) * h))
face = Image.fromarray(img[y1:y2, x1:x2])
if face.size[0] < 20 or face.size[1] < 20:
return pil_img
return face
def face_oval_mask(img_pil, shrink=0.80):
w, h = img_pil.size
mask = Image.new("L", (w, h), 0)
draw = ImageDraw.Draw(mask)
dx, dy = int((1 - shrink) * w / 2), int((1 - shrink) * h / 2)
draw.ellipse((dx, dy, w - dx, h - dy), fill=255)
return np.array(mask, dtype=np.float32) / 255.0
# ---- HF model load ----
processor = AutoImageProcessor.from_pretrained(MODEL_ID)
model = AutoModelForImageClassification.from_pretrained(MODEL_ID)
model.eval()
torch.set_grad_enabled(False)
# Resolve which index corresponds to "fake"
_FAKE_KEYS = ("artificial", "fake", "deepfake", "manipulated", "spoof", "forged")
def _fake_index_from_config(cfg) -> int | None:
# Prefer id2label
id2label = getattr(cfg, "id2label", None)
if id2label:
try:
normalized = {int(k): str(v).lower() for k, v in id2label.items()}
except Exception:
# sometimes keys already ints
normalized = {int(k): str(v).lower() for k, v in id2label.items()}
for idx, lab in normalized.items():
if any(k in lab for k in _FAKE_KEYS):
return idx
# Fallback: invert label2id
label2id = getattr(cfg, "label2id", None)
if label2id:
inv = {int(v): str(k).lower() for k, v in label2id.items()}
for idx, lab in inv.items():
if any(k in lab for k in _FAKE_KEYS):
return idx
return None
_FAKE_IDX = _fake_index_from_config(model.config)
# ---- Inference ----
def predict_fake_prob(pil_img: Image.Image) -> float:
"""
Returns P(fake) in [0,1].
Model labels per card: 0 -> 'artificial' (fake), 1 -> 'human' (real).
"""
# Face-focus to reduce background bias
face = crop_face(pil_img)
face = face.convert("RGB").resize((IMAGE_SIZE, IMAGE_SIZE))
inputs = processor(images=face, return_tensors="pt")
logits = model(**inputs).logits # (1, C)
if logits.shape[-1] == 1:
# Binary sigmoid head (unlikely for this model, but safe)
return torch.sigmoid(logits.squeeze(0))[0].item()
# Softmax multi-class (expected 2 classes)
probs = torch.softmax(logits.squeeze(0), dim=-1).detach().cpu().numpy()
# Use explicit mapping if available
if _FAKE_IDX is not None and 0 <= _FAKE_IDX < probs.shape[0]:
return float(probs[_FAKE_IDX])
# Known mapping from the model card: 0=artificial (fake), 1=human
if probs.shape[0] == 2:
return float(probs[0]) # class-0 is fake
# Last resort
return float(probs.max())
# ---- UI helpers ----
def result_card(prob_fake: float) -> str:
label = "Likely Manipulated" if prob_fake >= THRESHOLD else "Likely Authentic"
pct = prob_fake * 100.0
color = "#d84a4a" if label.startswith("Likely Manipulated") else "#2e7d32"
bar_bg = "#e9ecef"
return f"""
<div style="max-width:860px;margin:0 auto;">
<div style="border:1px solid #e5e7eb;border-radius:14px;padding:18px 20px;background:#fff;
box-shadow: 0 2px 10px rgba(16,24,40,.04);">
<div style="display:flex;justify-content:space-between;align-items:center;margin-bottom:10px;">
<div style="font-size:18px;color:#111827;font-weight:600;">Deepfake likelihood</div>
<div style="font-weight:700;color:{color};">{pct:.1f}% — {label}</div>
</div>
<div style="width:100%;height:10px;background:{bar_bg};border-radius:999px;overflow:hidden;">
<div style="height:100%;width:{pct:.4f}%;background:{color};"></div>
</div>
<div style="font-size:12px;color:#6b7280;margin-top:8px;">
Model: {MODEL_ID} · Threshold: {int(THRESHOLD*100)}%
</div>
</div>
</div>
"""
# ---- Gradio handlers ----
def analyze(pil_img: Image.Image):
if pil_img is None:
return result_card(0.0)
p_fake = predict_fake_prob(pil_img)
return result_card(p_fake)
# ---- UI ----
CUSTOM_CSS = """
.gradio-container {max-width: 980px !important;}
.sleek-card {
border: 1px solid #e5e7eb; border-radius: 16px; background: #fff;
box-shadow: 0 2px 10px rgba(16,24,40,.04); padding: 18px;
}
"""
with gr.Blocks(title="Deepfake Detector (SDXL ViT)", css=CUSTOM_CSS, theme=gr.themes.Soft()) as demo:
gr.Markdown(
"<h2 style='text-align:center;margin-bottom:6px;'>Deepfake Detector (SDXL ViT)</h2>"
)
with gr.Row():
with gr.Column(scale=6, elem_classes=["sleek-card"]):
inp = gr.Image(
type="pil",
label="Upload / Paste Image",
sources=["upload", "webcam", "clipboard"],
height=420,
show_label=True,
interactive=True,
)
btn = gr.Button("Analyze", variant="primary", size="lg")
with gr.Column(scale=6):
out = gr.HTML()
btn.click(analyze, inputs=inp, outputs=out)
inp.change(analyze, inputs=inp, outputs=out)
if __name__ == "__main__":
demo.launch()