File size: 12,899 Bytes
7957cc3
2ad8303
8e37de3
2ad8303
 
b005330
 
2ad8303
 
b005330
2ad8303
 
 
 
a5f333a
8e37de3
 
a5f333a
 
8e37de3
a5f333a
 
 
 
8e37de3
a5f333a
2ad8303
 
8e37de3
 
 
2ad8303
 
a5f333a
7957cc3
 
 
 
 
 
b005330
8e37de3
a5f333a
8e37de3
 
 
 
 
 
 
 
 
 
a5f333a
 
8e37de3
 
 
 
 
 
 
 
 
 
 
 
2ad8303
a5f333a
8e37de3
 
 
a5f333a
 
 
2ad8303
8e37de3
a5f333a
2ad8303
 
 
8e37de3
a5f333a
8e37de3
 
 
 
a5f333a
8e37de3
 
 
 
 
a5f333a
 
8e37de3
 
 
 
 
 
a5f333a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2ad8303
8e37de3
 
a5f333a
 
d6d1301
8e37de3
2ad8303
 
 
8e37de3
a5f333a
8e37de3
 
a5f333a
 
 
 
 
 
8e37de3
 
 
 
 
a5f333a
8e37de3
 
 
 
a5f333a
 
 
8e37de3
a5f333a
 
8e37de3
 
a5f333a
 
8e37de3
a5f333a
 
 
 
8e37de3
a5f333a
 
 
8e37de3
 
 
 
 
 
 
 
b005330
2ad8303
a5f333a
 
2ad8303
a5f333a
 
 
2ad8303
8e37de3
 
 
a5f333a
8e37de3
 
 
a5f333a
8e37de3
a5f333a
8e37de3
a5f333a
 
8e37de3
a5f333a
8e37de3
a5f333a
 
 
 
 
 
 
 
 
 
8e37de3
a5f333a
 
8e37de3
2ad8303
a5f333a
8e37de3
 
a5f333a
 
 
7957cc3
a5f333a
 
 
 
8e37de3
a5f333a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7957cc3
2ad8303
a5f333a
 
 
2ad8303
b005330
8e37de3
a5f333a
 
 
 
 
8e37de3
a5f333a
 
 
 
 
 
 
 
8e37de3
 
a5f333a
b005330
a5f333a
 
 
 
 
 
 
 
 
2ad8303
a5f333a
8e37de3
a5f333a
 
b005330
a5f333a
 
 
 
 
b005330
a5f333a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
import io
from typing import List, Tuple, Dict, Any
from PIL import Image, ImageOps
import numpy as np
import torch
import gradio as gr

# Face detector
from facenet_pytorch import MTCNN

# HF image classifier
from transformers import AutoImageProcessor, AutoModelForImageClassification

# ========= Config =========
# Multi-model ensemble (Soft Voting). You can add 1~2 more deepfake binary classifiers here.
MODEL_IDS = [
    "prithivMLmods/Deep-Fake-Detector-v2-Model",
    # Example for additional model:
    # "HuggingFaceM4/dfdc_deit_base_patch16_224",
]
MAX_SIDE = 896           # Resize image, keep detail
FACE_MIN_SIZE = 112      # Faces smaller than this in pixels are skipped (avoid artifacts)
FACE_MARGIN = 0.20       # Margin added when cropping face (square crop)
DETECT_THRESH = 0.80     # Face detection confidence threshold
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SEED = 42                # For reproducibility
# =========================

torch.manual_seed(SEED)
np.random.seed(SEED)

# ---- Utilities ----
def resize_keep_ratio(img: Image.Image, max_side: int = MAX_SIDE) -> Image.Image:
    """Resize while preserving aspect ratio."""
    w, h = img.size
    m = max(w, h)
    if m <= max_side:
        return img
    scale = max_side / float(m)
    return img.resize((int(w * scale), int(h * scale)), Image.LANCZOS)

