Venkata Nagasai Kesani commited on
Commit
b5ee9ec
·
1 Parent(s): 4f89adf

Improve prediction logic and fix false positives

Browse files
Files changed (1) hide show
  1. app.py +30 -11
app.py CHANGED
@@ -1,30 +1,49 @@
1
  import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
  import torch
 
4
 
 
5
  MODEL_REPO = "fusingAIandSec/malicious-url-detector"
6
-
7
  tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO)
8
  model = AutoModelForSequenceClassification.from_pretrained(MODEL_REPO)
 
9
  labels = ["benign", "defacement", "phishing", "malware"]
10
 
 
 
 
 
 
 
 
 
11
  def predict_url(url):
 
12
  inputs = tokenizer(url, return_tensors="pt", truncation=True, padding=True)
13
  with torch.no_grad():
14
  outputs = model(**inputs)
15
- probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
16
- predicted_label = labels[probs.argmax().item()]
17
- prob_values = {labels[i]: round(float(probs[0][i]), 4) for i in range(len(labels))}
18
- return f"🧠 **Prediction:** {predicted_label}\n\n**Confidence:** {prob_values}"
 
 
 
 
 
 
 
 
 
19
 
 
20
  demo = gr.Interface(
21
  fn=predict_url,
22
- inputs=gr.Textbox(label="Enter a URL"),
23
- outputs="markdown",
24
  title="🔍 Malicious URL Detector",
25
- description="Classifies URLs as benign, defacement, phishing, or malware.",
26
- theme="default"
27
  )
28
 
29
- if __name__ == "__main__":
30
- demo.launch()
 
1
  import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
  import torch
4
+ import re
5
 
6
+ # Load model & tokenizer
7
  MODEL_REPO = "fusingAIandSec/malicious-url-detector"
 
8
  tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO)
9
  model = AutoModelForSequenceClassification.from_pretrained(MODEL_REPO)
10
+
11
  labels = ["benign", "defacement", "phishing", "malware"]
12
 
13
+ # URL normalization helper
14
+ def normalize_url(url):
15
+ url = url.strip()
16
+ if not re.match(r"^https?://", url):
17
+ url = "https://" + url
18
+ return url.lower()
19
+
20
+ # Prediction function
21
  def predict_url(url):
22
+ url = normalize_url(url)
23
  inputs = tokenizer(url, return_tensors="pt", truncation=True, padding=True)
24
  with torch.no_grad():
25
  outputs = model(**inputs)
26
+ probs = torch.nn.functional.softmax(outputs.logits, dim=-1)[0].tolist()
27
+
28
+ # Convert to readable dictionary
29
+ confidence = {labels[i]: round(float(probs[i]), 4) for i in range(len(labels))}
30
+ pred_idx = torch.argmax(outputs.logits, dim=-1).item()
31
+ pred_label = labels[pred_idx]
32
+ max_prob = max(probs)
33
+
34
+ # Apply smart threshold to reduce false phishing/defacement
35
+ if pred_label in ["phishing", "defacement", "malware"] and max_prob < 0.85:
36
+ pred_label = "benign"
37
+
38
+ return f"🧠 Prediction: {pred_label}", f"Confidence: {confidence}"
39
 
40
+ # Gradio interface
41
  demo = gr.Interface(
42
  fn=predict_url,
43
+ inputs=gr.Textbox(label="Enter a URL", placeholder="https://example.com"),
44
+ outputs=["text", "text"],
45
  title="🔍 Malicious URL Detector",
46
+ description="Classifies URLs as benign, defacement, phishing, or malware. Now with smart confidence logic!",
 
47
  )
48
 
49
+ demo.launch()