vaniv commited on
Commit
7ff83e8
·
verified ·
1 Parent(s): 3597267

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -62
app.py CHANGED
@@ -1,102 +1,113 @@
1
-
2
  import os
3
  import typing as t
4
 
5
  import gradio as gr
6
  import numpy as np
7
  import tensorflow as tf
 
 
 
8
  from PIL import Image
9
 
10
- # Try to load a user-provided Keras model if available; else fallback to a pretrained MobileNetV2.
11
- CUSTOM_MODEL_PATH = "model.h5"
12
- LABELS_PATH = "labels.txt"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
- # Global objects
15
- MODEL = None
16
- USE_IMAGENET_DECODE = False
17
- CLASS_NAMES: t.Optional[t.List[str]] = None
18
- TARGET_SIZE = (224, 224)
19
 
20
- def _maybe_load_labels(path: str) -> t.Optional[t.List[str]]:
21
- if os.path.exists(path):
22
- with open(path, "r", encoding="utf-8") as f:
23
- lines = [x.strip() for x in f.readlines() if x.strip()]
24
- return lines
25
- return None
 
 
 
 
 
 
 
 
26
 
27
  def _load_model():
28
- global MODEL, USE_IMAGENET_DECODE, CLASS_NAMES
 
 
29
  if os.path.exists(CUSTOM_MODEL_PATH):
30
  try:
31
  MODEL = tf.keras.models.load_model(CUSTOM_MODEL_PATH, compile=False)
32
- CLASS_NAMES = _maybe_load_labels(LABELS_PATH)
33
- USE_IMAGENET_DECODE = False
34
  print("Loaded custom model from model.h5")
35
  return
36
  except Exception as e:
37
- print("Failed to load custom model:", e)
 
 
 
 
 
 
 
 
38
 
39
- # Fallback: MobileNetV2 pretrained on ImageNet
40
- MODEL = tf.keras.applications.MobileNetV2(weights="imagenet")
41
- USE_IMAGENET_DECODE = True
42
- CLASS_NAMES = None
43
- print("Loaded MobileNetV2 (ImageNet) fallback.")
44
 
45
  def _preprocess(img: Image.Image) -> np.ndarray:
46
- # Convert to RGB and resize
47
  img = img.convert("RGB").resize(TARGET_SIZE)
48
- arr = np.array(img).astype("float32")
49
- # If it's the MobileNetV2 fallback, apply its preprocess; otherwise just scale 0..1
50
- if USE_IMAGENET_DECODE:
51
- arr = tf.keras.applications.mobilenet_v2.preprocess_input(arr)
52
- else:
53
- arr = arr / 255.0
54
- arr = np.expand_dims(arr, axis=0)
55
- return arr
56
-
57
- def _decode_predictions(preds: np.ndarray, top: int = 3):
58
- # preds: (1, num_classes)
59
- preds = preds[0]
60
- if USE_IMAGENET_DECODE:
61
- decoded = tf.keras.applications.imagenet_utils.decode_predictions(preds[np.newaxis, :], top=top)[0]
62
- # decoded is list of tuples: (class_id, class_name, score)
63
- return [(name, float(score)) for (_, name, score) in decoded]
64
- else:
65
- # For custom model: if CLASS_NAMES provided, map; else show class indices
66
- top_indices = preds.argsort()[-top:][::-1]
67
- out = []
68
- for idx in top_indices:
69
- label = CLASS_NAMES[idx] if (CLASS_NAMES is not None and idx < len(CLASS_NAMES)) else f"class_{idx}"
70
- out.append((label, float(preds[idx])))
71
- return out
72
 
73
  def predict(image: Image.Image):
74
  if image is None:
75
- return [], None
76
  x = _preprocess(image)
77
- preds = MODEL.predict(x)
78
- top3 = _decode_predictions(preds, top=3)
79
- # Also return a bar plot-friendly structure for Gradio's Label component
80
- scores = {label: score for (label, score) in top3}
81
- return scores, image
 
82
 
83
- # Initialize
84
  _load_model()
85
 
86
- with gr.Blocks(title="Image Classifier (Keras/TF)") as demo:
87
- gr.Markdown("# Image Classifier\nUpload an image to classify using a Keras model.\n\n"
88
- "- Drop in your own `model.h5` (and optional `labels.txt`) to switch from ImageNet to your custom model.\n"
89
- "- For custom models, ensure input size is 224x224x3 or adjust code.\n")
90
 
91
  with gr.Row():
92
  with gr.Column(scale=1):
93
  inp = gr.Image(type="pil", label="Upload image")
94
  btn = gr.Button("Predict")
95
  with gr.Column(scale=1):
96
- out_label = gr.Label(num_top_classes=3, label="Top Predictions")
97
  out_img = gr.Image(type="pil", label="Preview")
 
98
 
99
- btn.click(fn=predict, inputs=inp, outputs=[out_label, out_img])
100
 
101
  if __name__ == "__main__":
102
  demo.launch()
 
 
1
  import os
2
  import typing as t
3
 
4
  import gradio as gr
5
  import numpy as np
6
  import tensorflow as tf
7
+ from tensorflow.keras.layers import (Input, Conv2D, BatchNormalization,
8
+ MaxPooling2D, Flatten, Dropout, Dense, LeakyReLU)
9
+ from tensorflow.keras.models import Model
10
  from PIL import Image
