varunkul commited on
Commit
09f5d1d
·
verified ·
1 Parent(s): 0f4a497

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +187 -37
src/streamlit_app.py CHANGED
@@ -1,40 +1,190 @@
1
- import altair as alt
 
 
 
 
 
 
 
 
 
 
 
 
2
  import numpy as np
3
- import pandas as pd
4
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
1
+ # streamlit_app.py
2
+ # ------------------------------------------------------------
3
+ # Voice Guard (Streamlit) - env-only config (no st.secrets required)
4
+ # - Tries app/ then src/ for the Detector
5
+ # - Accepts mic (best-effort) or upload
6
+ # - Shows probabilities, decision details, and CAM heatmap
7
+ # - If MODEL_WEIGHTS_URL is set, downloads weights on boot when missing
8
+ # ------------------------------------------------------------
9
+
10
+ import os
11
+ import io
12
+ import pathlib
13
+ import urllib.request
14
  import numpy as np
 
15
  import streamlit as st
16
+ from PIL import Image
17
+ from matplotlib import cm
18
+
19
+ # --------------------------- Import Detector ---------------------------
20
+ Detector = None
21
+ _last_err = None
22
+ for mod in [
23
+ "app.inference_wav2vec",
24
+ "app.inference",
25
+ "src.inference_wav2vec",
26
+ "src.inference",
27
+ ]:
28
+ try:
29
+ Detector = __import__(mod, fromlist=["Detector"]).Detector
30
+ break
31
+ except Exception as e:
32
+ _last_err = e
33
+
34
+ if Detector is None:
35
+ st.error(
36
+ "Could not import Detector from app/ or src/. "
37
+ "Please include app/inference_wav2vec.py (preferred) or app/inference.py. "
38
+ f"Last import error: {_last_err}"
39
+ )
40
+ st.stop()
41
+
42
+ # ----------------------- Weights: ensure on disk -----------------------
43
+ def cfg(name: str, default: str = "") -> str:
44
+ """Read from environment only (HF Variables & Secrets are env)."""
45
+ val = os.getenv(name)
46
+ return val if val not in (None, "") else default
47
+
48
+ def ensure_weights() -> str:
49
+ """
50
+ Ensure model weights exist at MODEL_WEIGHTS_PATH.
51
+ If missing and MODEL_WEIGHTS_URL is set, download them.
52
+ """
53
+ default_path = "app/models/weights/wav2vec2_classifier.pth"
54
+ wp = cfg("MODEL_WEIGHTS_PATH", default_path)
55
+ url = cfg("MODEL_WEIGHTS_URL", "")
56
+
57
+ dest = pathlib.Path(wp)
58
+ if not dest.exists():
59
+ if url:
60
+ dest.parent.mkdir(parents=True, exist_ok=True)
61
+ with st.spinner(f"Downloading model weights to {dest} …"):
62
+ urllib.request.urlretrieve(url, str(dest))
63
+ st.toast("Weights downloaded", icon="✅")
64
+ else:
65
+ st.warning(
66
+ f"Model weights not found at '{wp}'. "
67
+ "Upload the .pth file to that path in the repo OR set MODEL_WEIGHTS_URL in "
68
+ "Settings → Variables & secrets so the app can download them."
69
+ )
70
+ return str(dest)
71
+
72
+ @st.cache_resource(show_spinner=True)
73
+ def load_detector() -> "Detector":
74
+ weights_path = ensure_weights()
75
+ det = Detector(weights_path=weights_path)
76
+ return det
77
+
78
+ det = load_detector()
79
+
80
+ # ----------------------------- Utilities -------------------------------
81
+ def cam_to_png_bytes(cam: np.ndarray) -> bytes:
82
+ """Map [H,W] float array (0..1) to magma RGB PNG bytes."""
83
+ cam = np.asarray(cam, dtype=np.float32)
84
+ cam = np.nan_to_num(cam, nan=0.0)
85
+ cam = np.clip(cam, 0.0, 1.0)
86
+ rgb = (cm.magma(cam)[..., :3] * 255).astype(np.uint8)
87
+ img = Image.fromarray(rgb)
88
+ bio = io.BytesIO()
89
+ img.save(bio, format="PNG")
90
+ return bio.getvalue()
91
+
92
+ def analyze(wav_bytes: bytes, source_hint: str):
93
+ """Call detector predict + explain; returns (proba_dict, explain_dict)."""
94
+ proba = det.predict_proba(wav_bytes, source_hint=source_hint)
95
+ exp = det.explain(wav_bytes, source_hint=source_hint)
96
+ return proba, exp
97
+
98
+ # ------------------------------- UI -----------------------------------
99
+ st.set_page_config(page_title="Voice Guard", page_icon="🛡️", layout="wide")
100
+ st.title("🛡️ Voice Guard — Human vs AI Speech")
101
+
102
+ left, right = st.columns([1, 2], gap="large")
103
+
104
+ with left:
105
+ st.subheader("Input")
106
+ tabs = st.tabs(["🎙️ Microphone", "📁 Upload"])
107
+
108
+ wav_bytes = None
109
+ source_hint = None
110
+
111
+ # Microphone tab (best effort; if not supported, use Upload)
112
+ with tabs[0]:
113
+ st.caption("Record ~3–7 seconds. If mic fails in your browser, use Upload.")
114
+ try:
115
+ from audio_recorder_streamlit import audio_recorder
116
+ audio = audio_recorder(
117
+ text="Record",
118
+ recording_color="#ff6a00",
119
+ neutral_color="#2b2b2b",
120
+ icon_size="2x",
121
+ )
122
+ if audio:
123
+ wav_bytes = audio # component returns WAV bytes
124
+ source_hint = "microphone"
125
+ st.audio(wav_bytes, format="audio/wav")
126
+ except Exception:
127
+ st.info("Recorder component not available here—please use the Upload tab.")
128
+
129
+ # Upload tab (most reliable across platforms)
130
+ with tabs[1]:
131
+ f = st.file_uploader(
132
+ "Upload an audio file (wav/mp3/m4a/aac)",
133
+ type=["wav", "mp3", "m4a", "aac"],
134
+ )
135
+ if f is not None:
136
+ wav_bytes = f.read()
137
+ source_hint = "upload"
138
+ st.audio(wav_bytes)
139
+
140
+ st.markdown("---")
141
+ run = st.button(
142
+ "🔍 Analyze", type="primary", use_container_width=True, disabled=wav_bytes is None
143
+ )
144
+
145
+ with right:
146
+ st.subheader("Results")
147
+
148
+ if run and wav_bytes:
149
+ try:
150
+ with st.spinner("Analyzing…"):
151
+ proba, exp = analyze(wav_bytes, source_hint or "auto")
152
+
153
+ ph = float(proba.get("human", 0.0))
154
+ pa = float(proba.get("ai", 0.0))
155
+ label = (proba.get("label", "human") or "human").upper()
156
+ thr = float(proba.get("threshold", 0.5))
157
+ rule = proba.get("decision", "threshold")
158
+ thr_src = proba.get("threshold_source", "—")
159
+ rscore = proba.get("replay_score", None)
160
+
161
+ c1, c2, c3 = st.columns(3)
162
+ with c1:
163
+ st.metric("Human", f"{ph*100:.1f}%")
164
+ with c2:
165
+ st.metric("AI", f"{pa*100:.1f}%")
166
+ with c3:
167
+ color = "#22c55e" if label == "HUMAN" else "#fb7185"
168
+ st.markdown(
169
+ f"**Final Label:** <span style='color:{color}'>{label}</span>",
170
+ unsafe_allow_html=True,
171
+ )
172
+ st.caption(
173
+ f"thr({thr_src})={thr:.2f} • rule={rule} • replay={'—' if rscore is None else f'{float(rscore):.2f}'}"
174
+ )
175
+
176
+ st.markdown("##### Explanation Heatmap")
177
+ cam = np.asarray(exp.get("cam"), dtype=np.float32)
178
+ st.image(
179
+ cam_to_png_bytes(cam),
180
+ caption="Spectrogram importance",
181
+ use_column_width=True,
182
+ )
183
+
184
+ with st.expander("Raw JSON (debug)"):
185
+ st.json({"proba": proba, "explain": {"cam_shape": list(cam.shape)}})
186
+
187
+ except Exception as e:
188
+ st.error(f"Analyze failed: {e}")
189
 
190
+ st.caption("Tip: Uploading a short 3–7s clip is the most reliable across browsers.")