# app.py import numpy as np import gradio as gr from PIL import Image, ImageDraw import cv2 import torch from transformers import AutoImageProcessor, AutoModelForImageClassification # ---- Config ---- MODEL_ID = "SadraCoding/SDXL-Deepfake-Detector" THRESHOLD = 0.65 # >= -> "Likely Manipulated" IMAGE_SIZE = 224 # ViT input size try: import mediapipe as mp _mp_face = mp.solutions.face_detection.FaceDetection( model_selection=0, min_detection_confidence=0.4 ) except Exception: _mp_face = None # ---- Face crop ---- def crop_face(pil_img, pad=0.25): """ Crop the most prominent face using MediaPipe. If MP missing or no face found, return the original image. """ if _mp_face is None: return pil_img img = np.array(pil_img.convert("RGB")) h, w = img.shape[:2] res = _mp_face.process(cv2.cvtColor(img, cv2.COLOR_RGB2BGR)) if not res.detections: return pil_img det = max( res.detections, key=lambda d: d.location_data.relative_bounding_box.width ) b = det.location_data.relative_bounding_box x, y, bw, bh = b.xmin, b.ymin, b.width, b.height x1 = int(max(0, (x - pad*bw) * w)); y1 = int(max(0, (y - pad*bh) * h)) x2 = int(min(w, (x + bw + pad*bw) * w)); y2 = int(min(h, (y + bh + pad*bh) * h)) face = Image.fromarray(img[y1:y2, x1:x2]) if face.size[0] < 20 or face.size[1] < 20: return pil_img return face def face_oval_mask(img_pil, shrink=0.80): w, h = img_pil.size mask = Image.new("L", (w, h), 0) draw = ImageDraw.Draw(mask) dx, dy = int((1 - shrink) * w / 2), int((1 - shrink) * h / 2) draw.ellipse((dx, dy, w - dx, h - dy), fill=255) return np.array(mask, dtype=np.float32) / 255.0 # ---- HF model load ---- processor = AutoImageProcessor.from_pretrained(MODEL_ID) model = AutoModelForImageClassification.from_pretrained(MODEL_ID) model.eval() torch.set_grad_enabled(False) # Resolve which index corresponds to "fake" _FAKE_KEYS = ("artificial", "fake", "deepfake", "manipulated", "spoof", "forged") def _fake_index_from_config(cfg) -> int | None: # Prefer id2label id2label = getattr(cfg, "id2label", None) if id2label: try: normalized = {int(k): str(v).lower() for k, v in id2label.items()} except Exception: # sometimes keys already ints normalized = {int(k): str(v).lower() for k, v in id2label.items()} for idx, lab in normalized.items(): if any(k in lab for k in _FAKE_KEYS): return idx # Fallback: invert label2id label2id = getattr(cfg, "label2id", None) if label2id: inv = {int(v): str(k).lower() for k, v in label2id.items()} for idx, lab in inv.items(): if any(k in lab for k in _FAKE_KEYS): return idx return None _FAKE_IDX = _fake_index_from_config(model.config) # ---- Inference ---- def predict_fake_prob(pil_img: Image.Image) -> float: """ Returns P(fake) in [0,1]. Model labels per card: 0 -> 'artificial' (fake), 1 -> 'human' (real). """ # Face-focus to reduce background bias face = crop_face(pil_img) face = face.convert("RGB").resize((IMAGE_SIZE, IMAGE_SIZE)) inputs = processor(images=face, return_tensors="pt") logits = model(**inputs).logits # (1, C) if logits.shape[-1] == 1: # Binary sigmoid head (unlikely for this model, but safe) return torch.sigmoid(logits.squeeze(0))[0].item() # Softmax multi-class (expected 2 classes) probs = torch.softmax(logits.squeeze(0), dim=-1).detach().cpu().numpy() # Use explicit mapping if available if _FAKE_IDX is not None and 0 <= _FAKE_IDX < probs.shape[0]: return float(probs[_FAKE_IDX]) # Known mapping from the model card: 0=artificial (fake), 1=human if probs.shape[0] == 2: return float(probs[0]) # class-0 is fake # Last resort return float(probs.max()) # ---- UI helpers ---- def result_card(prob_fake: float) -> str: label = "Likely Manipulated" if prob_fake >= THRESHOLD else "Likely Authentic" pct = prob_fake * 100.0 color = "#d84a4a" if label.startswith("Likely Manipulated") else "#2e7d32" bar_bg = "#e9ecef" return f"""