Voice-guard / streamlit_app.py
varunkul's picture
Upload 8 files
6ecef58 verified
raw
history blame
4.87 kB
# streamlit_app.py
import os, io, base64, urllib.request, pathlib
import numpy as np
import streamlit as st
from PIL import Image
from matplotlib import cm
# ------- wiring to your detector -------
# We prefer the wav2vec2 detector; fall back to the CNN one if needed.
BACKENDS_TRY = ["app.inference_wav2vec", "app.inference"]
Detector = None
err = None
for mod in BACKENDS_TRY:
try:
Detector = __import__(mod, fromlist=["Detector"]).Detector
BREAK = True
break
except Exception as e:
err = e
if Detector is None:
st.error("Could not import Detector from app/. Make sure your repo contains app/inference_wav2vec.py (or app/inference.py).")
st.stop()
# ------- config / weights -------
def ensure_weights():
wp = os.environ.get("MODEL_WEIGHTS_PATH", st.secrets.get("MODEL_WEIGHTS_PATH", "app/models/weights/wav2vec2_classifier.pth"))
url = os.environ.get("MODEL_WEIGHTS_URL", st.secrets.get("MODEL_WEIGHTS_URL", ""))
if url and not os.path.exists(wp):
pathlib.Path(wp).parent.mkdir(parents=True, exist_ok=True)
with st.spinner("Downloading model weights…"):
urllib.request.urlretrieve(url, wp)
return wp
@st.cache_resource
def load_detector():
wp = ensure_weights()
det = Detector(weights_path=wp)
return det
det = load_detector()
# ------- helpers -------
def cam_to_png_bytes(cam: np.ndarray) -> bytes:
cam = np.array(cam, dtype=np.float32)
cam = np.clip(cam, 0.0, 1.0)
rgb = (cm.magma(cam)[..., :3] * 255).astype(np.uint8)
im = Image.fromarray(rgb)
buf = io.BytesIO()
im.save(buf, format="PNG")
return buf.getvalue()
def analyze(wav_bytes: bytes, source_hint: str):
proba = det.predict_proba(wav_bytes, source_hint=source_hint)
exp = det.explain(wav_bytes, source_hint=source_hint)
return proba, exp
# ------- UI -------
st.set_page_config(page_title="Voice Guard", page_icon="🛡️", layout="wide")
st.title("🛡️ Voice Guard — Human vs AI Speech (Streamlit)")
left, right = st.columns([1,2])
with left:
st.subheader("Input")
tab_rec, tab_up = st.tabs(["🎙️ Microphone", "📁 Upload"])
wav_bytes = None
source_hint = None
with tab_rec:
st.caption("If the recorder component fails on your browser, use Upload.")
try:
# light, zero-config recorder component
from audio_recorder_streamlit import audio_recorder
audio = audio_recorder(
text="Record",
recording_color="#ff6a00",
neutral_color="#2b2b2b",
icon_size="2x",
)
if audio:
wav_bytes = audio # component returns WAV bytes
source_hint = "microphone"
st.audio(wav_bytes, format="audio/wav")
except Exception:
st.info("Recorder component not available; please use the Upload tab.")
with tab_up:
f = st.file_uploader("Upload an audio file (wav/mp3/m4a)", type=["wav","mp3","m4a","aac"])
if f is not None:
wav_bytes = f.read()
source_hint = "upload"
st.audio(wav_bytes)
st.markdown("---")
run = st.button("🔍 Analyze", use_container_width=True, type="primary", disabled=wav_bytes is None)
with right:
st.subheader("Results")
placeholder = st.empty()
if run and wav_bytes:
with st.spinner("Analyzing…"):
proba, exp = analyze(wav_bytes, source_hint or "auto")
ph = proba["human"]; pa = proba["ai"]
label = proba["label"].upper()
thr = proba.get("threshold", 0.5)
rule = proba.get("decision", "threshold")
rscore = proba.get("replay_score", None)
thr_src = proba.get("threshold_source", "—")
col1, col2, col3 = st.columns(3)
with col1:
st.metric("Human", f"{ph*100:.1f} %")
with col2:
st.metric("AI", f"{pa*100:.1f} %")
with col3:
color = "#22c55e" if label=="HUMAN" else "#fb7185"
st.markdown(f"**Final Label:** <span style='color:{color}'>{label}</span>", unsafe_allow_html=True)
st.caption(f"thr({thr_src})={thr:.2f} • rule={rule} • replay={('-' if rscore is None else f'{rscore:.2f}')}")
st.markdown("##### Explanation Heatmap")
cam = np.array(exp["cam"], dtype=np.float32)
st.image(cam_to_png_bytes(cam), caption="Spectrogram importance", use_column_width=True)
st.markdown("---")
with st.expander("Raw JSON (debug)"):
st.json({"proba": proba, "explain": {"cam_shape": list(cam.shape)}})
st.caption("Tip: If the mic recorder fails, upload a short 3–7s clip instead.")