11
 
12
+ # Paths
13
+ CUSTOM_MODEL_PATH = "model.h5" # optional: full Keras model
14
+ MESO_WEIGHTS_PATH = "weights/Meso4_DF" # your weights-only file
15
+ LABELS = ["real", "fake"] # index 0..1 (we'll compute both scores)
16
+
17
+ # Globals
18
+ MODEL: t.Optional[tf.keras.Model] = None
19
+ IS_MESO = False
20
+ TARGET_SIZE = (256, 256) # your notebook used 256×256
21
+ THRESHOLD = 0.5 # sigmoid > 0.5 => fake
22
+
23
+ def build_meso4() -> tf.keras.Model:
24
+ x = Input(shape=(TARGET_SIZE[0], TARGET_SIZE[1], 3))
25
+ x1 = Conv2D(8, (3, 3), padding='same', activation='relu')(x)
26
+ x1 = BatchNormalization()(x1)
27
+ x1 = MaxPooling2D(pool_size=(2, 2), padding='same')(x1)
28
+
29
+ x2 = Conv2D(8, (5, 5), padding='same', activation='relu')(x1)
30
+ x2 = BatchNormalization()(x2)
31
+ x2 = MaxPooling2D(pool_size=(2, 2), padding='same')(x2)
32
 
33
+ x3 = Conv2D(16, (5, 5), padding='same', activation='relu')(x2)
34
+ x3 = BatchNormalization()(x3)
35
+ x3 = MaxPooling2D(pool_size=(2, 2), padding='same')(x3)
 
 
36
 
37
+ x4 = Conv2D(16, (5, 5), padding='same', activation='relu')(x3)
38
+ x4 = BatchNormalization()(x4)
39
+ x4 = MaxPooling2D(pool_size=(4, 4), padding='same')(x4)
40
+
41
+ y = Flatten()(x4)
42
+ y = Dropout(0.5)(y)
43
+ y = Dense(16)(y)
44
+ y = LeakyReLU(alpha=0.1)(y)
45
+ y = Dropout(0.5)(y)
46
+ y = Dense(1, activation='sigmoid')(y)
47
+
48
+ model = Model(inputs=x, outputs=y)
49
+ model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
50
+ return model
51
 
52
  def _load_model():
53
+ """Load a full Keras model if present; otherwise build Meso4 and load weights."""
54
+ global MODEL, IS_MESO
55
+ # 1) Full model (optional)
56
  if os.path.exists(CUSTOM_MODEL_PATH):
57
  try:
58
  MODEL = tf.keras.models.load_model(CUSTOM_MODEL_PATH, compile=False)
59
+ IS_MESO = False
 
60
  print("Loaded custom model from model.h5")
61
  return
62
  except Exception as e:
63
+ print("Failed to load model.h5:", e)
64
+
65
+ # 2) Meso4 + weights (your case)
66
+ if os.path.exists(MESO_WEIGHTS_PATH):
67
+ MODEL = build_meso4()
68
+ MODEL.load_weights(MESO_WEIGHTS_PATH)
69
+ IS_MESO = True
70
+ print("Loaded Meso4 with weights:", MESO_WEIGHTS_PATH)
71
+ return
72
 
73
+ # 3) Hard fail (don’t silently switch to ImageNet; this is a deepfake app)
74
+ raise RuntimeError(
75
+ "No model found. Upload either model.h5 or weights/Meso4_DF to the Space."
76
+ )
 
77
 
78
  def _preprocess(img: Image.Image) -> np.ndarray:
 
79
  img = img.convert("RGB").resize(TARGET_SIZE)
80
+ arr = np.array(img).astype("float32") / 255.0
81
+ return np.expand_dims(arr, axis=0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
  def predict(image: Image.Image):
84
  if image is None:
85
+ return {"real": 0.0, "fake": 0.0}, None, "Upload an image."
86
  x = _preprocess(image)
87
+ prob_fake = float(MODEL.predict(x, verbose=0)[0][0])
88
+ prob_real = 1.0 - prob_fake
89
+ label = "fake" if prob_fake >= THRESHOLD else "real"
90
+ msg = f"Prediction: {label.upper()} | fake={prob_fake:.2f}, real={prob_real:.2f}"
91
+ # Return both scores for the Label component
92
+ return {"real": prob_real, "fake": prob_fake}, image, msg
93
 
94
+ # Init
95
  _load_model()
96
 
97
+ with gr.Blocks(title="Deepfake Detector (Meso4)") as demo:
98
+ gr.Markdown("# Deepfake Detector (Meso4)\n"
99
+ "Upload a face image (or a frame from a video). The model outputs real vs fake.")
 
100
 
101
  with gr.Row():
102
  with gr.Column(scale=1):
103
  inp = gr.Image(type="pil", label="Upload image")
104
  btn = gr.Button("Predict")
105
  with gr.Column(scale=1):
106
+ out_label = gr.Label(num_top_classes=2, label="Scores")
107
  out_img = gr.Image(type="pil", label="Preview")
108
+ out_text = gr.Markdown()
109
 
110
+ btn.click(fn=predict, inputs=inp, outputs=[out_label, out_img, out_text])
111
 
112
  if __name__ == "__main__":
113
  demo.launch()