Spaces:
Sleeping
Sleeping
| # app.py | |
| import numpy as np | |
| import gradio as gr | |
| from PIL import Image, ImageDraw | |
| import cv2 | |
| import torch | |
| from transformers import AutoImageProcessor, AutoModelForImageClassification | |
| # ---- Config ---- | |
| MODEL_ID = "SadraCoding/SDXL-Deepfake-Detector" | |
| THRESHOLD = 0.65 # >= -> "Likely Manipulated" | |
| IMAGE_SIZE = 224 # ViT input size | |
| try: | |
| import mediapipe as mp | |
| _mp_face = mp.solutions.face_detection.FaceDetection( | |
| model_selection=0, min_detection_confidence=0.4 | |
| ) | |
| except Exception: | |
| _mp_face = None | |
| # ---- Face crop ---- | |
| def crop_face(pil_img, pad=0.25): | |
| """ | |
| Crop the most prominent face using MediaPipe. If MP missing or no face found, | |
| return the original image. | |
| """ | |
| if _mp_face is None: | |
| return pil_img | |
| img = np.array(pil_img.convert("RGB")) | |
| h, w = img.shape[:2] | |
| res = _mp_face.process(cv2.cvtColor(img, cv2.COLOR_RGB2BGR)) | |
| if not res.detections: | |
| return pil_img | |
| det = max( | |
| res.detections, | |
| key=lambda d: d.location_data.relative_bounding_box.width | |
| ) | |
| b = det.location_data.relative_bounding_box | |
| x, y, bw, bh = b.xmin, b.ymin, b.width, b.height | |
| x1 = int(max(0, (x - pad*bw) * w)); y1 = int(max(0, (y - pad*bh) * h)) | |
| x2 = int(min(w, (x + bw + pad*bw) * w)); y2 = int(min(h, (y + bh + pad*bh) * h)) | |
| face = Image.fromarray(img[y1:y2, x1:x2]) | |
| if face.size[0] < 20 or face.size[1] < 20: | |
| return pil_img | |
| return face | |
| def face_oval_mask(img_pil, shrink=0.80): | |
| w, h = img_pil.size | |
| mask = Image.new("L", (w, h), 0) | |
| draw = ImageDraw.Draw(mask) | |
| dx, dy = int((1 - shrink) * w / 2), int((1 - shrink) * h / 2) | |
| draw.ellipse((dx, dy, w - dx, h - dy), fill=255) | |
| return np.array(mask, dtype=np.float32) / 255.0 | |
| # ---- HF model load ---- | |
| processor = AutoImageProcessor.from_pretrained(MODEL_ID) | |
| model = AutoModelForImageClassification.from_pretrained(MODEL_ID) | |
| model.eval() | |
| torch.set_grad_enabled(False) | |
| # Resolve which index corresponds to "fake" | |
| _FAKE_KEYS = ("artificial", "fake", "deepfake", "manipulated", "spoof", "forged") | |
| def _fake_index_from_config(cfg) -> int | None: | |
| # Prefer id2label | |
| id2label = getattr(cfg, "id2label", None) | |
| if id2label: | |
| try: | |
| normalized = {int(k): str(v).lower() for k, v in id2label.items()} | |
| except Exception: | |
| # sometimes keys already ints | |
| normalized = {int(k): str(v).lower() for k, v in id2label.items()} | |
| for idx, lab in normalized.items(): | |
| if any(k in lab for k in _FAKE_KEYS): | |
| return idx | |
| # Fallback: invert label2id | |
| label2id = getattr(cfg, "label2id", None) | |
| if label2id: | |
| inv = {int(v): str(k).lower() for k, v in label2id.items()} | |
| for idx, lab in inv.items(): | |
| if any(k in lab for k in _FAKE_KEYS): | |
| return idx | |
| return None | |
| _FAKE_IDX = _fake_index_from_config(model.config) | |
| # ---- Inference ---- | |
| def predict_fake_prob(pil_img: Image.Image) -> float: | |
| """ | |
| Returns P(fake) in [0,1]. | |
| Model labels per card: 0 -> 'artificial' (fake), 1 -> 'human' (real). | |
| """ | |
| # Face-focus to reduce background bias | |
| face = crop_face(pil_img) | |
| face = face.convert("RGB").resize((IMAGE_SIZE, IMAGE_SIZE)) | |
| inputs = processor(images=face, return_tensors="pt") | |
| logits = model(**inputs).logits # (1, C) | |
| if logits.shape[-1] == 1: | |
| # Binary sigmoid head (unlikely for this model, but safe) | |
| return torch.sigmoid(logits.squeeze(0))[0].item() | |
| # Softmax multi-class (expected 2 classes) | |
| probs = torch.softmax(logits.squeeze(0), dim=-1).detach().cpu().numpy() | |
| # Use explicit mapping if available | |
| if _FAKE_IDX is not None and 0 <= _FAKE_IDX < probs.shape[0]: | |
| return float(probs[_FAKE_IDX]) | |
| # Known mapping from the model card: 0=artificial (fake), 1=human | |
| if probs.shape[0] == 2: | |
| return float(probs[0]) # class-0 is fake | |
| # Last resort | |
| return float(probs.max()) | |
| # ---- UI helpers ---- | |
| def result_card(prob_fake: float) -> str: | |
| label = "Likely Manipulated" if prob_fake >= THRESHOLD else "Likely Authentic" | |
| pct = prob_fake * 100.0 | |
| color = "#d84a4a" if label.startswith("Likely Manipulated") else "#2e7d32" | |
| bar_bg = "#e9ecef" | |
| return f""" | |
| <div style="max-width:860px;margin:0 auto;"> | |
| <div style="border:1px solid #e5e7eb;border-radius:14px;padding:18px 20px;background:#fff; | |
| box-shadow: 0 2px 10px rgba(16,24,40,.04);"> | |
| <div style="display:flex;justify-content:space-between;align-items:center;margin-bottom:10px;"> | |
| <div style="font-size:18px;color:#111827;font-weight:600;">Deepfake likelihood</div> | |
| <div style="font-weight:700;color:{color};">{pct:.1f}% — {label}</div> | |
| </div> | |
| <div style="width:100%;height:10px;background:{bar_bg};border-radius:999px;overflow:hidden;"> | |
| <div style="height:100%;width:{pct:.4f}%;background:{color};"></div> | |
| </div> | |
| <div style="font-size:12px;color:#6b7280;margin-top:8px;"> | |
| Model: {MODEL_ID} · Threshold: {int(THRESHOLD*100)}% | |
| </div> | |
| </div> | |
| </div> | |
| """ | |
| # ---- Gradio handlers ---- | |
| def analyze(pil_img: Image.Image): | |
| if pil_img is None: | |
| return result_card(0.0) | |
| p_fake = predict_fake_prob(pil_img) | |
| return result_card(p_fake) | |
| # ---- UI ---- | |
| CUSTOM_CSS = """ | |
| .gradio-container {max-width: 980px !important;} | |
| .sleek-card { | |
| border: 1px solid #e5e7eb; border-radius: 16px; background: #fff; | |
| box-shadow: 0 2px 10px rgba(16,24,40,.04); padding: 18px; | |
| } | |
| """ | |
| with gr.Blocks(title="Deepfake Detector (SDXL ViT)", css=CUSTOM_CSS, theme=gr.themes.Soft()) as demo: | |
| gr.Markdown( | |
| "<h2 style='text-align:center;margin-bottom:6px;'>Deepfake Detector (SDXL ViT)</h2>" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=6, elem_classes=["sleek-card"]): | |
| inp = gr.Image( | |
| type="pil", | |
| label="Upload / Paste Image", | |
| sources=["upload", "webcam", "clipboard"], | |
| height=420, | |
| show_label=True, | |
| interactive=True, | |
| ) | |
| btn = gr.Button("Analyze", variant="primary", size="lg") | |
| with gr.Column(scale=6): | |
| out = gr.HTML() | |
| btn.click(analyze, inputs=inp, outputs=out) | |
| inp.change(analyze, inputs=inp, outputs=out) | |
| if __name__ == "__main__": | |
| demo.launch() | |