import os os.environ["HF_HOME"] = "/tmp/huggingface" os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface/transformers" os.environ["HF_DATASETS_CACHE"] = "/tmp/huggingface/datasets" import os from pathlib import Path # --- Fix Streamlit permission issue --- os.environ["STREAMLIT_CACHE_DIR"] = "/tmp/streamlit_cache" os.environ["STREAMLIT_RUNTIME_DIR"] = "/tmp/streamlit_runtime" os.environ["STREAMLIT_CONFIG_DIR"] = "/tmp/.streamlit" Path("/tmp/.streamlit").mkdir(parents=True, exist_ok=True) import time from datetime import datetime import pandas as pd import streamlit as st from pathlib import Path from typing import Dict, List, Tuple import langdetect # Optional ML imports try: from transformers import pipeline try: from transformers.pipelines import Pipeline except ImportError: Pipeline = object # fallback HF_AVAILABLE = True except Exception as e: pipeline = None Pipeline = object # ensure name exists HF_AVAILABLE = False st.error(f"Transformers unavailable: {e}") from pydub import AudioSegment import altair as alt # ----------------------------------------------------------- # CONFIGURATION # ----------------------------------------------------------- st.set_page_config( page_title="🕵🏻Speech Threat Detection Dashboard", layout="wide", initial_sidebar_state="expanded", ) # Styling header st.markdown(""" """, unsafe_allow_html=True) st.markdown('
Speech Threat Detection Dashboard — upload audio or paste text
', unsafe_allow_html=True) UPLOAD_DIR = Path("/tmp/uploads") UPLOAD_DIR.mkdir(parents=True, exist_ok=True) DB_CSV = Path("/tmp/db.csv") if not DB_CSV.exists(): pd.DataFrame(columns=["timestamp","filename","mode","transcription","predicted_label","scores"]).to_csv(DB_CSV, index=False) LABELS = [ "physical threat", "cyber threat", "hate speech", "political extremist threat", "neutral" ] LABEL_MAP = { "LABEL_0": "hate speech", "LABEL_1": "self-harm", "LABEL_2": "cyber threat", "LABEL_3": "neutral / daily life", "LABEL_4": "physical threat", "LABEL_5": "political extremist threat" } # ----------------------------------------------------------- # HELPER FUNCTIONS # ----------------------------------------------------------- def save_audio_file(uploaded_file) -> Path: filename = f"{int(time.time())}_{uploaded_file.name}" out_path = UPLOAD_DIR / filename with open(out_path, "wb") as f: f.write(uploaded_file.read()) return out_path def normalize_audio_to_wav(path: Path) -> Path: sound = AudioSegment.from_file(path) sound = sound.set_frame_rate(16000).set_channels(1).set_sample_width(2) #wav_path = path.with_suffix(".wav") wav_path = UPLOAD_DIR / f"{path.stem}.wav" sound.export(wav_path, format="wav") return wav_path @st.cache_resource(show_spinner=False) def get_asr_pipeline() -> Tuple[str, "Pipeline"]: """Load Hugging Face Whisper ASR model""" asr = pipeline("automatic-speech-recognition", model="openai/whisper-large-v2") return ("hf", asr) # def hf_transcribe_with_pipeline(asr_pipeline: Pipeline, path: Path) -> str: # output = asr_pipeline(str(path)) # return output["text"].strip() if isinstance(output, dict) else str(output).strip() def get_classifier_pipeline(model_name: str): """Load zero-shot or custom classifier""" try: if model_name == "custom_xlm_roberta": classifier = pipeline( "text-classification", model="AiAnber/xlm-roberta-threat-detector", tokenizer="AiAnber/xlm-roberta-threat-detector", return_all_scores=True ) else: classifier = pipeline("zero-shot-classification", model=model_name) return classifier except Exception as e: st.error(f"Model load failed: {e}") st.stop() def hf_transcribe_with_pipeline(asr_pipeline_tuple, path: Path, lang_choice: str = "Auto") -> str: """Transcribe audio with Whisper and restrict to English/Urdu""" asr_pipeline = asr_pipeline_tuple[1] # Extract the pipeline from the tuple lang_token = None if lang_choice == "English only": lang_token = "<|en|>" elif lang_choice == "Urdu only": lang_token = "<|ur|>" kwargs = {"generate_kwargs": {"language": lang_token}} if lang_token else {} output = asr_pipeline(str(path), **kwargs) text = output["text"].strip() if isinstance(output, dict) else str(output).strip() # Restrict to English or Urdu only try: detected = langdetect.detect(text) if detected not in ["en", "ur"]: return "[❌ Unsupported language detected — please use Urdu or English.]" except Exception: pass return text def classify_text(text: str, classifier, labels: List[str]) -> Dict: try: if "zero-shot" in classifier.task: # For zero-shot models like RoBERTa or BART result = classifier(text, labels, multi_label=False, hypothesis_template="This text is about {}.") labels_out, scores_out = result["labels"], result["scores"] else: # For custom fine-tuned text classification models outputs = classifier(text) # Handle both single and batch outputs if isinstance(outputs, list): outputs = outputs[0] # unwrap batch if isinstance(outputs, list): # Handle return_all_scores=True (list of dicts) labels_out = [LABEL_MAP.get(o["label"], o["label"]) for o in outputs] scores_out = [o["score"] for o in outputs] else: # Single dict output labels_out = [LABEL_MAP.get(outputs["label"], outputs["label"])] scores_out = [outputs["score"]] # Pick the top scoring label top_label = labels_out[scores_out.index(max(scores_out))] return {"label": top_label, "scores": dict(zip(labels_out, scores_out))} except Exception as e: st.error(f"Classification failed: {e}") return {"label": "neutral", "scores": {}} def log_to_db(record: Dict): df = pd.read_csv(DB_CSV) df = pd.concat([df, pd.DataFrame([record])], ignore_index=True) df.to_csv(DB_CSV, index=False) # ----------------------------------------------------------- # SIDEBAR CONFIGURATION # ----------------------------------------------------------- st.sidebar.title("⚙️ Configuration") asr_pipeline_tuple = get_asr_pipeline() # Get the tuple asr_language = st.sidebar.selectbox( "Transcription Language Restriction", ["Auto", "English only", "Urdu only"], help="Restrict transcription to English or Urdu only" ) if asr_pipeline_tuple[0] == "hf": # Check the method st.sidebar.markdown("**ASR Method:** `hf`") else: st.sidebar.markdown("**ASR Method:** `none` (Hugging Face models not available)") # Model selection mode model_type = st.sidebar.radio("Select Model Type", ["Zero-shot Models", "Custom Models"]) if model_type == "Zero-shot Models": model_choices = { "✔Pretrained-RoBERTa": "roberta-large-mnli", "✔Pretrained-MultiClassification": "facebook/bart-large-mnli", "✔XLM-R": "joeddav/xlm-roberta-large-xnli" } else: model_choices = { "✔IB: XLM-R-Fine-tuned": "AiAnber/xlm-roberta-threat-detector" } model_display = st.sidebar.selectbox("Choose a Model", list(model_choices.keys())) model_path = model_choices[model_display] with st.sidebar: st.markdown("---") st.markdown("**Active Labels:**") for lbl in LABELS: st.markdown(f"- {lbl}") # Load classifier with st.spinner("Loading model..."): try: if model_type == "Custom Models": classifier = pipeline("text-classification", model=model_path, tokenizer=model_path, return_all_scores=True) else: classifier = pipeline("zero-shot-classification", model=model_path) except Exception as e: st.error(f"Model load failed: {e}") st.stop() # ----------------------------------------------------------- # MAIN INTERFACE # ----------------------------------------------------------- st.markdown(""" """, unsafe_allow_html=True) st.markdown('
Upload or enter text to detect threat categories
', unsafe_allow_html=True) st.write("") tab1, tab2 = st.tabs(["⏳ Processing", "📊 Analysis"]) # ----------------------------------------------------------- # TAB 1: PROCESSING # ----------------------------------------------------------- with tab1: input_mode = st.radio("Input mode", ["Upload audio", "Paste text"]) transcription_text = "" saved_file_path = None if input_mode == "Upload audio": uploaded_file = st.file_uploader("Upload audio file", type=["wav","mp3","m4a","flac","ogg"]) if uploaded_file: saved_path = save_audio_file(uploaded_file) saved_file_path = saved_path st.audio(saved_path) wav_path = normalize_audio_to_wav(saved_path) st.info("Transcribing audio...") try: if asr_pipeline_tuple[0] == "hf": # Check the method transcription_text = hf_transcribe_with_pipeline(asr_pipeline_tuple, wav_path) # Pass the tuple else: st.warning("No ASR available.") except Exception as e: st.error(f"Transcription failed: {e}") else: transcription_text = st.text_area("Enter or paste text", height=180) if transcription_text: st.markdown("### Transcription") txt = st.text_area("Editable text", value=transcription_text, height=180) if st.button("📝 Classify"): with st.spinner("Analyzing text..."): result = classify_text(txt, classifier, LABELS) record = { "timestamp": datetime.utcnow().isoformat(), "filename": saved_file_path.name if saved_file_path else "text_input", "mode": "audio" if saved_file_path else "text", "transcription": txt, "predicted_label": result["label"], "scores": result["scores"] } log_to_db(record) st.success(f"**Predicted Category:** {result['label']}") df_scores = pd.DataFrame(result["scores"].items(), columns=["Label", "Score"]).sort_values("Score", ascending=False) st.bar_chart(df_scores.set_index("Label")) # ----------------------------------------------------------- # TAB 2: ANALYSIS # ----------------------------------------------------------- with tab2: st.subheader("📈📊 Analytical Overview") if not DB_CSV.exists() or os.path.getsize(DB_CSV) == 0: st.info("No data available yet. Run a few classifications first.") else: df = pd.read_csv(DB_CSV) if df.empty: st.info("No records yet.") else: st.metric("Total Records", len(df)) cat_counts = df["predicted_label"].value_counts().reset_index() cat_counts.columns = ["Label", "Count"] chart = alt.Chart(cat_counts).mark_bar().encode( x="Label:N", y="Count:Q", tooltip=["Label", "Count"] ).properties(height=300) st.altair_chart(chart, use_container_width=True) st.markdown("### Recent Entries") st.dataframe(df.sort_values("timestamp", ascending=False).head(30)) st.markdown("### Upload Trends Over Time") df["ts_day"] = pd.to_datetime(df["timestamp"], errors="coerce").dt.date ts = df.groupby(["ts_day","predicted_label"]).size().reset_index(name="count") line_chart = alt.Chart(ts).mark_line(point=True).encode( x="ts_day:T", y="count:Q", color="predicted_label:N" ).properties(height=300) st.altair_chart(line_chart, use_container_width=True) csv = df.to_csv(index=False).encode("utf-8") st.download_button("⬇️ Download Full Log", csv, "threat_log.csv", "text/csv") st.markdown("---") st.caption("© 2025 — Intelligence Threat Detection Suite. Built for multilingual speech analysis.")