Spaces:
Sleeping
Sleeping
File size: 4,866 Bytes
6ecef58 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
# streamlit_app.py
import os, io, base64, urllib.request, pathlib
import numpy as np
import streamlit as st
from PIL import Image
from matplotlib import cm
# ------- wiring to your detector -------
# We prefer the wav2vec2 detector; fall back to the CNN one if needed.
BACKENDS_TRY = ["app.inference_wav2vec", "app.inference"]
Detector = None
err = None
for mod in BACKENDS_TRY:
try:
Detector = __import__(mod, fromlist=["Detector"]).Detector
BREAK = True
break
except Exception as e:
err = e
if Detector is None:
st.error("Could not import Detector from app/. Make sure your repo contains app/inference_wav2vec.py (or app/inference.py).")
st.stop()
# ------- config / weights -------
def ensure_weights():
wp = os.environ.get("MODEL_WEIGHTS_PATH", st.secrets.get("MODEL_WEIGHTS_PATH", "app/models/weights/wav2vec2_classifier.pth"))
url = os.environ.get("MODEL_WEIGHTS_URL", st.secrets.get("MODEL_WEIGHTS_URL", ""))
if url and not os.path.exists(wp):
pathlib.Path(wp).parent.mkdir(parents=True, exist_ok=True)
with st.spinner("Downloading model weights…"):
urllib.request.urlretrieve(url, wp)
return wp
@st.cache_resource
def load_detector():
wp = ensure_weights()
det = Detector(weights_path=wp)
return det
det = load_detector()
# ------- helpers -------
def cam_to_png_bytes(cam: np.ndarray) -> bytes:
cam = np.array(cam, dtype=np.float32)
cam = np.clip(cam, 0.0, 1.0)
rgb = (cm.magma(cam)[..., :3] * 255).astype(np.uint8)
im = Image.fromarray(rgb)
buf = io.BytesIO()
im.save(buf, format="PNG")
return buf.getvalue()
def analyze(wav_bytes: bytes, source_hint: str):
proba = det.predict_proba(wav_bytes, source_hint=source_hint)
exp = det.explain(wav_bytes, source_hint=source_hint)
return proba, exp
# ------- UI -------
st.set_page_config(page_title="Voice Guard", page_icon="🛡️", layout="wide")
st.title("🛡️ Voice Guard — Human vs AI Speech (Streamlit)")
left, right = st.columns([1,2])
with left:
st.subheader("Input")
tab_rec, tab_up = st.tabs(["🎙️ Microphone", "📁 Upload"])
wav_bytes = None
source_hint = None
with tab_rec:
st.caption("If the recorder component fails on your browser, use Upload.")
try:
# light, zero-config recorder component
from audio_recorder_streamlit import audio_recorder
audio = audio_recorder(
text="Record",
recording_color="#ff6a00",
neutral_color="#2b2b2b",
icon_size="2x",
)
if audio:
wav_bytes = audio # component returns WAV bytes
source_hint = "microphone"
st.audio(wav_bytes, format="audio/wav")
except Exception:
st.info("Recorder component not available; please use the Upload tab.")
with tab_up:
f = st.file_uploader("Upload an audio file (wav/mp3/m4a)", type=["wav","mp3","m4a","aac"])
if f is not None:
wav_bytes = f.read()
source_hint = "upload"
st.audio(wav_bytes)
st.markdown("---")
run = st.button("🔍 Analyze", use_container_width=True, type="primary", disabled=wav_bytes is None)
with right:
st.subheader("Results")
placeholder = st.empty()
if run and wav_bytes:
with st.spinner("Analyzing…"):
proba, exp = analyze(wav_bytes, source_hint or "auto")
ph = proba["human"]; pa = proba["ai"]
label = proba["label"].upper()
thr = proba.get("threshold", 0.5)
rule = proba.get("decision", "threshold")
rscore = proba.get("replay_score", None)
thr_src = proba.get("threshold_source", "—")
col1, col2, col3 = st.columns(3)
with col1:
st.metric("Human", f"{ph*100:.1f} %")
with col2:
st.metric("AI", f"{pa*100:.1f} %")
with col3:
color = "#22c55e" if label=="HUMAN" else "#fb7185"
st.markdown(f"**Final Label:** <span style='color:{color}'>{label}</span>", unsafe_allow_html=True)
st.caption(f"thr({thr_src})={thr:.2f} • rule={rule} • replay={('-' if rscore is None else f'{rscore:.2f}')}")
st.markdown("##### Explanation Heatmap")
cam = np.array(exp["cam"], dtype=np.float32)
st.image(cam_to_png_bytes(cam), caption="Spectrogram importance", use_column_width=True)
st.markdown("---")
with st.expander("Raw JSON (debug)"):
st.json({"proba": proba, "explain": {"cam_shape": list(cam.shape)}})
st.caption("Tip: If the mic recorder fails, upload a short 3–7s clip instead.")
|