AiAnber commited on
Commit
0d7fb76
·
verified ·
1 Parent(s): b745e60

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +255 -37
src/streamlit_app.py CHANGED
@@ -1,40 +1,258 @@
1
- import altair as alt
2
- import numpy as np
 
3
  import pandas as pd
4
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
1
+ import os
2
+ import time
3
+ from datetime import datetime
4
  import pandas as pd
5
  import streamlit as st
6
+ from pathlib import Path
7
+ from typing import Dict, List, Tuple
8
+ import langdetect
9
+
10
+ # Optional ML imports
11
+ try:
12
+ from transformers import pipeline, Pipeline
13
+ HF_AVAILABLE = True
14
+ except Exception:
15
+ HF_AVAILABLE = False
16
+
17
+ from pydub import AudioSegment
18
+ import altair as alt
19
+
20
+ # -----------------------------------------------------------
21
+ # CONFIGURATION
22
+ # -----------------------------------------------------------
23
+ st.set_page_config(
24
+ page_title="🕵🏻Speech Threat Detection Dashboard",
25
+ layout="wide",
26
+ initial_sidebar_state="expanded",
27
+ )
28
+
29
+ # Styling header
30
+ st.markdown("""
31
+ <style>
32
+ .big-font { font-size:32px; font-weight:700; }
33
+ .muted { color: #9AA0A6; }
34
+ .card { background: linear-gradient(135deg, rgba(10,25,47,0.95), rgba(23,43,77,0.95)); padding: 18px; border-radius: 12px; color: white; box-shadow: 0 6px 30px rgba(8,10,20,0.45); }
35
+ </style>
36
+ """, unsafe_allow_html=True)
37
+
38
+ st.markdown('<div class="card"><span class="big-font">Speech Threat Detection Dashboard</span> <span class="muted"> — upload audio or paste text </span></div>', unsafe_allow_html=True)
39
+
40
+ UPLOAD_DIR = Path("uploads")
41
+ DB_CSV = Path("db.csv")
42
+ UPLOAD_DIR.mkdir(exist_ok=True)
43
+ if not DB_CSV.exists():
44
+ pd.DataFrame(columns=["timestamp","filename","mode","transcription","predicted_label","scores"]).to_csv(DB_CSV, index=False)
45
+
46
+ LABELS = [
47
+ "physical threat",
48
+ "cyber threat",
49
+ "hate speech",
50
+ "political extremist threat",
51
+ "neutral"
52
+ ]
53
+
54
+ LABEL_MAP = {
55
+ "LABEL_0": "hate speech",
56
+ "LABEL_1": "self-harm",
57
+ "LABEL_2": "cyber threat",
58
+ "LABEL_3": "neutral / daily life",
59
+ "LABEL_4": "physical threat",
60
+ "LABEL_5": "political extremist threat"
61
+ }
62
+
63
+ # -----------------------------------------------------------
64
+ # HELPER FUNCTIONS
65
+ # -----------------------------------------------------------
66
+ def save_audio_file(uploaded_file) -> Path:
67
+ filename = f"{int(time.time())}_{uploaded_file.name}"
68
+ out_path = UPLOAD_DIR / filename
69
+ with open(out_path, "wb") as f:
70
+ f.write(uploaded_file.read())
71
+ return out_path
72
+
73
+ def normalize_audio_to_wav(path: Path) -> Path:
74
+ sound = AudioSegment.from_file(path)
75
+ sound = sound.set_frame_rate(16000).set_channels(1).set_sample_width(2)
76
+ wav_path = path.with_suffix(".wav")
77
+ sound.export(wav_path, format="wav")
78
+ return wav_path
79
+
80
+ @st.cache_resource(show_spinner=False)
81
+ def get_asr_pipeline() -> Tuple[str, "Pipeline"]:
82
+ """Load Hugging Face Whisper ASR model"""
83
+ asr = pipeline("automatic-speech-recognition", model="openai/whisper-large-v2")
84
+ return ("hf", asr)
85
+
86
+ def hf_transcribe_with_pipeline(asr_pipeline_tuple, path: Path, lang_choice: str = "Auto") -> str:
87
+ """Transcribe audio with Whisper and restrict to English/Urdu"""
88
+ asr_pipeline = asr_pipeline_tuple[1] # Extract the pipeline from the tuple
89
+
90
+ lang_token = None
91
+ if lang_choice == "English only":
92
+ lang_token = "<|en|>"
93
+ elif lang_choice == "Urdu only":
94
+ lang_token = "<|ur|>"
95
+
96
+ kwargs = {"generate_kwargs": {"language": lang_token}} if lang_token else {}
97
+ output = asr_pipeline(str(path), **kwargs)
98
+ text = output["text"].strip() if isinstance(output, dict) else str(output).strip()
99
+
100
+ # Restrict to English or Urdu only
101
+ try:
102
+ detected = langdetect.detect(text)
103
+ if detected not in ["en", "ur"]:
104
+ return "[❌ Unsupported language detected — please use Urdu or English.]"
105
+ except Exception:
106
+ pass
107
+
108
+ return text
109
+
110
+ def classify_text(text: str, classifier: Pipeline, labels: List[str]) -> Dict:
111
+ try:
112
+ result = classifier(text, labels, multi_label=False, hypothesis_template="This text is about {}.")
113
+ labels_out, scores_out = result["labels"], result["scores"]
114
+ top_label = labels_out[scores_out.index(max(scores_out))]
115
+ return {"label": top_label, "scores": dict(zip(labels_out, scores_out))}
116
+ except Exception as e:
117
+ st.error(f"Classification failed: {e}")
118
+ return {"label": "neutral", "scores": {}}
119
+
120
+ def log_to_db(record: Dict):
121
+ df = pd.read_csv(DB_CSV)
122
+ df = pd.concat([df, pd.DataFrame([record])], ignore_index=True)
123
+ df.to_csv(DB_CSV, index=False)
124
+
125
+ # -----------------------------------------------------------
126
+ # SIDEBAR CONFIGURATION
127
+ # -----------------------------------------------------------
128
+ st.sidebar.title("⚙️ Configuration")
129
+
130
+ asr_pipeline_tuple = get_asr_pipeline()
131
+
132
+ asr_language = st.sidebar.selectbox(
133
+ "Transcription Language Restriction",
134
+ ["Auto", "English only", "Urdu only"],
135
+ help="Restrict transcription to English or Urdu only"
136
+ )
137
+
138
+ if asr_pipeline_tuple[0] == "hf":
139
+ st.sidebar.markdown("**ASR Method:** `hf`")
140
+ else:
141
+ st.sidebar.markdown("**ASR Method:** `none` (Hugging Face models not available)")
142
+
143
+ # Only zero-shot models remain
144
+ model_choices = {
145
+ "✔Pretrained-RoBERTa": "roberta-large-mnli",
146
+ "✔Pretrained-MultiClassification": "facebook/bart-large-mnli",
147
+ "✔XLM-R": "joeddav/xlm-roberta-large-xnli"
148
+ }
149
+
150
+ model_display = st.sidebar.selectbox("Choose a Model", list(model_choices.keys()))
151
+ model_path = model_choices[model_display]
152
+
153
+ with st.sidebar:
154
+ st.markdown("---")
155
+ st.markdown("**Active Labels:**")
156
+ for lbl in LABELS:
157
+ st.markdown(f"- {lbl}")
158
+
159
+ # Load classifier
160
+ with st.spinner("Loading model..."):
161
+ try:
162
+ classifier = pipeline("zero-shot-classification", model=model_path)
163
+ except Exception as e:
164
+ st.error(f"Model load failed: {e}")
165
+ st.stop()
166
+
167
+ # -----------------------------------------------------------
168
+ # MAIN INTERFACE
169
+ # -----------------------------------------------------------
170
+ st.markdown("""
171
+ <style>
172
+ .subtitle {color:#999;}
173
+ </style>
174
+ """, unsafe_allow_html=True)
175
+
176
+ st.markdown('<div class="subtitle">Upload or enter text to detect threat categories</div>', unsafe_allow_html=True)
177
+ st.write("")
178
+
179
+ tab1, tab2 = st.tabs(["⏳ Processing", "📊 Analysis"])
180
+
181
+ # -----------------------------------------------------------
182
+ # TAB 1: PROCESSING
183
+ # -----------------------------------------------------------
184
+ with tab1:
185
+ input_mode = st.radio("Input mode", ["Upload audio", "Paste text"])
186
+ transcription_text = ""
187
+ saved_file_path = None
188
+
189
+ if input_mode == "Upload audio":
190
+ uploaded_file = st.file_uploader("Upload audio file", type=["wav","mp3","m4a","flac","ogg"])
191
+ if uploaded_file:
192
+ saved_path = save_audio_file(uploaded_file)
193
+ saved_file_path = saved_path
194
+ st.audio(saved_path)
195
+ wav_path = normalize_audio_to_wav(saved_path)
196
+ st.info("Transcribing audio...")
197
+ try:
198
+ transcription_text = hf_transcribe_with_pipeline(asr_pipeline_tuple, wav_path, asr_language)
199
+ except Exception as e:
200
+ st.error(f"Transcription failed: {e}")
201
+ else:
202
+ transcription_text = st.text_area("Enter or paste text", height=180)
203
+
204
+ if transcription_text:
205
+ st.markdown("### Transcription")
206
+ txt = st.text_area("Editable text", value=transcription_text, height=180)
207
+ if st.button("📝 Classify"):
208
+ with st.spinner("Analyzing text..."):
209
+ result = classify_text(txt, classifier, LABELS)
210
+ record = {
211
+ "timestamp": datetime.utcnow().isoformat(),
212
+ "filename": saved_file_path.name if saved_file_path else "text_input",
213
+ "mode": "audio" if saved_file_path else "text",
214
+ "transcription": txt,
215
+ "predicted_label": result["label"],
216
+ "scores": result["scores"]
217
+ }
218
+ log_to_db(record)
219
+ st.success(f"**Predicted Category:** {result['label']}")
220
+ df_scores = pd.DataFrame(result["scores"].items(), columns=["Label", "Score"]).sort_values("Score", ascending=False)
221
+ st.bar_chart(df_scores.set_index("Label"))
222
+
223
+ # -----------------------------------------------------------
224
+ # TAB 2: ANALYSIS
225
+ # -----------------------------------------------------------
226
+ with tab2:
227
+ st.subheader("📈📊 Analytical Overview")
228
+ if not DB_CSV.exists() or os.path.getsize(DB_CSV) == 0:
229
+ st.info("No data available yet. Run a few classifications first.")
230
+ else:
231
+ df = pd.read_csv(DB_CSV)
232
+ if df.empty:
233
+ st.info("No records yet.")
234
+ else:
235
+ st.metric("Total Records", len(df))
236
+ cat_counts = df["predicted_label"].value_counts().reset_index()
237
+ cat_counts.columns = ["Label", "Count"]
238
+ chart = alt.Chart(cat_counts).mark_bar().encode(
239
+ x="Label:N", y="Count:Q", tooltip=["Label", "Count"]
240
+ ).properties(height=300)
241
+ st.altair_chart(chart, use_container_width=True)
242
+
243
+ st.markdown("### Recent Entries")
244
+ st.dataframe(df.sort_values("timestamp", ascending=False).head(30))
245
+
246
+ st.markdown("### Upload Trends Over Time")
247
+ df["ts_day"] = pd.to_datetime(df["timestamp"], errors="coerce").dt.date
248
+ ts = df.groupby(["ts_day","predicted_label"]).size().reset_index(name="count")
249
+ line_chart = alt.Chart(ts).mark_line(point=True).encode(
250
+ x="ts_day:T", y="count:Q", color="predicted_label:N"
251
+ ).properties(height=300)
252
+ st.altair_chart(line_chart, use_container_width=True)
253
+
254
+ csv = df.to_csv(index=False).encode("utf-8")
255
+ st.download_button("⬇️ Download Full Log", csv, "threat_log.csv", "text/csv")
256
 
257
+ st.markdown("---")
258
+ st.caption("© 2025 Intelligence Threat Detection Suite. Built for multilingual zero-shot NLP analysis.")