File size: 4,948 Bytes
4564b05
 
5f98d9a
 
09f5d1d
 
 
4564b05
 
 
 
09f5d1d
 
 
 
 
 
4564b05
09f5d1d
 
4564b05
09f5d1d
4564b05
 
09f5d1d
 
4564b05
09f5d1d
 
4564b05
 
 
 
 
 
 
 
 
 
09f5d1d
 
 
4564b05
 
09f5d1d
 
 
4564b05
09f5d1d
 
4564b05
09f5d1d
4564b05
 
09f5d1d
 
 
4564b05
09f5d1d
 
4564b05
09f5d1d
 
 
4564b05
09f5d1d
 
4564b05
 
09f5d1d
4564b05
 
09f5d1d
 
4564b05
 
09f5d1d
4564b05
09f5d1d
 
4564b05
09f5d1d
4564b05
 
 
 
09f5d1d
 
 
4564b05
 
09f5d1d
 
 
 
 
 
 
4564b05
 
 
 
 
09f5d1d
 
4564b05
 
 
09f5d1d
4564b05
 
 
09f5d1d
 
 
4564b05
09f5d1d
 
 
 
 
5f98d9a
4564b05
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
# streamlit_app.py  — ENV-ONLY CONFIG (no st.secrets)
import os, io, pathlib, urllib.request
import numpy as np
import streamlit as st
from PIL import Image
from matplotlib import cm

# ---- import Detector from app/ or src/ ----
Detector, _last_err = None, None
for mod in ["app.inference_wav2vec", "app.inference",
            "src.inference_wav2vec", "src.inference"]:
    try:
        Detector = __import__(mod, fromlist=["Detector"]).Detector
        break
    except Exception as e:
        _last_err = e
if Detector is None:
    st.error(f"Could not import Detector from app/ or src/. Last error: {_last_err}")
    st.stop()

# ---- weights handling (ENV ONLY) ----
def cfg(name: str, default: str = "") -> str:
    v = os.getenv(name)
    return v if v not in (None, "") else default

def ensure_weights() -> str:
    wp  = cfg("MODEL_WEIGHTS_PATH", "app/models/weights/wav2vec2_classifier.pth")
    url = cfg("MODEL_WEIGHTS_URL", "")
    dest = pathlib.Path(wp)
    if not dest.exists() and url:
        dest.parent.mkdir(parents=True, exist_ok=True)
        with st.spinner(f"Downloading model weights to {dest} …"):
            urllib.request.urlretrieve(url, str(dest))
            st.toast("Weights downloaded", icon="✅")
    if not dest.exists() and not url:
        st.warning(
            f"Model weights not found at '{wp}'. "
            "Upload the .pth there OR set MODEL_WEIGHTS_URL in Settings → Variables & secrets."
        )
    return str(dest)

@st.cache_resource(show_spinner=True)
def load_detector():
    return Detector(weights_path=ensure_weights())

det = load_detector()

# ---- helpers ----
def cam_to_png_bytes(cam: np.ndarray) -> bytes:
    cam = np.asarray(cam, dtype=np.float32)
    cam = np.nan_to_num(cam, nan=0.0); cam = np.clip(cam, 0.0, 1.0)
    rgb = (cm.magma(cam)[..., :3] * 255).astype(np.uint8)
    buf = io.BytesIO(); Image.fromarray(rgb).save(buf, "PNG")
    return buf.getvalue()

def analyze(wav_bytes: bytes, source_hint: str):
    proba = det.predict_proba(wav_bytes, source_hint=source_hint)
    exp   = det.explain(wav_bytes, source_hint=source_hint)
    return proba, exp

# ---- UI ----
st.set_page_config(page_title="Voice Guard", page_icon="🛡️", layout="wide")
st.title("🛡️ Voice Guard — Human vs AI Speech")

left, right = st.columns([1,2], gap="large")
with left:
    st.subheader("Input")
    tab_rec, tab_up = st.tabs(["🎙️ Microphone", "📁 Upload"])
    wav_bytes, source_hint = None, None

    with tab_rec:
        st.caption("Record ~3–7 s. If mic fails, use Upload.")
        try:
            from audio_recorder_streamlit import audio_recorder
            audio = audio_recorder(text="Record", recording_color="#ff6a00",
                                   neutral_color="#2b2b2b", icon_size="2x")
            if audio:
                wav_bytes, source_hint = audio, "microphone"
                st.audio(wav_bytes, format="audio/wav")
        except Exception:
            st.info("Recorder not available—use Upload tab.")

    with tab_up:
        f = st.file_uploader("Upload wav/mp3/m4a/aac", type=["wav","mp3","m4a","aac"])
        if f:
            wav_bytes, source_hint = f.read(), "upload"
            st.audio(wav_bytes)

    st.markdown("---")
    run = st.button("🔍 Analyze", type="primary", use_container_width=True,
                    disabled=wav_bytes is None)

with right:
    st.subheader("Results")
    if run and wav_bytes:
        try:
            with st.spinner("Analyzing…"):
                proba, exp = analyze(wav_bytes, source_hint or "auto")
            ph = float(proba.get("human",0.0)); pa = float(proba.get("ai",0.0))
            label = (proba.get("label","human") or "human").upper()
            thr = float(proba.get("threshold",0.5))
            rule = proba.get("decision","threshold")
            thr_src = proba.get("threshold_source","—")
            rscore = proba.get("replay_score", None)

            c1,c2,c3 = st.columns(3)
            with c1: st.metric("Human", f"{ph*100:.1f}%")
            with c2: st.metric("AI", f"{pa*100:.1f}%")
            with c3:
                color = "#22c55e" if label=="HUMAN" else "#fb7185"
                st.markdown(f"**Final Label:** <span style='color:{color}'>{label}</span>", unsafe_allow_html=True)
                st.caption(f"thr({thr_src})={thr:.2f} • rule={rule} • replay={'—' if rscore is None else f'{float(rscore):.2f}'}")

            st.markdown("##### Explanation Heatmap")
            cam = np.asarray(exp.get("cam"), dtype=np.float32)
            st.image(cam_to_png_bytes(cam), caption="Spectrogram importance", use_column_width=True)

            with st.expander("Raw JSON (debug)"):
                st.json({"proba": proba, "explain": {"cam_shape": list(cam.shape)}})
        except Exception as e:
            st.error(f"Analyze failed: {e}")

st.caption("Upload 3–7s clips for the most reliable experience across browsers.")