vaniv commited on
Commit
3777334
·
verified ·
1 Parent(s): 7ae7475

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +74 -56
app.py CHANGED
@@ -1,19 +1,33 @@
1
  # app.py
2
- import io
3
  import numpy as np
4
  import gradio as gr
5
  from PIL import Image, ImageDraw
6
  import cv2
7
  import torch
8
- from transformers import AutoImageProcessor, ViTForImageClassification
9
- import mediapipe as mp
10
-
11
- # -------------------- Face crop utilities --------------------
12
- _mp_face = mp.solutions.face_detection.FaceDetection(
13
- model_selection=0, min_detection_confidence=0.4
14
- )
 
 
 
 
 
 
 
 
15
 
 
16
  def crop_face(pil_img, pad=0.25):
 
 
 
 
 
 
17
  img = np.array(pil_img.convert("RGB"))
18
  h, w = img.shape[:2]
19
  res = _mp_face.process(cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
@@ -28,10 +42,12 @@ def crop_face(pil_img, pad=0.25):
28
  x1 = int(max(0, (x - pad*bw) * w)); y1 = int(max(0, (y - pad*bh) * h))
29
  x2 = int(min(w, (x + bw + pad*bw) * w)); y2 = int(min(h, (y + bh + pad*bh) * h))
30
  face = Image.fromarray(img[y1:y2, x1:x2])
31
- return face if face.size[0] > 20 and face.size[1] > 20 else pil_img
 
 
32
 
 
33
  def face_oval_mask(img_pil, shrink=0.80):
34
-
35
  w, h = img_pil.size
36
  mask = Image.new("L", (w, h), 0)
37
  draw = ImageDraw.Draw(mask)
@@ -39,69 +55,73 @@ def face_oval_mask(img_pil, shrink=0.80):
39
  draw.ellipse((dx, dy, w - dx, h - dy), fill=255)
40
  return np.array(mask, dtype=np.float32) / 255.0
41
 
42
- # -------------------- HF model: Deepfake vs Realism --------------------
43
- MODEL_ID = "prithivMLmods/Deep-Fake-Detector-v2-Model"
44
-
45
- # CPU by default
46
- _hf_processor = AutoImageProcessor.from_pretrained(MODEL_ID)
47
- _hf_model = ViTForImageClassification.from_pretrained(MODEL_ID)
48
- _hf_model.eval()
49
  torch.set_grad_enabled(False)
50
 
51
- _FAKE_KEYS = ("fake", "deepfake", "manipulated", "spoof", "forged")
 
52
 
53
- def _deepfake_index_from_config(cfg) -> int | None:
54
- """
55
- Try to find the class index for 'Deepfake' from id2label/label2id.
56
- This model typically has {0:'Realism', 1:'Deepfake'}.
57
- """
58
  # Prefer id2label
59
  id2label = getattr(cfg, "id2label", None)
60
  if id2label:
61
- normalized = {int(k): str(v).lower() for k, v in id2label.items()}
 
 
 
 
62
  for idx, lab in normalized.items():
63
  if any(k in lab for k in _FAKE_KEYS):
64
  return idx
65
-
66
- # Fallback to label2id if present
67
  label2id = getattr(cfg, "label2id", None)
68
  if label2id:
69
  inv = {int(v): str(k).lower() for k, v in label2id.items()}
70
  for idx, lab in inv.items():
71
  if any(k in lab for k in _FAKE_KEYS):
72
  return idx
73
-
74
  return None
75
 
76
- _DEEP_IDX = _deepfake_index_from_config(_hf_model.config)
77
 
78
- def _hf_predict_proba(pil_img: Image.Image) -> float:
 
79
  """
80
- Returns P(Deepfake) in [0,1] using the ViT classifier.
 
81
  """
82
- inputs = _hf_processor(images=pil_img.convert("RGB"), return_tensors="pt")
83
- with torch.no_grad():
84
- logits = _hf_model(**inputs).logits # (1, C)
 
 
 
85
 
86
  if logits.shape[-1] == 1:
87
- # Binary sigmoid head
88
  return torch.sigmoid(logits.squeeze(0))[0].item()
89
 
90
- # Softmax head
91
  probs = torch.softmax(logits.squeeze(0), dim=-1).detach().cpu().numpy()
92
- if _DEEP_IDX is not None and 0 <= _DEEP_IDX < probs.shape[0]:
93
- return float(probs[_DEEP_IDX])
94
 
95
- # Binary fallback
 
 
 
 
96
  if probs.shape[0] == 2:
97
- return float(probs[1])
98
 
99
- # Last resort: take max
100
  return float(probs.max())
101
 
102
- # -------------------- Output card --------------------
103
- def _result_card(label: str, conf: float) -> str:
104
- pct = max(0.0, min(1.0, conf)) * 100.0
 
105
  color = "#d84a4a" if label.startswith("Likely Manipulated") else "#2e7d32"
106
  bar_bg = "#e9ecef"
107
  return f"""
@@ -115,24 +135,21 @@ def _result_card(label: str, conf: float) -> str:
115
  <div style="width:100%;height:10px;background:{bar_bg};border-radius:999px;overflow:hidden;">
116
  <div style="height:100%;width:{pct:.4f}%;background:{color};"></div>
117
  </div>
 
 
 
118
  </div>
119
  </div>
120
  """
121
 
122
- # -------------------- Gradio handler --------------------
123
  def analyze(pil_img: Image.Image):
124
  if pil_img is None:
125
- return _result_card("Likely Authentic", 0.0)
126
-
127
- # Focus on the face to reduce background false positives
128
- face = crop_face(pil_img)
129
- face = face.convert("RGB").resize((224, 224)) # ViT expects 224x224
130
-
131
- p_fake = _hf_predict_proba(face)
132
- label = "Likely Manipulated" if p_fake >= 0.65 else "Likely Authentic"
133
- return _result_card(label, p_fake)
134
 
135
- # -------------------- UI --------------------
136
  CUSTOM_CSS = """
137
  .gradio-container {max-width: 980px !important;}
138
  .sleek-card {
@@ -141,9 +158,10 @@ CUSTOM_CSS = """
141
  }
