Spaces:
Running
Running
| import io | |
| from typing import List, Tuple, Dict, Any | |
| from PIL import Image, ImageOps | |
| import numpy as np | |
| import torch | |
| import gradio as gr | |
| # Face detector | |
| from facenet_pytorch import MTCNN | |
| # HF image classifier | |
| from transformers import AutoImageProcessor, AutoModelForImageClassification | |
| # ========= Config ========= | |
| # Multi-model ensemble (Soft Voting). You can add 1~2 more deepfake binary classifiers here. | |
| MODEL_IDS = [ | |
| "prithivMLmods/Deep-Fake-Detector-v2-Model", | |
| # Example for additional model: | |
| # "HuggingFaceM4/dfdc_deit_base_patch16_224", | |
| ] | |
| MAX_SIDE = 896 # Resize image, keep detail | |
| FACE_MIN_SIZE = 112 # Faces smaller than this in pixels are skipped (avoid artifacts) | |
| FACE_MARGIN = 0.20 # Margin added when cropping face (square crop) | |
| DETECT_THRESH = 0.80 # Face detection confidence threshold | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| SEED = 42 # For reproducibility | |
| # ========================= | |
| torch.manual_seed(SEED) | |
| np.random.seed(SEED) | |
| # ---- Utilities ---- | |
| def resize_keep_ratio(img: Image.Image, max_side: int = MAX_SIDE) -> Image.Image: | |
| """Resize while preserving aspect ratio.""" | |
| w, h = img.size | |
| m = max(w, h) | |
| if m <= max_side: | |
| return img | |
| scale = max_side / float(m) | |
| return img.resize((int(w * scale), int(h * scale)), Image.LANCZOS) | |
| def to_square_with_margin(box, W, H, margin=FACE_MARGIN): | |
| """Convert face bounding box to a square crop with margin.""" | |
| x1, y1, x2, y2 = box | |
| w = x2 - x1 | |
| h = y2 - y1 | |
| cx = (x1 + x2) / 2.0 | |
| cy = (y1 + y2) / 2.0 | |
| side = max(w, h) * (1.0 + margin) | |
| x1s = int(max(0, cx - side/2)) | |
| y1s = int(max(0, cy - side/2)) | |
| x2s = int(min(W, cx + side/2)) | |
| y2s = int(min(H, cy + side/2)) | |
| # Make it square | |
| sq_w = x2s - x1s | |
| sq_h = y2s - y1s | |
| if sq_w != sq_h: | |
| diff = abs(sq_w - sq_h) | |
| if sq_w < sq_h: | |
| x1s = max(0, x1s - diff // 2) | |
| x2s = min(W, x2s + diff - diff // 2) | |
| else: | |
| y1s = max(0, y1s - diff // 2) | |
| y2s = min(H, y2s + diff - diff // 2) | |
| return (x1s, y1s, x2s, y2s) | |
| def canonical_label(label: str) -> str: | |
| """Standardize label names into fake/real.""" | |
| l = label.lower().strip() | |
| fake_keys = ["fake", "ai", "synthetic", "deepfake", "generated", "manipulated", "forged"] | |
| real_keys = ["real", "authentic", "genuine", "original", "live"] | |
| if any(k in l for k in fake_keys): return "fake" | |
| if any(k in l for k in real_keys): return "real" | |
| return label | |
| def rank_probs(id2label: Dict[int, str], probs: np.ndarray) -> List[Tuple[str, float]]: | |
| """Sort probabilities by descending value.""" | |
| pairs = [(id2label[i], float(probs[i])) for i in range(len(probs))] | |
| return sorted(pairs, key=lambda x: x[1], reverse=True) | |
| def entropy(p: np.ndarray) -> float: | |
| """Probability entropy as uncertainty measure.""" | |
| p = np.clip(p, 1e-8, 1.0) | |
| return float(-(p * np.log(p)).sum()) | |
| def jpeg_recompress(pil_img: Image.Image, quality: int = 85) -> Image.Image: | |
| """JPEG recompression to simulate compression noise (TTA).""" | |
| buf = io.BytesIO() | |
| pil_img.save(buf, format="JPEG", quality=quality, optimize=True) | |
| buf.seek(0) | |
| return Image.open(buf).convert("RGB") | |
| def mild_center_crop_resize(pil_img: Image.Image, ratio: float = 0.92) -> Image.Image: | |
| """Slight center crop and resize (TTA).""" | |
| w, h = pil_img.size | |
| nw, nh = int(w * ratio), int(h * ratio) | |
| left = (w - nw) // 2 | |
| top = (h - nh) // 2 | |
| return pil_img.crop((left, top, left + nw, top + nh)).resize((w, h), Image.LANCZOS) | |
| # ===== Image quality evaluation (to reduce false positives on webcam photos) ===== | |
| def _np_gray(pil_img: Image.Image) -> np.ndarray: | |
| return np.array(pil_img.convert("L")) | |
| def _conv2d_same_reflect(img: np.ndarray, kernel: np.ndarray) -> np.ndarray: | |
| """Lightweight 2D convolution (reflect padding, no cv2/scipy).""" | |
| kh, kw = kernel.shape | |
| ph, pw = kh // 2, kw // 2 | |
| img_pad = np.pad(img, ((ph, ph), (pw, pw)), mode="reflect") | |
| out = np.zeros_like(img, dtype=np.float32) | |
| for i in range(out.shape[0]): | |
| for j in range(out.shape[1]): | |
| region = img_pad[i:i+kh, j:j+kw] | |
| out[i, j] = float((region * kernel).sum()) | |
| return out | |
| def variance_of_laplacian(gray: np.ndarray) -> float: | |
| """Sharpness measure (blur detection).""" | |
| k = np.array([[0, 1, 0],[1,-4, 1],[0, 1, 0]], dtype=np.float32) | |
| lap = _conv2d_same_reflect(gray.astype(np.float32), k) | |
| return float(lap.var()) | |
| def jpeg_size_ratio(pil_img: Image.Image, quality: int = 85) -> float: | |
| """Compression ratio proxy.""" | |
| buf_png = io.BytesIO(); pil_img.save(buf_png, format="PNG"); s_png = len(buf_png.getvalue()) | |
| buf_jpg = io.BytesIO(); pil_img.save(buf_jpg, format="JPEG", quality=quality, optimize=True); s_jpg = len(buf_jpg.getvalue()) | |
| if s_png == 0: return 1.0 | |
| return float(s_jpg) / float(s_png) | |
| def image_quality_metrics(pil_img: Image.Image) -> Dict[str, float]: | |
| g = _np_gray(pil_img) | |
| return { | |
| "sharp": variance_of_laplacian(g), | |
| "bright": float(g.mean()), | |
| "contrast": float(g.std()), | |
| "comp": jpeg_size_ratio(pil_img, 85) | |
| } | |
| def quality_bucket(m: Dict[str,float]) -> str: | |
| """Classify image quality level.""" | |
| poor = (m["sharp"] < 60) or (m["bright"] < 40) or (m["bright"] > 215) or (m["contrast"] < 25) | |
| good = (m["sharp"] >= 120) and (50 <= m["bright"] <= 200) and (m["contrast"] >= 35) | |
| if poor: return "poor" | |
| if good: return "good" | |
| return "ok" | |
| def logit(p, eps=1e-6): p = min(max(p, eps), 1 - eps); return float(np.log(p/(1-p))) | |
| def sigmoid(x): return float(1/(1+np.exp(-x))) | |
| def calibrate_fake_prob(p: float, qlvl: str) -> float: | |
| """Quality-adaptive probability scaling.""" | |
| if qlvl == "poor": t, b = 1.6, -0.4 | |
| elif qlvl == "good": t, b = 0.9, +0.1 | |
| else: t, b = 1.2, -0.1 | |
| z = (logit(p) + b) / t | |
| return sigmoid(z) | |
| # ---- Load models ---- | |
| mtcnn = MTCNN(keep_all=True, device=DEVICE) | |
| _models = [] | |
| for mid in MODEL_IDS: | |
| try: processor = AutoImageProcessor.from_pretrained(mid, use_fast=True) | |
| except Exception: processor = AutoImageProcessor.from_pretrained(mid) | |
| clf = AutoModelForImageClassification.from_pretrained(mid).to(DEVICE).eval() | |
| _models.append((mid, processor, clf)) | |
| # ---- Core inference ---- | |
| def classify_images_ensemble(pils: List[Image.Image]) -> Dict[str, Any]: | |
| """Run classification on multiple faces and return per-image predictions.""" | |
| per_image_results = [] | |
| def tta_views(img): return [ | |
| img, | |
| ImageOps.mirror(img), | |
| jpeg_recompress(img, 85), | |
| mild_center_crop_resize(img, 0.92), | |
| ] | |
| use_amp = (DEVICE == "cuda") | |
| for img in pils: | |
| views = tta_views(img) | |
| model_probs_accum = [] | |
| for (mid, processor, clf) in _models: | |
| inputs = processor(images=views, return_tensors="pt") | |
| inputs = {k: v.to(DEVICE) for k, v in inputs.items()} | |
| with torch.cuda.amp.autocast(enabled=use_amp): | |
| logits = clf(**inputs).logits | |
| probs = torch.softmax(logits, dim=-1) | |
| model_probs_accum.append(probs.mean(dim=0).unsqueeze(0)) | |
| probs_ens = torch.cat(model_probs_accum, dim=0).mean(dim=0) | |
| probs_np = probs_ens.float().cpu().numpy() | |
| id2label = _models[0][2].config.id2label | |
| ranked = rank_probs(id2label, probs_np) | |
| fake_p = real_p = None | |
| for i in range(len(probs_np)): | |
| cat = canonical_label(id2label[i]) | |
| if cat == "fake": fake_p = max(fake_p or 0, probs_np[i]) | |
| elif cat == "real": real_p = max(real_p or 0, probs_np[i]) | |
| if fake_p is None and real_p is None: | |
| top1 = ranked[0] | |
| if canonical_label(top1[0]) == "fake": fake_p, real_p = top1[1], 1-top1[1] | |
| else: real_p, fake_p = top1[1], 1-top1[1] | |
| per_image_results.append({ | |
| "top": ranked[:3], | |
| "fake_prob": None if fake_p is None else round(fake_p, 4), | |
| "real_prob": None if real_p is None else round(real_p, 4), | |
| "entropy": round(entropy(probs_np), 4) | |
| }) | |
| return {"per_image": per_image_results} | |
| def analyze(img: Image.Image) -> Dict[str, Any]: | |
| """Detect deepfake in faces and full image, quality-aware fusion.""" | |
| if img is None: return {"error": "No image uploaded."} | |
| img = resize_keep_ratio(img.convert("RGB"), MAX_SIDE) | |
| q_metrics = image_quality_metrics(img) | |
| q_level = quality_bucket(q_metrics) | |
| boxes, probs = mtcnn.detect(img) | |
| crops, crop_boxes, kept_scores = [], [], [] | |
| # Face detection & filtering | |
| if boxes is not None and probs is not None: | |
| W, H = img.size | |
| for (b, sc) in zip(boxes, probs): | |
| if sc is None or sc < DETECT_THRESH: continue | |
| x1s, y1s, x2s, y2s = to_square_with_margin(b, W, H, FACE_MARGIN) | |
| if x2s<=x1s or y2s<=y1s: continue | |
| face = img.crop((x1s, y1s, x2s, y2s)) | |
| if min(face.size) < FACE_MIN_SIZE: continue | |
| crops.append(face); crop_boxes.append((x1s,y1s,x2s,y2s)); kept_scores.append(float(sc)) | |
| # Face inference | |
| per_face_results = [] | |
| if crops: | |
| preds = classify_images_ensemble(crops)["per_image"] | |
| for i,(pred,box,sc) in enumerate(zip(preds, crop_boxes, kept_scores),1): | |
| fm = image_quality_metrics(crops[i-1]) | |
| fl = quality_bucket(fm) | |
| fp = pred.get("fake_prob") | |
| if fp is not None: | |
| pred["fake_prob_raw"]=fp | |
| pred["fake_prob"]=round(calibrate_fake_prob(fp, fl),4) | |
| pred["quality"]={"level":fl,**fm} | |
| per_face_results.append({ | |
| "face_index":i,"box":{"x1":box[0],"y1":box[1],"x2":box[2],"y2":box[3]}, | |
| "det_score":round(sc,4),**pred | |
| }) | |
| # Full-image inference (weak expert) | |
| full_pred = classify_images_ensemble([img])["per_image"][0] | |
| if full_pred.get("fake_prob") is not None: | |
| full_pred["fake_prob_raw"]=full_pred["fake_prob"] | |
| full_pred["fake_prob"]=round(calibrate_fake_prob(full_pred["fake_prob"], q_level),4) | |
| full_pred["quality"]={"level":q_level,**q_metrics} | |
| # Score fusion (faces > full image) | |
| face_scores=[r["fake_prob"] for r in per_face_results if r.get("fake_prob") is not None] | |
| if not face_scores and full_pred.get("fake_prob") is None: | |
| overall_fake=0.5 | |
| else: | |
| faces_robust=float(np.median(face_scores)) if face_scores else None | |
| full_score=full_pred.get("fake_prob",None) | |
| if faces_robust and full_score: | |
| overall_fake=0.8*faces_robust+0.2*full_score | |
| elif faces_robust: overall_fake=faces_robust | |
| else: overall_fake=full_score | |
| if face_scores: | |
| q3=float(np.quantile(face_scores,0.75)) | |
| overall_fake=float(0.7*overall_fake+0.3*q3) | |
| # Dynamic thresholding based on quality | |
| if q_level=="poor": th_fake,th_unc=0.85,0.65 | |
| elif q_level=="good": th_fake,th_unc=0.70,0.55 | |
| else: th_fake,th_unc=0.75,0.60 | |
| if overall_fake>=th_fake: label="Likely Deepfake" | |
| elif overall_fake>=th_unc: label="Uncertain" | |
| else: label="Likely Real" | |
| return { | |
| "label":label, | |
| "overall_fake_probability":round(overall_fake,4), | |
| "faces_detected":len(per_face_results), | |
| } | |
| def visualize_faces(img: Image.Image): | |
| """Return cropped faces for preview.""" | |
| if img is None: return [] | |
| img=resize_keep_ratio(img.convert("RGB"),MAX_SIDE) | |
| boxes,probs=mtcnn.detect(img) | |
| thumbs=[] | |
| if boxes is not None and probs is not None: | |
| W,H=img.size | |
| for(b,sc) in zip(boxes,probs): | |
| if sc is None or sc<DETECT_THRESH: continue | |
| x1s,y1s,x2s,y2s=to_square_with_margin(b,W,H,FACE_MARGIN) | |
| if x2s<=x1s or y2s<=y1s: continue | |
| face=img.crop((x1s,y1s,x2s,y2s)) | |
| if min(face.size)<FACE_MIN_SIZE: continue | |
| thumbs.append(face.resize((160,160),Image.LANCZOS)) | |
| return thumbs | |
| # ---- Gradio UI ---- | |
| with gr.Blocks() as demo: | |
| gr.Markdown(""" | |
| # π΅οΈ FakeSpotter β Deepfake Image Detector | |
| **Ensemble + TTA + Quality Awareness** | |
| - Multi-model ensemble | |
| - Face-focused detection + fallback to whole image | |
| - Image quality guard to reduce false positives (e.g., webcam noise) | |
| > Educational, not forensic-grade. | |
| """) | |
| with gr.Row(): | |
| inp=gr.Image(type="pil",label="Upload Image") | |
| with gr.Column(): | |
| thumbs=gr.Gallery(label="Detected Face Crops (preview)",columns=6,height="auto") | |
| out=gr.JSON(label="Results") | |
| btn=gr.Button("Analyze") | |
| btn.click(lambda im:visualize_faces(im),inputs=inp,outputs=thumbs) | |
| btn.click(analyze,inputs=inp,outputs=out) | |
| if __name__=="__main__": | |
| demo.launch() | |