MoulSot / app.py
01Yassine's picture
Update app.py
87e47e2 verified
raw
history blame
3.66 kB
import gradio as gr
import torchaudio
from transformers import pipeline
import soundfile as sf
import torch
# Load only the Moul-Sout-100 model
asr_pipeline = pipeline("automatic-speech-recognition", model="01Yassine/moulsot_v0.2_1000")
# Adjust generation config if necessary
asr_pipeline.model.generation_config.input_ids = asr_pipeline.model.generation_config.forced_decoder_ids
asr_pipeline.model.generation_config.forced_decoder_ids = None
def load_audio(audio_path):
"""Robustly load any audio file into (waveform, sr)"""
try:
waveform, sr = torchaudio.load(audio_path)
except Exception:
# fallback for unknown backends
data, sr = sf.read(audio_path)
waveform = torch.tensor(data, dtype=torch.float32).T
if waveform.ndim == 1:
waveform = waveform.unsqueeze(0)
return waveform, sr
def ensure_mono_16k(audio_path):
"""Convert audio to mono + 16 kHz"""
waveform, sr = load_audio(audio_path)
if waveform.shape[0] > 1:
waveform = waveform.mean(dim=0, keepdim=True)
if sr != 16000:
resampler = torchaudio.transforms.Resample(sr, 16000)
waveform = resampler(waveform)
sr = 16000
return waveform, sr
def trim_leading_silence(waveform, sr, keep_ms=100, threshold=0.01):
"""Trim leading silence, keep ≀ keep_ms ms"""
energy = waveform.abs().mean(dim=0)
non_silence_idx = (energy > threshold).nonzero(as_tuple=True)[0]
if len(non_silence_idx) == 0:
return waveform # all silence
first_non_silence = non_silence_idx[0].item()
keep_samples = int(sr * (keep_ms / 1000.0))
start = max(0, first_non_silence - keep_samples)
return waveform[:, start:]
def preprocess_audio(audio_path):
waveform, sr = ensure_mono_16k(audio_path)
waveform = trim_leading_silence(waveform, sr, keep_ms=100, threshold=0.01)
tmp_path = "/tmp/processed_trimmed.wav"
torchaudio.save(tmp_path, waveform, sr)
return tmp_path
# def ensure_mono_16k(audio_path):
# """Load audio, convert to mono + 16kHz, and save a temp version."""
# waveform, sr = torchaudio.load(audio_path)
# # Convert to mono if necessary
# if waveform.shape[0] > 1:
# waveform = waveform.mean(dim=0, keepdim=True)
# # Resample to 16kHz if necessary
# if sr != 16000:
# resampler = torchaudio.transforms.Resample(sr, 16000)
# waveform = resampler(waveform)
# sr = 16000
# tmp_path = "/tmp/processed_16k.wav"
# torchaudio.save(tmp_path, waveform, sr)
# return tmp_path
def transcribe(audio):
if audio is None:
return "Please record or upload an audio file."
# Process and transcribe
processed_audio = preprocess_audio(audio)
result = asr_pipeline(processed_audio)["text"]
return result
title = "πŸŽ™οΈ Moul-Sout ASR πŸ‡²πŸ‡¦"
description = """
**Moul-Sout** model for Darija ASR πŸ‡²πŸ‡¦.
You can record or upload an audio sample (it will be automatically resampled to 16 kHz mono),
and view the transcription result below.
"""
with gr.Blocks(title=title) as demo:
gr.Markdown(f"# {title}\n{description}")
with gr.Row():
audio_input = gr.Audio(
sources=["microphone", "upload"],
type="filepath",
label="🎀 Record or Upload Audio (auto 16 kHz mono)"
)
transcribe_btn = gr.Button("πŸš€ Transcribe")
output_text = gr.Textbox(label="🟩 Transcription Output")
transcribe_btn.click(
fn=transcribe,
inputs=[audio_input],
outputs=[output_text]
)
# Local launch
if __name__ == "__main__":
demo.launch()