142
  """
143
 
144
- with gr.Blocks(title="Deepfake Detector (ViT)", css=CUSTOM_CSS, theme=gr.themes.Soft()) as demo:
145
  gr.Markdown(
146
- "<h2 style='text-align:center;margin-bottom:6px;'>Deepfake Detector (ViT)</h2>"
 
147
  )
148
  with gr.Row():
149
  with gr.Column(scale=6, elem_classes=["sleek-card"]):
 
1
  # app.py
 
2
  import numpy as np
3
  import gradio as gr
4
  from PIL import Image, ImageDraw
5
  import cv2
6
  import torch
7
+ from transformers import AutoImageProcessor, AutoModelForImageClassification
8
+
9
+ # ---- Config ----
10
+ MODEL_ID = "SadraCoding/SDXL-Deepfake-Detector"
11
+ THRESHOLD = 0.65 # >= -> "Likely Manipulated"
12
+ IMAGE_SIZE = 224 # ViT input size
13
+
14
+ # Optional: MediaPipe face detection (app still works if not installed)
15
+ try:
16
+ import mediapipe as mp
17
+ _mp_face = mp.solutions.face_detection.FaceDetection(
18
+ model_selection=0, min_detection_confidence=0.4
19
+ )
20
+ except Exception:
21
+ _mp_face = None
22
 
23
+ # ---- Face crop ----
24
  def crop_face(pil_img, pad=0.25):
25
+ """
26
+ Crop the most prominent face using MediaPipe. If MP missing or no face found,
27
+ return the original image.
28
+ """
29
+ if _mp_face is None:
30
+ return pil_img
31
  img = np.array(pil_img.convert("RGB"))
32
  h, w = img.shape[:2]
33
  res = _mp_face.process(cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
 
42
  x1 = int(max(0, (x - pad*bw) * w)); y1 = int(max(0, (y - pad*bh) * h))
43
  x2 = int(min(w, (x + bw + pad*bw) * w)); y2 = int(min(h, (y + bh + pad*bh) * h))
44
  face = Image.fromarray(img[y1:y2, x1:x2])
45
+ if face.size[0] < 20 or face.size[1] < 20:
46
+ return pil_img
47
+ return face
48
 
49
+ # (Not used for inference; kept if you want to mask background later)
50
  def face_oval_mask(img_pil, shrink=0.80):
 
51
  w, h = img_pil.size
52
  mask = Image.new("L", (w, h), 0)
53
  draw = ImageDraw.Draw(mask)
 
55
  draw.ellipse((dx, dy, w - dx, h - dy), fill=255)
56
  return np.array(mask, dtype=np.float32) / 255.0
57
 
58
+ # ---- HF model load ----
59
+ processor = AutoImageProcessor.from_pretrained(MODEL_ID)
60
+ model = AutoModelForImageClassification.from_pretrained(MODEL_ID)
61
+ model.eval()
 
 
 
62
  torch.set_grad_enabled(False)
63
 
64
+ # Resolve which index corresponds to "fake"
65
+ _FAKE_KEYS = ("artificial", "fake", "deepfake", "manipulated", "spoof", "forged")
66
 
67
+ def _fake_index_from_config(cfg) -> int | None:
 
 
 
 
68
  # Prefer id2label
69
  id2label = getattr(cfg, "id2label", None)
70
  if id2label:
71
+ try:
72
+ normalized = {int(k): str(v).lower() for k, v in id2label.items()}
73
+ except Exception:
74
+ # sometimes keys already ints
75
+ normalized = {int(k): str(v).lower() for k, v in id2label.items()}
76
  for idx, lab in normalized.items():
77
  if any(k in lab for k in _FAKE_KEYS):
78
  return idx
79
+ # Fallback: invert label2id
 
80
  label2id = getattr(cfg, "label2id", None)
81
  if label2id:
82
  inv = {int(v): str(k).lower() for k, v in label2id.items()}
83
  for idx, lab in inv.items():
84
  if any(k in lab for k in _FAKE_KEYS):
85
  return idx
 
86
  return None
87
 
88
+ _FAKE_IDX = _fake_index_from_config(model.config)
89
 
90
+ # ---- Inference ----
91
+ def predict_fake_prob(pil_img: Image.Image) -> float:
92
  """
