File size: 5,025 Bytes
6ad4764
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import os, io, pathlib, urllib.request
import numpy as np
import streamlit as st
from PIL import Image
from matplotlib import cm

st.write("### ✅ Voice Guard Streamlit — env-only v4 (no st.secrets)")

# ---- import Detector from app/ or src/ ----
Detector, _last_err = None, None
for mod in ["app.inference_wav2vec", "app.inference",
            "src.inference_wav2vec", "src.inference"]:
    try:
        Detector = __import__(mod, fromlist=["Detector"]).Detector
        break
    except Exception as e:
        _last_err = e
if Detector is None:
    st.error(f"Could not import Detector from app/ or src/. Last error: {_last_err}")
    st.stop()

# ---- ENV config only ----
def cfg(name: str, default: str = "") -> str:
    v = os.getenv(name)
    return v if v not in (None, "") else default

def ensure_weights() -> str:
    wp  = cfg("MODEL_WEIGHTS_PATH", "app/models/weights/wav2vec2_classifier.pth")
    url = cfg("MODEL_WEIGHTS_URL", "")
    dest = pathlib.Path(wp)
    if not dest.exists() and url:
        dest.parent.mkdir(parents=True, exist_ok=True)
        with st.spinner(f"Downloading model weights to {dest} …"):
            urllib.request.urlretrieve(url, str(dest))
            st.toast("Weights downloaded", icon="✅")
    if not dest.exists() and not url:
        st.warning(
            f"Model weights not found at '{wp}'. "
            "Upload the .pth there OR set MODEL_WEIGHTS_URL in Settings → Variables & secrets."
        )
    return str(dest)

@st.cache_resource(show_spinner=True)
def load_detector():
    return Detector(weights_path=ensure_weights())

det = load_detector()

# ---- helpers ----
def cam_to_png_bytes(cam: np.ndarray) -> bytes:
    cam = np.asarray(cam, dtype=np.float32)
    cam = np.nan_to_num(cam, nan=0.0); cam = np.clip(cam, 0.0, 1.0)
    rgb = (cm.magma(cam)[..., :3] * 255).astype(np.uint8)
    buf = io.BytesIO(); Image.fromarray(rgb).save(buf, "PNG")
    return buf.getvalue()

def analyze(wav_bytes: bytes, source_hint: str):
    proba = det.predict_proba(wav_bytes, source_hint=source_hint)
    exp   = det.explain(wav_bytes, source_hint=source_hint)
    return proba, exp

# ---- UI ----
st.set_page_config(page_title="Voice Guard", page_icon="🛡️", layout="wide")
st.title("🛡️ Voice Guard — Human vs AI Speech")

left, right = st.columns([1,2], gap="large")
with left:
    st.subheader("Input")
    tab_rec, tab_up = st.tabs(["🎙️ Microphone", "📁 Upload"])
    wav_bytes, source_hint = None, None

    with tab_rec:
        st.caption("Record ~3–7 s. If mic fails, use Upload.")
        try:
            from audio_recorder_streamlit import audio_recorder
            audio = audio_recorder(text="Record",
                                   recording_color="#ff6a00",
                                   neutral_color="#2b2b2b",
                                   icon_size="2x")
            if audio:
                wav_bytes, source_hint = audio, "microphone"
                st.audio(wav_bytes, format="audio/wav")
        except Exception:
            st.info("Recorder not available—use Upload tab.")

    with tab_up:
        f = st.file_uploader("Upload wav/mp3/m4a/aac", type=["wav","mp3","m4a","aac"])
        if f:
            wav_bytes, source_hint = f.read(), "upload"
            st.audio(wav_bytes)

    st.markdown("---")
    run = st.button("🔍 Analyze", type="primary", use_container_width=True,
                    disabled=wav_bytes is None)

with right:
    st.subheader("Results")
    if run and wav_bytes:
        try:
            with st.spinner("Analyzing…"):
                proba, exp = analyze(wav_bytes, source_hint or "auto")
            ph = float(proba.get("human",0.0)); pa = float(proba.get("ai",0.0))
            label = (proba.get("label","human") or "human").upper()
            thr = float(proba.get("threshold",0.5))
            rule = proba.get("decision","threshold")
            thr_src = proba.get("threshold_source","—")
            rscore = proba.get("replay_score", None)

            c1,c2,c3 = st.columns(3)
            with c1: st.metric("Human", f"{ph*100:.1f}%")
            with c2: st.metric("AI", f"{pa*100:.1f}%")
            with c3:
                color = "#22c55e" if label=="HUMAN" else "#fb7185"
                st.markdown(f"**Final Label:** <span style='color:{color}'>{label}</span>", unsafe_allow_html=True)
                st.caption(f"thr({thr_src})={thr:.2f} • rule={rule} • replay={'—' if rscore is None else f'{float(rscore):.2f}'}")

            st.markdown("##### Explanation Heatmap")
            cam = np.asarray(exp.get("cam"), dtype=np.float32)
            st.image(cam_to_png_bytes(cam), caption="Spectrogram importance", use_column_width=True)

            with st.expander("Raw JSON (debug)"):
                st.json({"proba": proba, "explain": {"cam_shape": list(cam.shape)}})
        except Exception as e:
            st.error(f"Analyze failed: {e}")

st.caption("Upload 3–7s clips for the most reliable experience across browsers.")