Spaces:
Running
Running
File size: 12,764 Bytes
0d7fb76 416d782 65e46e0 416d782 0d7fb76 1d8214e 0d7fb76 e99b91c 0d7fb76 e99b91c 0d7fb76 e99b91c 0d7fb76 799b955 0d7fb76 09333ec e967da2 09333ec e967da2 0d7fb76 e967da2 0d7fb76 799b955 52b1d55 799b955 0d7fb76 799b955 0d7fb76 87a586f 0d7fb76 799b955 0d7fb76 799b955 0d7fb76 799b955 0d7fb76 799b955 0d7fb76 799b955 52b1d55 799b955 0d7fb76 799b955 0d7fb76 799b955 0d7fb76 1d8214e 0d7fb76 799b955 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 |
import os
os.environ["HF_HOME"] = "/tmp/huggingface"
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface/transformers"
os.environ["HF_DATASETS_CACHE"] = "/tmp/huggingface/datasets"
import os
from pathlib import Path
# --- Fix Streamlit permission issue ---
os.environ["STREAMLIT_CACHE_DIR"] = "/tmp/streamlit_cache"
os.environ["STREAMLIT_RUNTIME_DIR"] = "/tmp/streamlit_runtime"
os.environ["STREAMLIT_CONFIG_DIR"] = "/tmp/.streamlit"
Path("/tmp/.streamlit").mkdir(parents=True, exist_ok=True)
import time
from datetime import datetime
import pandas as pd
import streamlit as st
from pathlib import Path
from typing import Dict, List, Tuple
import langdetect
# Optional ML imports
try:
from transformers import pipeline
try:
from transformers.pipelines import Pipeline
except ImportError:
Pipeline = object # fallback
HF_AVAILABLE = True
except Exception as e:
pipeline = None
Pipeline = object # ensure name exists
HF_AVAILABLE = False
st.error(f"Transformers unavailable: {e}")
from pydub import AudioSegment
import altair as alt
# -----------------------------------------------------------
# CONFIGURATION
# -----------------------------------------------------------
st.set_page_config(
page_title="🕵🏻Speech Threat Detection Dashboard",
layout="wide",
initial_sidebar_state="expanded",
)
# Styling header
st.markdown("""
<style>
.big-font { font-size:32px; font-weight:700; }
.muted { color: #9AA0A6; }
.card { background: linear-gradient(135deg, rgba(10,25,47,0.95), rgba(23,43,77,0.95)); padding: 18px; border-radius: 12px; color: white; box-shadow: 0 6px 30px rgba(8,10,20,0.45); }
</style>
""", unsafe_allow_html=True)
st.markdown('<div class="card"><span class="big-font">Speech Threat Detection Dashboard</span> <span class="muted"> — upload audio or paste text </span></div>', unsafe_allow_html=True)
UPLOAD_DIR = Path("/tmp/uploads")
UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
DB_CSV = Path("/tmp/db.csv")
if not DB_CSV.exists():
pd.DataFrame(columns=["timestamp","filename","mode","transcription","predicted_label","scores"]).to_csv(DB_CSV, index=False)
LABELS = [
"physical threat",
"cyber threat",
"hate speech",
"political extremist threat",
"neutral"
]
LABEL_MAP = {
"LABEL_0": "hate speech",
"LABEL_1": "self-harm",
"LABEL_2": "cyber threat",
"LABEL_3": "neutral / daily life",
"LABEL_4": "physical threat",
"LABEL_5": "political extremist threat"
}
# -----------------------------------------------------------
# HELPER FUNCTIONS
# -----------------------------------------------------------
def save_audio_file(uploaded_file) -> Path:
filename = f"{int(time.time())}_{uploaded_file.name}"
out_path = UPLOAD_DIR / filename
with open(out_path, "wb") as f:
f.write(uploaded_file.read())
return out_path
def normalize_audio_to_wav(path: Path) -> Path:
sound = AudioSegment.from_file(path)
sound = sound.set_frame_rate(16000).set_channels(1).set_sample_width(2)
#wav_path = path.with_suffix(".wav")
wav_path = UPLOAD_DIR / f"{path.stem}.wav"
sound.export(wav_path, format="wav")
return wav_path
@st.cache_resource(show_spinner=False)
def get_asr_pipeline() -> Tuple[str, "Pipeline"]:
"""Load Hugging Face Whisper ASR model"""
asr = pipeline("automatic-speech-recognition", model="openai/whisper-large-v2")
return ("hf", asr)
# def hf_transcribe_with_pipeline(asr_pipeline: Pipeline, path: Path) -> str:
# output = asr_pipeline(str(path))
# return output["text"].strip() if isinstance(output, dict) else str(output).strip()
def get_classifier_pipeline(model_name: str):
"""Load zero-shot or custom classifier"""
try:
if model_name == "custom_xlm_roberta":
classifier = pipeline(
"text-classification",
model="AiAnber/xlm-roberta-threat-detector",
tokenizer="AiAnber/xlm-roberta-threat-detector",
return_all_scores=True
)
else:
classifier = pipeline("zero-shot-classification", model=model_name)
return classifier
except Exception as e:
st.error(f"Model load failed: {e}")
st.stop()
def hf_transcribe_with_pipeline(asr_pipeline_tuple, path: Path, lang_choice: str = "Auto") -> str:
"""Transcribe audio with Whisper and restrict to English/Urdu"""
asr_pipeline = asr_pipeline_tuple[1] # Extract the pipeline from the tuple
lang_token = None
if lang_choice == "English only":
lang_token = "<|en|>"
elif lang_choice == "Urdu only":
lang_token = "<|ur|>"
kwargs = {"generate_kwargs": {"language": lang_token}} if lang_token else {}
output = asr_pipeline(str(path), **kwargs)
text = output["text"].strip() if isinstance(output, dict) else str(output).strip()
# Restrict to English or Urdu only
try:
detected = langdetect.detect(text)
if detected not in ["en", "ur"]:
return "[❌ Unsupported language detected — please use Urdu or English.]"
except Exception:
pass
return text
def classify_text(text: str, classifier, labels: List[str]) -> Dict:
try:
if "zero-shot" in classifier.task:
# For zero-shot models like RoBERTa or BART
result = classifier(text, labels, multi_label=False, hypothesis_template="This text is about {}.")
labels_out, scores_out = result["labels"], result["scores"]
else:
# For custom fine-tuned text classification models
outputs = classifier(text)
# Handle both single and batch outputs
if isinstance(outputs, list):
outputs = outputs[0] # unwrap batch
if isinstance(outputs, list):
# Handle return_all_scores=True (list of dicts)
labels_out = [LABEL_MAP.get(o["label"], o["label"]) for o in outputs]
scores_out = [o["score"] for o in outputs]
else:
# Single dict output
labels_out = [LABEL_MAP.get(outputs["label"], outputs["label"])]
scores_out = [outputs["score"]]
# Pick the top scoring label
top_label = labels_out[scores_out.index(max(scores_out))]
return {"label": top_label, "scores": dict(zip(labels_out, scores_out))}
except Exception as e:
st.error(f"Classification failed: {e}")
return {"label": "neutral", "scores": {}}
def log_to_db(record: Dict):
df = pd.read_csv(DB_CSV)
df = pd.concat([df, pd.DataFrame([record])], ignore_index=True)
df.to_csv(DB_CSV, index=False)
# -----------------------------------------------------------
# SIDEBAR CONFIGURATION
# -----------------------------------------------------------
st.sidebar.title("⚙️ Configuration")
asr_pipeline_tuple = get_asr_pipeline() # Get the tuple
asr_language = st.sidebar.selectbox(
"Transcription Language Restriction",
["Auto", "English only", "Urdu only"],
help="Restrict transcription to English or Urdu only"
)
if asr_pipeline_tuple[0] == "hf": # Check the method
st.sidebar.markdown("**ASR Method:** `hf`")
else:
st.sidebar.markdown("**ASR Method:** `none` (Hugging Face models not available)")
# Model selection mode
model_type = st.sidebar.radio("Select Model Type", ["Zero-shot Models", "Custom Models"])
if model_type == "Zero-shot Models":
model_choices = {
"✔Pretrained-RoBERTa": "roberta-large-mnli",
"✔Pretrained-MultiClassification": "facebook/bart-large-mnli",
"✔XLM-R": "joeddav/xlm-roberta-large-xnli"
}
else:
model_choices = {
"✔IB: XLM-R-Fine-tuned": "AiAnber/xlm-roberta-threat-detector"
}
model_display = st.sidebar.selectbox("Choose a Model", list(model_choices.keys()))
model_path = model_choices[model_display]
with st.sidebar:
st.markdown("---")
st.markdown("**Active Labels:**")
for lbl in LABELS:
st.markdown(f"- {lbl}")
# Load classifier
with st.spinner("Loading model..."):
try:
if model_type == "Custom Models":
classifier = pipeline("text-classification", model=model_path, tokenizer=model_path, return_all_scores=True)
else:
classifier = pipeline("zero-shot-classification", model=model_path)
except Exception as e:
st.error(f"Model load failed: {e}")
st.stop()
# -----------------------------------------------------------
# MAIN INTERFACE
# -----------------------------------------------------------
st.markdown("""
<style>
.subtitle {color:#999;}
</style>
""", unsafe_allow_html=True)
st.markdown('<div class="subtitle">Upload or enter text to detect threat categories</div>', unsafe_allow_html=True)
st.write("")
tab1, tab2 = st.tabs(["⏳ Processing", "📊 Analysis"])
# -----------------------------------------------------------
# TAB 1: PROCESSING
# -----------------------------------------------------------
with tab1:
input_mode = st.radio("Input mode", ["Upload audio", "Paste text"])
transcription_text = ""
saved_file_path = None
if input_mode == "Upload audio":
uploaded_file = st.file_uploader("Upload audio file", type=["wav","mp3","m4a","flac","ogg"])
if uploaded_file:
saved_path = save_audio_file(uploaded_file)
saved_file_path = saved_path
st.audio(saved_path)
wav_path = normalize_audio_to_wav(saved_path)
st.info("Transcribing audio...")
try:
if asr_pipeline_tuple[0] == "hf": # Check the method
transcription_text = hf_transcribe_with_pipeline(asr_pipeline_tuple, wav_path) # Pass the tuple
else:
st.warning("No ASR available.")
except Exception as e:
st.error(f"Transcription failed: {e}")
else:
transcription_text = st.text_area("Enter or paste text", height=180)
if transcription_text:
st.markdown("### Transcription")
txt = st.text_area("Editable text", value=transcription_text, height=180)
if st.button("📝 Classify"):
with st.spinner("Analyzing text..."):
result = classify_text(txt, classifier, LABELS)
record = {
"timestamp": datetime.utcnow().isoformat(),
"filename": saved_file_path.name if saved_file_path else "text_input",
"mode": "audio" if saved_file_path else "text",
"transcription": txt,
"predicted_label": result["label"],
"scores": result["scores"]
}
log_to_db(record)
st.success(f"**Predicted Category:** {result['label']}")
df_scores = pd.DataFrame(result["scores"].items(), columns=["Label", "Score"]).sort_values("Score", ascending=False)
st.bar_chart(df_scores.set_index("Label"))
# -----------------------------------------------------------
# TAB 2: ANALYSIS
# -----------------------------------------------------------
with tab2:
st.subheader("📈📊 Analytical Overview")
if not DB_CSV.exists() or os.path.getsize(DB_CSV) == 0:
st.info("No data available yet. Run a few classifications first.")
else:
df = pd.read_csv(DB_CSV)
if df.empty:
st.info("No records yet.")
else:
st.metric("Total Records", len(df))
cat_counts = df["predicted_label"].value_counts().reset_index()
cat_counts.columns = ["Label", "Count"]
chart = alt.Chart(cat_counts).mark_bar().encode(
x="Label:N", y="Count:Q", tooltip=["Label", "Count"]
).properties(height=300)
st.altair_chart(chart, use_container_width=True)
st.markdown("### Recent Entries")
st.dataframe(df.sort_values("timestamp", ascending=False).head(30))
st.markdown("### Upload Trends Over Time")
df["ts_day"] = pd.to_datetime(df["timestamp"], errors="coerce").dt.date
ts = df.groupby(["ts_day","predicted_label"]).size().reset_index(name="count")
line_chart = alt.Chart(ts).mark_line(point=True).encode(
x="ts_day:T", y="count:Q", color="predicted_label:N"
).properties(height=300)
st.altair_chart(line_chart, use_container_width=True)
csv = df.to_csv(index=False).encode("utf-8")
st.download_button("⬇️ Download Full Log", csv, "threat_log.csv", "text/csv")
st.markdown("---")
st.caption("© 2025 — Intelligence Threat Detection Suite. Built for multilingual speech analysis.") |