93
+ Returns P(fake) in [0,1].
94
+ Model labels per card: 0 -> 'artificial' (fake), 1 -> 'human' (real).
95
  """
96
+ # Face-focus to reduce background bias
97
+ face = crop_face(pil_img)
98
+ face = face.convert("RGB").resize((IMAGE_SIZE, IMAGE_SIZE))
99
+
100
+ inputs = processor(images=face, return_tensors="pt")
101
+ logits = model(**inputs).logits # (1, C)
102
 
103
  if logits.shape[-1] == 1:
104
+ # Binary sigmoid head (unlikely for this model, but safe)
105
  return torch.sigmoid(logits.squeeze(0))[0].item()
106
 
107
+ # Softmax multi-class (expected 2 classes)
108
  probs = torch.softmax(logits.squeeze(0), dim=-1).detach().cpu().numpy()
 
 
109
 
110
+ # Use explicit mapping if available
111
+ if _FAKE_IDX is not None and 0 <= _FAKE_IDX < probs.shape[0]:
112
+ return float(probs[_FAKE_IDX])
113
+
114
+ # Known mapping from the model card: 0=artificial (fake), 1=human
115
  if probs.shape[0] == 2:
116
+ return float(probs[0]) # class-0 is fake
117
 
118
+ # Last resort
119
  return float(probs.max())
120
 
121
+ # ---- UI helpers ----
122
+ def result_card(prob_fake: float) -> str:
123
+ label = "Likely Manipulated" if prob_fake >= THRESHOLD else "Likely Authentic"
124
+ pct = prob_fake * 100.0
125
  color = "#d84a4a" if label.startswith("Likely Manipulated") else "#2e7d32"
126
  bar_bg = "#e9ecef"
127
  return f"""
 
135
  <div style="width:100%;height:10px;background:{bar_bg};border-radius:999px;overflow:hidden;">
136
  <div style="height:100%;width:{pct:.4f}%;background:{color};"></div>
137
  </div>
138
+ <div style="font-size:12px;color:#6b7280;margin-top:8px;">
139
+ Model: {MODEL_ID} · Threshold: {int(THRESHOLD*100)}%
140
+ </div>
141
  </div>
142
  </div>
143
  """
144
 
145
+ # ---- Gradio handlers ----
146
  def analyze(pil_img: Image.Image):
147
  if pil_img is None:
148
+ return result_card(0.0)
149
+ p_fake = predict_fake_prob(pil_img)
150
+ return result_card(p_fake)
 
 
 
 
 
 
151
 
152
+ # ---- UI ----
153
  CUSTOM_CSS = """
154
  .gradio-container {max-width: 980px !important;}
155
  .sleek-card {
 
158
  }
159
  """
160
 
161
+ with gr.Blocks(title="Deepfake Detector (SDXL ViT)", css=CUSTOM_CSS, theme=gr.themes.Soft()) as demo:
162
  gr.Markdown(
163
+ "<h2 style='text-align:center;margin-bottom:6px;'>Deepfake Detector (SDXL ViT)</h2>"
164
+ "<p style='text-align:center;color:#6b7280;'>MediaPipe face-crop + Vision Transformer fine-tuned for artificial vs human faces.</p>"
165
  )
166
  with gr.Row():
167
  with gr.Column(scale=6, elem_classes=["sleek-card"]):