def to_square_with_margin(box, W, H, margin=FACE_MARGIN):
    """Convert face bounding box to a square crop with margin."""
    x1, y1, x2, y2 = box
    w = x2 - x1
    h = y2 - y1
    cx = (x1 + x2) / 2.0
    cy = (y1 + y2) / 2.0
    side = max(w, h) * (1.0 + margin)
    x1s = int(max(0, cx - side/2))
    y1s = int(max(0, cy - side/2))
    x2s = int(min(W, cx + side/2))
    y2s = int(min(H, cy + side/2))

    # Make it square
    sq_w = x2s - x1s
    sq_h = y2s - y1s
    if sq_w != sq_h:
        diff = abs(sq_w - sq_h)
        if sq_w < sq_h:
            x1s = max(0, x1s - diff // 2)
            x2s = min(W, x2s + diff - diff // 2)
        else:
            y1s = max(0, y1s - diff // 2)
            y2s = min(H, y2s + diff - diff // 2)
    return (x1s, y1s, x2s, y2s)

def canonical_label(label: str) -> str:
    """Standardize label names into fake/real."""
    l = label.lower().strip()
    fake_keys = ["fake", "ai", "synthetic", "deepfake", "generated", "manipulated", "forged"]
    real_keys = ["real", "authentic", "genuine", "original", "live"]
    if any(k in l for k in fake_keys): return "fake"
    if any(k in l for k in real_keys): return "real"
    return label

def rank_probs(id2label: Dict[int, str], probs: np.ndarray) -> List[Tuple[str, float]]:
    """Sort probabilities by descending value."""
    pairs = [(id2label[i], float(probs[i])) for i in range(len(probs))]
    return sorted(pairs, key=lambda x: x[1], reverse=True)

def entropy(p: np.ndarray) -> float:
    """Probability entropy as uncertainty measure."""
    p = np.clip(p, 1e-8, 1.0)
    return float(-(p * np.log(p)).sum())

def jpeg_recompress(pil_img: Image.Image, quality: int = 85) -> Image.Image:
    """JPEG recompression to simulate compression noise (TTA)."""
    buf = io.BytesIO()
    pil_img.save(buf, format="JPEG", quality=quality, optimize=True)
    buf.seek(0)
    return Image.open(buf).convert("RGB")

def mild_center_crop_resize(pil_img: Image.Image, ratio: float = 0.92) -> Image.Image:
    """Slight center crop and resize (TTA)."""
    w, h = pil_img.size
    nw, nh = int(w * ratio), int(h * ratio)
    left = (w - nw) // 2
    top = (h - nh) // 2
    return pil_img.crop((left, top, left + nw, top + nh)).resize((w, h), Image.LANCZOS)

# ===== Image quality evaluation (to reduce false positives on webcam photos) =====
def _np_gray(pil_img: Image.Image) -> np.ndarray:
    return np.array(pil_img.convert("L"))

def _conv2d_same_reflect(img: np.ndarray, kernel: np.ndarray) -> np.ndarray:
    """Lightweight 2D convolution (reflect padding, no cv2/scipy)."""
    kh, kw = kernel.shape
    ph, pw = kh // 2, kw // 2
    img_pad = np.pad(img, ((ph, ph), (pw, pw)), mode="reflect")
    out = np.zeros_like(img, dtype=np.float32)
    for i in range(out.shape[0]):
        for j in range(out.shape[1]):
            region = img_pad[i:i+kh, j:j+kw]
            out[i, j] = float((region * kernel).sum())
    return out

def variance_of_laplacian(gray: np.ndarray) -> float:
    """Sharpness measure (blur detection)."""
    k = np.array([[0, 1, 0],[1,-4, 1],[0, 1, 0]], dtype=np.float32)
    lap = _conv2d_same_reflect(gray.astype(np.float32), k)
    return float(lap.var())

def jpeg_size_ratio(pil_img: Image.Image, quality: int = 85) -> float:
    """Compression ratio proxy."""
    buf_png = io.BytesIO(); pil_img.save(buf_png, format="PNG"); s_png = len(buf_png.getvalue())
    buf_jpg = io.BytesIO(); pil_img.save(buf_jpg, format="JPEG", quality=quality, optimize=True); s_jpg = len(buf_jpg.getvalue())
    if s_png == 0: return 1.0
    return float(s_jpg) / float(s_png)

def image_quality_metrics(pil_img: Image.Image) -> Dict[str, float]:
    g = _np_gray(pil_img)
    return {
        "sharp": variance_of_laplacian(g),
        "bright": float(g.mean()),
        "contrast": float(g.std()),
        "comp": jpeg_size_ratio(pil_img, 85)
    }

def quality_bucket(m: Dict[str,float]) -> str:
    """Classify image quality level."""
    poor = (m["sharp"] < 60) or (m["bright"] < 40) or (m["bright"] > 215) or (m["contrast"] < 25)
    good = (m["sharp"] >= 120) and (50 <= m["bright"] <= 200) and (m["contrast"] >= 35)
    if poor: return "poor"
    if good: return "good"
    return "ok"

def logit(p, eps=1e-6): p = min(max(p, eps), 1 - eps); return float(np.log(p/(1-p)))
def sigmoid(x): return float(1/(1+np.exp(-x)))

def calibrate_fake_prob(p: float, qlvl: str) -> float:
    """Quality-adaptive probability scaling."""
    if qlvl == "poor":  t, b = 1.6, -0.4
    elif qlvl == "good": t, b = 0.9, +0.1
    else:                t, b = 1.2, -0.1
    z = (logit(p) + b) / t
    return sigmoid(z)

# ---- Load models ----
mtcnn = MTCNN(keep_all=True, device=DEVICE)
_models = []
for mid in MODEL_IDS:
    try: processor = AutoImageProcessor.from_pretrained(mid, use_fast=True)
    except Exception: processor = AutoImageProcessor.from_pretrained(mid)
    clf = AutoModelForImageClassification.from_pretrained(mid).to(DEVICE).eval()
    _models.append((mid, processor, clf))

# ---- Core inference ----
@torch.no_grad()
def classify_images_ensemble(pils: List[Image.Image]) -> Dict[str, Any]:
    """Run classification on multiple faces and return per-image predictions."""
    per_image_results = []

    def tta_views(img): return [
        img,
        ImageOps.mirror(img),
        jpeg_recompress(img, 85),
        mild_center_crop_resize(img, 0.92),
    ]

    use_amp = (DEVICE == "cuda")
    for img in pils:
        views = tta_views(img)
        model_probs_accum = []

        for (mid, processor, clf) in _models:
            inputs = processor(images=views, return_tensors="pt")
            inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
            with torch.cuda.amp.autocast(enabled=use_amp):
                logits = clf(**inputs).logits
            probs = torch.softmax(logits, dim=-1)
            model_probs_accum.append(probs.mean(dim=0).unsqueeze(0))

        probs_ens = torch.cat(model_probs_accum, dim=0).mean(dim=0)
        probs_np = probs_ens.float().cpu().numpy()
        id2label = _models[0][2].config.id2label
        ranked = rank_probs(id2label, probs_np)

        fake_p = real_p = None
        for i in range(len(probs_np)):
            cat = canonical_label(id2label[i])
            if cat == "fake": fake_p = max(fake_p or 0, probs_np[i])
            elif cat == "real": real_p = max(real_p or 0, probs_np[i])

        if fake_p is None and real_p is None:
            top1 = ranked[0]
            if canonical_label(top1[0]) == "fake": fake_p, real_p = top1[1], 1-top1[1]
            else: real_p, fake_p = top1[1], 1-top1[1]

        per_image_results.append({
            "top": ranked[:3],
            "fake_prob": None if fake_p is None else round(fake_p, 4),
            "real_prob": None if real_p is None else round(real_p, 4),
            "entropy": round(entropy(probs_np), 4)
        })
    return {"per_image": per_image_results}

def analyze(img: Image.Image) -> Dict[str, Any]:
    """Detect deepfake in faces and full image, quality-aware fusion."""
    if img is None: return {"error": "No image uploaded."}

    img = resize_keep_ratio(img.convert("RGB"), MAX_SIDE)
    q_metrics = image_quality_metrics(img)
    q_level = quality_bucket(q_metrics)

    boxes, probs = mtcnn.detect(img)
    crops, crop_boxes, kept_scores = [], [], []

    # Face detection & filtering
    if boxes is not None and probs is not None:
        W, H = img.size
        for (b, sc) in zip(boxes, probs):
            if sc is None or sc < DETECT_THRESH: continue
            x1s, y1s, x2s, y2s = to_square_with_margin(b, W, H, FACE_MARGIN)
            if x2s<=x1s or y2s<=y1s: continue
            face = img.crop((x1s, y1s, x2s, y2s))
            if min(face.size) < FACE_MIN_SIZE: continue
            crops.append(face); crop_boxes.append((x1s,y1s,x2s,y2s)); kept_scores.append(float(sc))

    # Face inference
    per_face_results = []
    if crops:
        preds = classify_images_ensemble(crops)["per_image"]
        for i,(pred,box,sc) in enumerate(zip(preds, crop_boxes, kept_scores),1):
            fm = image_quality_metrics(crops[i-1])
            fl = quality_bucket(fm)
            fp = pred.get("fake_prob")
            if fp is not None:
                pred["fake_prob_raw"]=fp
                pred["fake_prob"]=round(calibrate_fake_prob(fp, fl),4)
            pred["quality"]={"level":fl,**fm}
            per_face_results.append({
                "face_index":i,"box":{"x1":box[0],"y1":box[1],"x2":box[2],"y2":box[3]},
                "det_score":round(sc,4),**pred
            })

    # Full-image inference (weak expert)
    full_pred = classify_images_ensemble([img])["per_image"][0]
    if full_pred.get("fake_prob") is not None:
        full_pred["fake_prob_raw"]=full_pred["fake_prob"]
        full_pred["fake_prob"]=round(calibrate_fake_prob(full_pred["fake_prob"], q_level),4)
    full_pred["quality"]={"level":q_level,**q_metrics}

    # Score fusion (faces > full image)
    face_scores=[r["fake_prob"] for r in per_face_results if r.get("fake_prob") is not None]
    if not face_scores and full_pred.get("fake_prob") is None:
        overall_fake=0.5
    else:
        faces_robust=float(np.median(face_scores)) if face_scores else None
        full_score=full_pred.get("fake_prob",None)

        if faces_robust and full_score:
            overall_fake=0.8*faces_robust+0.2*full_score
        elif faces_robust: overall_fake=faces_robust
        else: overall_fake=full_score

        if face_scores:
            q3=float(np.quantile(face_scores,0.75))
            overall_fake=float(0.7*overall_fake+0.3*q3)

    # Dynamic thresholding based on quality
    if q_level=="poor": th_fake,th_unc=0.85,0.65
    elif q_level=="good": th_fake,th_unc=0.70,0.55
    else: th_fake,th_unc=0.75,0.60

    if overall_fake>=th_fake: label="Likely Deepfake"
    elif overall_fake>=th_unc: label="Uncertain"
    else: label="Likely Real"

    return {
        "label":label,
        "overall_fake_probability":round(overall_fake,4),
        "faces_detected":len(per_face_results),
    }

def visualize_faces(img: Image.Image):
    """Return cropped faces for preview."""
    if img is None: return []
    img=resize_keep_ratio(img.convert("RGB"),MAX_SIDE)
    boxes,probs=mtcnn.detect(img)
    thumbs=[]
    if boxes is not None and probs is not None:
        W,H=img.size
        for(b,sc) in zip(boxes,probs):
            if sc is None or sc<DETECT_THRESH: continue
            x1s,y1s,x2s,y2s=to_square_with_margin(b,W,H,FACE_MARGIN)
            if x2s<=x1s or y2s<=y1s: continue
            face=img.crop((x1s,y1s,x2s,y2s))
            if min(face.size)<FACE_MIN_SIZE: continue
            thumbs.append(face.resize((160,160),Image.LANCZOS))
    return thumbs

# ---- Gradio UI ----
with gr.Blocks() as demo:
    gr.Markdown("""
    # πŸ•΅οΈ FakeSpotter β€” Deepfake Image Detector  
    **Ensemble + TTA + Quality Awareness**
    - Multi-model ensemble  
    - Face-focused detection + fallback to whole image  
    - Image quality guard to reduce false positives (e.g., webcam noise)  
    > Educational, not forensic-grade.
    """)

    with gr.Row():
        inp=gr.Image(type="pil",label="Upload Image")
        with gr.Column():
            thumbs=gr.Gallery(label="Detected Face Crops (preview)",columns=6,height="auto")
            out=gr.JSON(label="Results")

    btn=gr.Button("Analyze")
    btn.click(lambda im:visualize_faces(im),inputs=inp,outputs=thumbs)
    btn.click(analyze,inputs=inp,outputs=out)

if __name__=="__main__":
    demo.launch()