File size: 4,866 Bytes
6ecef58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
# streamlit_app.py
import os, io, base64, urllib.request, pathlib
import numpy as np
import streamlit as st
from PIL import Image
from matplotlib import cm

# ------- wiring to your detector -------
# We prefer the wav2vec2 detector; fall back to the CNN one if needed.
BACKENDS_TRY = ["app.inference_wav2vec", "app.inference"]
Detector = None
err = None
for mod in BACKENDS_TRY:
    try:
        Detector = __import__(mod, fromlist=["Detector"]).Detector
        BREAK = True
        break
    except Exception as e:
        err = e
if Detector is None:
    st.error("Could not import Detector from app/. Make sure your repo contains app/inference_wav2vec.py (or app/inference.py).")
    st.stop()

# ------- config / weights -------
def ensure_weights():
    wp  = os.environ.get("MODEL_WEIGHTS_PATH", st.secrets.get("MODEL_WEIGHTS_PATH", "app/models/weights/wav2vec2_classifier.pth"))
    url = os.environ.get("MODEL_WEIGHTS_URL",  st.secrets.get("MODEL_WEIGHTS_URL",  ""))

    if url and not os.path.exists(wp):
        pathlib.Path(wp).parent.mkdir(parents=True, exist_ok=True)
        with st.spinner("Downloading model weights…"):
            urllib.request.urlretrieve(url, wp)
    return wp

@st.cache_resource
def load_detector():
    wp = ensure_weights()
    det = Detector(weights_path=wp)
    return det

det = load_detector()

# ------- helpers -------
def cam_to_png_bytes(cam: np.ndarray) -> bytes:
    cam = np.array(cam, dtype=np.float32)
    cam = np.clip(cam, 0.0, 1.0)
    rgb = (cm.magma(cam)[..., :3] * 255).astype(np.uint8)
    im = Image.fromarray(rgb)
    buf = io.BytesIO()
    im.save(buf, format="PNG")
    return buf.getvalue()

def analyze(wav_bytes: bytes, source_hint: str):
    proba = det.predict_proba(wav_bytes, source_hint=source_hint)
    exp   = det.explain(wav_bytes, source_hint=source_hint)
    return proba, exp

# ------- UI -------
st.set_page_config(page_title="Voice Guard", page_icon="🛡️", layout="wide")
st.title("🛡️ Voice Guard — Human vs AI Speech (Streamlit)")

left, right = st.columns([1,2])

with left:
    st.subheader("Input")
    tab_rec, tab_up = st.tabs(["🎙️ Microphone", "📁 Upload"])

    wav_bytes = None
    source_hint = None

    with tab_rec:
        st.caption("If the recorder component fails on your browser, use Upload.")
        try:
            # light, zero-config recorder component
            from audio_recorder_streamlit import audio_recorder
            audio = audio_recorder(
                text="Record",
                recording_color="#ff6a00",
                neutral_color="#2b2b2b",
                icon_size="2x",
            )
            if audio:
                wav_bytes = audio  # component returns WAV bytes
                source_hint = "microphone"
                st.audio(wav_bytes, format="audio/wav")
        except Exception:
            st.info("Recorder component not available; please use the Upload tab.")

    with tab_up:
        f = st.file_uploader("Upload an audio file (wav/mp3/m4a)", type=["wav","mp3","m4a","aac"])
        if f is not None:
            wav_bytes = f.read()
            source_hint = "upload"
            st.audio(wav_bytes)

    st.markdown("---")
    run = st.button("🔍 Analyze", use_container_width=True, type="primary", disabled=wav_bytes is None)

with right:
    st.subheader("Results")
    placeholder = st.empty()

    if run and wav_bytes:
        with st.spinner("Analyzing…"):
            proba, exp = analyze(wav_bytes, source_hint or "auto")

        ph = proba["human"]; pa = proba["ai"]
        label = proba["label"].upper()
        thr = proba.get("threshold", 0.5)
        rule = proba.get("decision", "threshold")
        rscore = proba.get("replay_score", None)
        thr_src = proba.get("threshold_source", "—")

        col1, col2, col3 = st.columns(3)
        with col1:
            st.metric("Human", f"{ph*100:.1f} %")
        with col2:
            st.metric("AI", f"{pa*100:.1f} %")
        with col3:
            color = "#22c55e" if label=="HUMAN" else "#fb7185"
            st.markdown(f"**Final Label:** <span style='color:{color}'>{label}</span>", unsafe_allow_html=True)
            st.caption(f"thr({thr_src})={thr:.2f} • rule={rule} • replay={('-' if rscore is None else f'{rscore:.2f}')}")

        st.markdown("##### Explanation Heatmap")
        cam = np.array(exp["cam"], dtype=np.float32)
        st.image(cam_to_png_bytes(cam), caption="Spectrogram importance", use_column_width=True)

        st.markdown("---")
        with st.expander("Raw JSON (debug)"):
            st.json({"proba": proba, "explain": {"cam_shape": list(cam.shape)}})

st.caption("Tip: If the mic recorder fails, upload a short 3–7s clip instead.")