Spaces:
Runtime error
Runtime error
| import importlib | |
| from types import SimpleNamespace | |
| import gradio as gr | |
| import pandas as pd | |
| # import spaces | |
| import torch | |
| from utmosv2.utils import get_dataset, get_model | |
| description = ( | |
| "# π UTMOSv2 demo\n\n" | |
| "This is a demonstration of MOS prediction using UTMOSv2. " | |
| "This demonstration only accepts `.wav` format. Best at 16 kHz sampling rate." | |
| ) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| config = importlib.import_module("utmosv2.config.fusion_stage3") | |
| cfg = SimpleNamespace(**{attr: getattr(config, attr) for attr in config.__dict__ if not attr.startswith("__")}) | |
| cfg.reproduce = False | |
| cfg.config = "fusion_stage3" | |
| cfg.print_config = False | |
| cfg.data_config = None | |
| cfg.phase = "inference" | |
| cfg.weight = None | |
| cfg.num_workers = 1 | |
| # @spaces.GPU | |
| def predict_mos(audio_path: str, domain: str) -> float: | |
| data = pd.DataFrame({"file_path": [audio_path]}) | |
| data["dataset"] = domain | |
| data['mos'] = 0 | |
| preds = 0.0 | |
| for fold in range(5): | |
| cfg.now_fold = fold | |
| model = get_model(cfg, device) | |
| for _ in range(5): | |
| test_dataset = get_dataset(cfg, data, "test") | |
| p = model(*[torch.tensor(t).unsqueeze(0) for t in test_dataset[0][:-1]]) | |
| preds += p[0] | |
| preds /= 25.0 | |
| return preds | |
| with gr.Blocks() as demo: | |
| gr.Markdown(description) | |
| with gr.Row(): | |
| with gr.Column(): | |
| audio = gr.Audio(type="filepath", label="Audio") | |
| domain = gr.Dropdown( | |
| [ | |
| "sarulab", | |
| "bvcc", | |
| "somos", | |
| "blizzard2008", | |
| "blizzard2009", | |
| "blizzard2010-EH1", | |
| "blizzard2010-EH2", | |
| "blizzard2010-ES1", | |
| "blizzard2010-ES3", | |
| "blizzard2011", | |
| ], | |
| label="Data-domain ID for the MOS prediction", | |
| ) | |
| submit = gr.Button(value="Submit") | |
| with gr.Column(): | |
| output = gr.Textbox(label="Predicted MOS", type="text") | |
| submit.click(fn=predict_mos, inputs=[audio, domain], outputs=[output]) | |
| demo.queue().launch() |