feat: add simple prediction model with preprocessing and postprocessing functions
Browse files- app_test.py +31 -9
app_test.py
CHANGED
|
@@ -166,7 +166,38 @@ register_model_with_metadata(
|
|
| 166 |
display_name="ViT", contributor="temp", model_path=MODEL_PATHS["model_7"]
|
| 167 |
)
|
| 168 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
def infer(image: Image.Image, model_id: str, confidence_threshold: float = 0.75) -> dict:
|
| 171 |
"""Predict using a specific model.
|
| 172 |
|
|
@@ -403,15 +434,6 @@ def ensemble_prediction(img, confidence_threshold, augment_methods, rotate_degre
|
|
| 403 |
|
| 404 |
return img_pil, cleaned_forensics_images, table_rows, json_results, consensus_html
|
| 405 |
|
| 406 |
-
def simple_prediction(img):
|
| 407 |
-
client = Client("aiwithoutborders-xyz/OpenSight-Community-Forensics-Preview")
|
| 408 |
-
result = client.predict(
|
| 409 |
-
input_image=handle_file(img),
|
| 410 |
-
api_name="/simple_predict"
|
| 411 |
-
)
|
| 412 |
-
return result
|
| 413 |
-
|
| 414 |
-
|
| 415 |
detection_model_eval_playground = gr.Interface(
|
| 416 |
fn=ensemble_prediction,
|
| 417 |
inputs=[
|
|
|
|
| 166 |
display_name="ViT", contributor="temp", model_path=MODEL_PATHS["model_7"]
|
| 167 |
)
|
| 168 |
|
| 169 |
+
def preprocess_simple_prediction(image):
|
| 170 |
+
# The simple_prediction function expects a PIL image (filepath is handled internally)
|
| 171 |
+
return image
|
| 172 |
+
|
| 173 |
+
def postprocess_simple_prediction(result, class_names):
|
| 174 |
+
scores = {name: 0.0 for name in class_names}
|
| 175 |
+
fake_prob = result.get("Fake Probability")
|
| 176 |
+
if fake_prob is not None:
|
| 177 |
+
# Assume class_names = ["AI", "REAL"]
|
| 178 |
+
scores["AI"] = float(fake_prob)
|
| 179 |
+
scores["REAL"] = 1.0 - float(fake_prob)
|
| 180 |
+
return scores
|
| 181 |
|
| 182 |
+
def simple_prediction(img):
|
| 183 |
+
client = Client("aiwithoutborders-xyz/OpenSight-Community-Forensics-Preview")
|
| 184 |
+
result = client.predict(
|
| 185 |
+
input_image=handle_file(img),
|
| 186 |
+
api_name="/simple_predict"
|
| 187 |
+
)
|
| 188 |
+
return result
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
register_model_with_metadata(
|
| 192 |
+
"simple_prediction",
|
| 193 |
+
simple_prediction,
|
| 194 |
+
preprocess_simple_prediction,
|
| 195 |
+
postprocess_simple_prediction,
|
| 196 |
+
["AI", "REAL"],
|
| 197 |
+
display_name="Community Forensics",
|
| 198 |
+
contributor="Jeongsoo Park",
|
| 199 |
+
model_path="aiwithoutborders-xyz/CommunityForensics-DeepfakeDet-ViT"
|
| 200 |
+
)
|
| 201 |
def infer(image: Image.Image, model_id: str, confidence_threshold: float = 0.75) -> dict:
|
| 202 |
"""Predict using a specific model.
|
| 203 |
|
|
|
|
| 434 |
|
| 435 |
return img_pil, cleaned_forensics_images, table_rows, json_results, consensus_html
|
| 436 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 437 |
detection_model_eval_playground = gr.Interface(
|
| 438 |
fn=ensemble_prediction,
|
| 439 |
inputs=[
|