SPOC_AI_HW / app.py
ooki0626's picture
Update app.py
49f94a5 verified
import io
from typing import List, Tuple, Dict, Any
from PIL import Image, ImageOps
import numpy as np
import torch
import gradio as gr
# Face detector
from facenet_pytorch import MTCNN
# HF image classifier
from transformers import AutoImageProcessor, AutoModelForImageClassification
# ========= Config =========
# Multi-model ensemble (Soft Voting). You can add 1~2 more deepfake binary classifiers here.
MODEL_IDS = [
"prithivMLmods/Deep-Fake-Detector-v2-Model",
# Example for additional model:
# "HuggingFaceM4/dfdc_deit_base_patch16_224",
]
MAX_SIDE = 896 # Resize image, keep detail
FACE_MIN_SIZE = 112 # Faces smaller than this in pixels are skipped (avoid artifacts)
FACE_MARGIN = 0.20 # Margin added when cropping face (square crop)
DETECT_THRESH = 0.80 # Face detection confidence threshold
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SEED = 42 # For reproducibility
# =========================
torch.manual_seed(SEED)
np.random.seed(SEED)
# ---- Utilities ----
def resize_keep_ratio(img: Image.Image, max_side: int = MAX_SIDE) -> Image.Image:
"""Resize while preserving aspect ratio."""
w, h = img.size
m = max(w, h)
if m <= max_side:
return img
scale = max_side / float(m)
return img.resize((int(w * scale), int(h * scale)), Image.LANCZOS)
def to_square_with_margin(box, W, H, margin=FACE_MARGIN):
"""Convert face bounding box to a square crop with margin."""
x1, y1, x2, y2 = box
w = x2 - x1
h = y2 - y1
cx = (x1 + x2) / 2.0
cy = (y1 + y2) / 2.0
side = max(w, h) * (1.0 + margin)
x1s = int(max(0, cx - side/2))
y1s = int(max(0, cy - side/2))
x2s = int(min(W, cx + side/2))
y2s = int(min(H, cy + side/2))
# Make it square
sq_w = x2s - x1s
sq_h = y2s - y1s
if sq_w != sq_h:
diff = abs(sq_w - sq_h)
if sq_w < sq_h:
x1s = max(0, x1s - diff // 2)
x2s = min(W, x2s + diff - diff // 2)
else:
y1s = max(0, y1s - diff // 2)
y2s = min(H, y2s + diff - diff // 2)
return (x1s, y1s, x2s, y2s)
def canonical_label(label: str) -> str:
"""Standardize label names into fake/real."""
l = label.lower().strip()
fake_keys = ["fake", "ai", "synthetic", "deepfake", "generated", "manipulated", "forged"]
real_keys = ["real", "authentic", "genuine", "original", "live"]
if any(k in l for k in fake_keys): return "fake"
if any(k in l for k in real_keys): return "real"
return label
def rank_probs(id2label: Dict[int, str], probs: np.ndarray) -> List[Tuple[str, float]]:
"""Sort probabilities by descending value."""
pairs = [(id2label[i], float(probs[i])) for i in range(len(probs))]
return sorted(pairs, key=lambda x: x[1], reverse=True)
def entropy(p: np.ndarray) -> float:
"""Probability entropy as uncertainty measure."""
p = np.clip(p, 1e-8, 1.0)
return float(-(p * np.log(p)).sum())
def jpeg_recompress(pil_img: Image.Image, quality: int = 85) -> Image.Image:
"""JPEG recompression to simulate compression noise (TTA)."""
buf = io.BytesIO()
pil_img.save(buf, format="JPEG", quality=quality, optimize=True)
buf.seek(0)
return Image.open(buf).convert("RGB")
def mild_center_crop_resize(pil_img: Image.Image, ratio: float = 0.92) -> Image.Image:
"""Slight center crop and resize (TTA)."""
w, h = pil_img.size
nw, nh = int(w * ratio), int(h * ratio)
left = (w - nw) // 2
top = (h - nh) // 2
return pil_img.crop((left, top, left + nw, top + nh)).resize((w, h), Image.LANCZOS)
# ===== Image quality evaluation (to reduce false positives on webcam photos) =====
def _np_gray(pil_img: Image.Image) -> np.ndarray:
return np.array(pil_img.convert("L"))
def _conv2d_same_reflect(img: np.ndarray, kernel: np.ndarray) -> np.ndarray:
"""Lightweight 2D convolution (reflect padding, no cv2/scipy)."""
kh, kw = kernel.shape
ph, pw = kh // 2, kw // 2
img_pad = np.pad(img, ((ph, ph), (pw, pw)), mode="reflect")
out = np.zeros_like(img, dtype=np.float32)
for i in range(out.shape[0]):
for j in range(out.shape[1]):
region = img_pad[i:i+kh, j:j+kw]
out[i, j] = float((region * kernel).sum())
return out
def variance_of_laplacian(gray: np.ndarray) -> float:
"""Sharpness measure (blur detection)."""
k = np.array([[0, 1, 0],[1,-4, 1],[0, 1, 0]], dtype=np.float32)
lap = _conv2d_same_reflect(gray.astype(np.float32), k)
return float(lap.var())
def jpeg_size_ratio(pil_img: Image.Image, quality: int = 85) -> float:
"""Compression ratio proxy."""
buf_png = io.BytesIO(); pil_img.save(buf_png, format="PNG"); s_png = len(buf_png.getvalue())
buf_jpg = io.BytesIO(); pil_img.save(buf_jpg, format="JPEG", quality=quality, optimize=True); s_jpg = len(buf_jpg.getvalue())
if s_png == 0: return 1.0
return float(s_jpg) / float(s_png)
def image_quality_metrics(pil_img: Image.Image) -> Dict[str, float]:
g = _np_gray(pil_img)
return {
"sharp": variance_of_laplacian(g),
"bright": float(g.mean()),
"contrast": float(g.std()),
"comp": jpeg_size_ratio(pil_img, 85)
}
def quality_bucket(m: Dict[str,float]) -> str:
"""Classify image quality level."""
poor = (m["sharp"] < 60) or (m["bright"] < 40) or (m["bright"] > 215) or (m["contrast"] < 25)
good = (m["sharp"] >= 120) and (50 <= m["bright"] <= 200) and (m["contrast"] >= 35)
if poor: return "poor"
if good: return "good"
return "ok"
def logit(p, eps=1e-6): p = min(max(p, eps), 1 - eps); return float(np.log(p/(1-p)))
def sigmoid(x): return float(1/(1+np.exp(-x)))
def calibrate_fake_prob(p: float, qlvl: str) -> float:
"""Quality-adaptive probability scaling."""
if qlvl == "poor": t, b = 1.6, -0.4
elif qlvl == "good": t, b = 0.9, +0.1
else: t, b = 1.2, -0.1
z = (logit(p) + b) / t
return sigmoid(z)
# ---- Load models ----
mtcnn = MTCNN(keep_all=True, device=DEVICE)
_models = []
for mid in MODEL_IDS:
try: processor = AutoImageProcessor.from_pretrained(mid, use_fast=True)
except Exception: processor = AutoImageProcessor.from_pretrained(mid)
clf = AutoModelForImageClassification.from_pretrained(mid).to(DEVICE).eval()
_models.append((mid, processor, clf))
# ---- Core inference ----
@torch.no_grad()
def classify_images_ensemble(pils: List[Image.Image]) -> Dict[str, Any]:
"""Run classification on multiple faces and return per-image predictions."""
per_image_results = []
def tta_views(img): return [
img,
ImageOps.mirror(img),
jpeg_recompress(img, 85),
mild_center_crop_resize(img, 0.92),
]
use_amp = (DEVICE == "cuda")
for img in pils:
views = tta_views(img)
model_probs_accum = []
for (mid, processor, clf) in _models:
inputs = processor(images=views, return_tensors="pt")
inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
with torch.cuda.amp.autocast(enabled=use_amp):
logits = clf(**inputs).logits
probs = torch.softmax(logits, dim=-1)
model_probs_accum.append(probs.mean(dim=0).unsqueeze(0))
probs_ens = torch.cat(model_probs_accum, dim=0).mean(dim=0)
probs_np = probs_ens.float().cpu().numpy()
id2label = _models[0][2].config.id2label
ranked = rank_probs(id2label, probs_np)
fake_p = real_p = None
for i in range(len(probs_np)):
cat = canonical_label(id2label[i])
if cat == "fake": fake_p = max(fake_p or 0, probs_np[i])
elif cat == "real": real_p = max(real_p or 0, probs_np[i])
if fake_p is None and real_p is None:
top1 = ranked[0]
if canonical_label(top1[0]) == "fake": fake_p, real_p = top1[1], 1-top1[1]
else: real_p, fake_p = top1[1], 1-top1[1]
per_image_results.append({
"top": ranked[:3],
"fake_prob": None if fake_p is None else round(fake_p, 4),
"real_prob": None if real_p is None else round(real_p, 4),
"entropy": round(entropy(probs_np), 4)
})
return {"per_image": per_image_results}
def analyze(img: Image.Image) -> Dict[str, Any]:
"""Detect deepfake in faces and full image, quality-aware fusion."""
if img is None: return {"error": "No image uploaded."}
img = resize_keep_ratio(img.convert("RGB"), MAX_SIDE)
q_metrics = image_quality_metrics(img)
q_level = quality_bucket(q_metrics)
boxes, probs = mtcnn.detect(img)
crops, crop_boxes, kept_scores = [], [], []
# Face detection & filtering
if boxes is not None and probs is not None:
W, H = img.size
for (b, sc) in zip(boxes, probs):
if sc is None or sc < DETECT_THRESH: continue
x1s, y1s, x2s, y2s = to_square_with_margin(b, W, H, FACE_MARGIN)
if x2s<=x1s or y2s<=y1s: continue
face = img.crop((x1s, y1s, x2s, y2s))
if min(face.size) < FACE_MIN_SIZE: continue
crops.append(face); crop_boxes.append((x1s,y1s,x2s,y2s)); kept_scores.append(float(sc))
# Face inference
per_face_results = []
if crops:
preds = classify_images_ensemble(crops)["per_image"]
for i,(pred,box,sc) in enumerate(zip(preds, crop_boxes, kept_scores),1):
fm = image_quality_metrics(crops[i-1])
fl = quality_bucket(fm)
fp = pred.get("fake_prob")
if fp is not None:
pred["fake_prob_raw"]=fp
pred["fake_prob"]=round(calibrate_fake_prob(fp, fl),4)
pred["quality"]={"level":fl,**fm}
per_face_results.append({
"face_index":i,"box":{"x1":box[0],"y1":box[1],"x2":box[2],"y2":box[3]},
"det_score":round(sc,4),**pred
})
# Full-image inference (weak expert)
full_pred = classify_images_ensemble([img])["per_image"][0]
if full_pred.get("fake_prob") is not None:
full_pred["fake_prob_raw"]=full_pred["fake_prob"]
full_pred["fake_prob"]=round(calibrate_fake_prob(full_pred["fake_prob"], q_level),4)
full_pred["quality"]={"level":q_level,**q_metrics}
# Score fusion (faces > full image)
face_scores=[r["fake_prob"] for r in per_face_results if r.get("fake_prob") is not None]
if not face_scores and full_pred.get("fake_prob") is None:
overall_fake=0.5
else:
faces_robust=float(np.median(face_scores)) if face_scores else None
full_score=full_pred.get("fake_prob",None)
if faces_robust and full_score:
overall_fake=0.8*faces_robust+0.2*full_score
elif faces_robust: overall_fake=faces_robust
else: overall_fake=full_score
if face_scores:
q3=float(np.quantile(face_scores,0.75))
overall_fake=float(0.7*overall_fake+0.3*q3)
# Dynamic thresholding based on quality
if q_level=="poor": th_fake,th_unc=0.85,0.65
elif q_level=="good": th_fake,th_unc=0.70,0.55
else: th_fake,th_unc=0.75,0.60
if overall_fake>=th_fake: label="Likely Deepfake"
elif overall_fake>=th_unc: label="Uncertain"
else: label="Likely Real"
return {
"label":label,
"overall_fake_probability":round(overall_fake,4),
"faces_detected":len(per_face_results),
}
def visualize_faces(img: Image.Image):
"""Return cropped faces for preview."""
if img is None: return []
img=resize_keep_ratio(img.convert("RGB"),MAX_SIDE)
boxes,probs=mtcnn.detect(img)
thumbs=[]
if boxes is not None and probs is not None:
W,H=img.size
for(b,sc) in zip(boxes,probs):
if sc is None or sc<DETECT_THRESH: continue
x1s,y1s,x2s,y2s=to_square_with_margin(b,W,H,FACE_MARGIN)
if x2s<=x1s or y2s<=y1s: continue
face=img.crop((x1s,y1s,x2s,y2s))
if min(face.size)<FACE_MIN_SIZE: continue
thumbs.append(face.resize((160,160),Image.LANCZOS))
return thumbs
# ---- Gradio UI ----
with gr.Blocks() as demo:
gr.Markdown("""
# πŸ•΅οΈ FakeSpotter β€” Deepfake Image Detector
**Ensemble + TTA + Quality Awareness**
- Multi-model ensemble
- Face-focused detection + fallback to whole image
- Image quality guard to reduce false positives (e.g., webcam noise)
> Educational, not forensic-grade.
""")
with gr.Row():
inp=gr.Image(type="pil",label="Upload Image")
with gr.Column():
thumbs=gr.Gallery(label="Detected Face Crops (preview)",columns=6,height="auto")
out=gr.JSON(label="Results")
btn=gr.Button("Analyze")
btn.click(lambda im:visualize_faces(im),inputs=inp,outputs=thumbs)
btn.click(analyze,inputs=inp,outputs=out)
if __name__=="__main__":
demo.launch()