Spaces:
Sleeping
Sleeping
| # streamlit_app.py | |
| import os, io, base64, urllib.request, pathlib | |
| import numpy as np | |
| import streamlit as st | |
| from PIL import Image | |
| from matplotlib import cm | |
| # ------- wiring to your detector ------- | |
| # We prefer the wav2vec2 detector; fall back to the CNN one if needed. | |
| BACKENDS_TRY = ["app.inference_wav2vec", "app.inference"] | |
| Detector = None | |
| err = None | |
| for mod in BACKENDS_TRY: | |
| try: | |
| Detector = __import__(mod, fromlist=["Detector"]).Detector | |
| BREAK = True | |
| break | |
| except Exception as e: | |
| err = e | |
| if Detector is None: | |
| st.error("Could not import Detector from app/. Make sure your repo contains app/inference_wav2vec.py (or app/inference.py).") | |
| st.stop() | |
| # ------- config / weights ------- | |
| def ensure_weights(): | |
| wp = os.environ.get("MODEL_WEIGHTS_PATH", st.secrets.get("MODEL_WEIGHTS_PATH", "app/models/weights/wav2vec2_classifier.pth")) | |
| url = os.environ.get("MODEL_WEIGHTS_URL", st.secrets.get("MODEL_WEIGHTS_URL", "")) | |
| if url and not os.path.exists(wp): | |
| pathlib.Path(wp).parent.mkdir(parents=True, exist_ok=True) | |
| with st.spinner("Downloading model weights…"): | |
| urllib.request.urlretrieve(url, wp) | |
| return wp | |
| def load_detector(): | |
| wp = ensure_weights() | |
| det = Detector(weights_path=wp) | |
| return det | |
| det = load_detector() | |
| # ------- helpers ------- | |
| def cam_to_png_bytes(cam: np.ndarray) -> bytes: | |
| cam = np.array(cam, dtype=np.float32) | |
| cam = np.clip(cam, 0.0, 1.0) | |
| rgb = (cm.magma(cam)[..., :3] * 255).astype(np.uint8) | |
| im = Image.fromarray(rgb) | |
| buf = io.BytesIO() | |
| im.save(buf, format="PNG") | |
| return buf.getvalue() | |
| def analyze(wav_bytes: bytes, source_hint: str): | |
| proba = det.predict_proba(wav_bytes, source_hint=source_hint) | |
| exp = det.explain(wav_bytes, source_hint=source_hint) | |
| return proba, exp | |
| # ------- UI ------- | |
| st.set_page_config(page_title="Voice Guard", page_icon="🛡️", layout="wide") | |
| st.title("🛡️ Voice Guard — Human vs AI Speech (Streamlit)") | |
| left, right = st.columns([1,2]) | |
| with left: | |
| st.subheader("Input") | |
| tab_rec, tab_up = st.tabs(["🎙️ Microphone", "📁 Upload"]) | |
| wav_bytes = None | |
| source_hint = None | |
| with tab_rec: | |
| st.caption("If the recorder component fails on your browser, use Upload.") | |
| try: | |
| # light, zero-config recorder component | |
| from audio_recorder_streamlit import audio_recorder | |
| audio = audio_recorder( | |
| text="Record", | |
| recording_color="#ff6a00", | |
| neutral_color="#2b2b2b", | |
| icon_size="2x", | |
| ) | |
| if audio: | |
| wav_bytes = audio # component returns WAV bytes | |
| source_hint = "microphone" | |
| st.audio(wav_bytes, format="audio/wav") | |
| except Exception: | |
| st.info("Recorder component not available; please use the Upload tab.") | |
| with tab_up: | |
| f = st.file_uploader("Upload an audio file (wav/mp3/m4a)", type=["wav","mp3","m4a","aac"]) | |
| if f is not None: | |
| wav_bytes = f.read() | |
| source_hint = "upload" | |
| st.audio(wav_bytes) | |
| st.markdown("---") | |
| run = st.button("🔍 Analyze", use_container_width=True, type="primary", disabled=wav_bytes is None) | |
| with right: | |
| st.subheader("Results") | |
| placeholder = st.empty() | |
| if run and wav_bytes: | |
| with st.spinner("Analyzing…"): | |
| proba, exp = analyze(wav_bytes, source_hint or "auto") | |
| ph = proba["human"]; pa = proba["ai"] | |
| label = proba["label"].upper() | |
| thr = proba.get("threshold", 0.5) | |
| rule = proba.get("decision", "threshold") | |
| rscore = proba.get("replay_score", None) | |
| thr_src = proba.get("threshold_source", "—") | |
| col1, col2, col3 = st.columns(3) | |
| with col1: | |
| st.metric("Human", f"{ph*100:.1f} %") | |
| with col2: | |
| st.metric("AI", f"{pa*100:.1f} %") | |
| with col3: | |
| color = "#22c55e" if label=="HUMAN" else "#fb7185" | |
| st.markdown(f"**Final Label:** <span style='color:{color}'>{label}</span>", unsafe_allow_html=True) | |
| st.caption(f"thr({thr_src})={thr:.2f} • rule={rule} • replay={('-' if rscore is None else f'{rscore:.2f}')}") | |
| st.markdown("##### Explanation Heatmap") | |
| cam = np.array(exp["cam"], dtype=np.float32) | |
| st.image(cam_to_png_bytes(cam), caption="Spectrogram importance", use_column_width=True) | |
| st.markdown("---") | |
| with st.expander("Raw JSON (debug)"): | |
| st.json({"proba": proba, "explain": {"cam_shape": list(cam.shape)}}) | |
| st.caption("Tip: If the mic recorder fails, upload a short 3–7s clip instead.") | |