Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| File: submit.py | |
| Author: Dmitry Ryumin, Maxim Markitantov, Elena Ryumina, Anastasia Dvoynikova, and Alexey Karpov | |
| Description: Event handler for Gradio app to submit. | |
| License: MIT License | |
| """ | |
| import spaces | |
| import torch | |
| import pandas as pd | |
| import cv2 | |
| import gradio as gr | |
| # Importing necessary components for the Gradio app | |
| from app.config import config_data | |
| from app.utils import ( | |
| Timer, | |
| convert_video_to_audio, | |
| readetect_speech, | |
| slice_audio, | |
| find_intersections, | |
| calculate_mode, | |
| find_nearest_frames, | |
| convert_webm_to_mp4, | |
| ) | |
| from app.plots import ( | |
| get_evenly_spaced_frame_indices, | |
| plot_audio, | |
| display_frame_info, | |
| plot_images, | |
| plot_predictions, | |
| ) | |
| from app.data_init import ( | |
| read_audio, | |
| get_speech_timestamps, | |
| vad_model, | |
| video_model, | |
| asr, | |
| audio_model, | |
| text_model, | |
| ) | |
| from app.load_models import VideoFeatureExtractor | |
| def event_handler_submit( | |
| video: str, | |
| ) -> tuple[ | |
| gr.Textbox, | |
| gr.Plot, | |
| gr.Plot, | |
| gr.Plot, | |
| gr.Plot, | |
| gr.Row, | |
| gr.Textbox, | |
| gr.Textbox, | |
| ]: | |
| with Timer() as timer: | |
| if video: | |
| if video.split(".")[-1] == "webm": | |
| video = convert_webm_to_mp4(video) | |
| audio_file_path = convert_video_to_audio( | |
| file_path=video, sr=config_data.General_SR | |
| ) | |
| wav, vad_info = readetect_speech( | |
| file_path=audio_file_path, | |
| read_audio=read_audio, | |
| get_speech_timestamps=get_speech_timestamps, | |
| vad_model=vad_model, | |
| sr=config_data.General_SR, | |
| ) | |
| audio_windows = slice_audio( | |
| start_time=config_data.General_START_TIME, | |
| end_time=int(len(wav)), | |
| win_max_length=int( | |
| config_data.General_WIN_MAX_LENGTH * config_data.General_SR | |
| ), | |
| win_shift=int(config_data.General_WIN_SHIFT * config_data.General_SR), | |
| win_min_length=int( | |
| config_data.General_WIN_MIN_LENGTH * config_data.General_SR | |
| ), | |
| ) | |
| intersections = find_intersections( | |
| x=audio_windows, | |
| y=vad_info, | |
| min_length=config_data.General_WIN_MIN_LENGTH * config_data.General_SR, | |
| ) | |
| vfe = VideoFeatureExtractor(video_model, file_path=video, with_features=False) | |
| vfe.preprocess_video() | |
| transcriptions, total_text = asr(wav, audio_windows) | |
| window_frames = [] | |
| preds_emo = [] | |
| preds_sen = [] | |
| for w_idx, window in enumerate(audio_windows): | |
| a_w = intersections[w_idx] | |
| if not a_w["speech"]: | |
| a_pred = None | |
| else: | |
| wave = wav[a_w["start"] : a_w["end"]].clone() | |
| a_pred, _ = audio_model(wave) | |
| v_pred, _ = vfe(window, config_data.General_WIN_MAX_LENGTH) | |
| t_pred, _ = text_model(transcriptions[w_idx][0]) | |
| if a_pred: | |
| pred_emo = (a_pred["emo"] + v_pred["emo"] + t_pred["emo"]) / 3 | |
| pred_sen = (a_pred["sen"] + v_pred["sen"] + t_pred["sen"]) / 3 | |
| else: | |
| pred_emo = (v_pred["emo"] + t_pred["emo"]) / 2 | |
| pred_sen = (v_pred["sen"] + t_pred["sen"]) / 2 | |
| frames = list( | |
| range( | |
| int(window["start"] * vfe.fps / config_data.General_SR) + 1, | |
| int(window["end"] * vfe.fps / config_data.General_SR) + 2, | |
| ) | |
| ) | |
| preds_emo.extend([torch.argmax(pred_emo).numpy()] * len(frames)) | |
| preds_sen.extend([torch.argmax(pred_sen).numpy()] * len(frames)) | |
| window_frames.extend(frames) | |
| if max(window_frames) < vfe.frame_number: | |
| missed_frames = list(range(max(window_frames) + 1, vfe.frame_number + 1)) | |
| window_frames.extend(missed_frames) | |
| preds_emo.extend([preds_emo[-1]] * len(missed_frames)) | |
| preds_sen.extend([preds_sen[-1]] * len(missed_frames)) | |
| df_pred = pd.DataFrame(columns=["frames", "pred_emo", "pred_sent"]) | |
| df_pred["frames"] = window_frames | |
| df_pred["pred_emo"] = preds_emo | |
| df_pred["pred_sent"] = preds_sen | |
| df_pred = df_pred.groupby("frames").agg( | |
| { | |
| "pred_emo": calculate_mode, | |
| "pred_sent": calculate_mode, | |
| } | |
| ) | |
| frame_indices = get_evenly_spaced_frame_indices(vfe.frame_number, 9) | |
| num_frames = len(wav) | |
| time_axis = [i / config_data.General_SR for i in range(num_frames)] | |
| plt_audio = plot_audio( | |
| time_axis, wav.unsqueeze(0), frame_indices, vfe.fps, (12, 2) | |
| ) | |
| all_idx_faces = list(vfe.faces[1].keys()) | |
| need_idx_faces = find_nearest_frames(frame_indices, all_idx_faces) | |
| faces = [] | |
| for idx_frame, idx_faces in zip(frame_indices, need_idx_faces): | |
| cur_face = cv2.resize( | |
| vfe.faces[1][idx_faces], (224, 224), interpolation=cv2.INTER_AREA | |
| ) | |
| faces.append( | |
| display_frame_info( | |
| cur_face, "Frame: {}".format(idx_frame + 1), box_scale=0.3 | |
| ) | |
| ) | |
| plt_faces = plot_images(faces) | |
| plt_emo = plot_predictions( | |
| df_pred, | |
| "pred_emo", | |
| "Emotion", | |
| list(config_data.General_DICT_EMO), | |
| (12, 2.5), | |
| [i + 1 for i in frame_indices], | |
| 3, | |
| ) | |
| plt_sent = plot_predictions( | |
| df_pred, | |
| "pred_sent", | |
| "Sentiment", | |
| list(config_data.General_DICT_SENT), | |
| (12, 1.5), | |
| [i + 1 for i in frame_indices], | |
| 3, | |
| ) | |
| return ( | |
| gr.Textbox( | |
| value=" ".join(total_text).strip(), | |
| info=config_data.InformationMessages_REC_TEXT, | |
| container=True, | |
| elem_classes="noti-results", | |
| ), | |
| gr.Plot(value=plt_audio, visible=True), | |
| gr.Plot(value=plt_faces, visible=True), | |
| gr.Plot(value=plt_emo, visible=True), | |
| gr.Plot(value=plt_sent, visible=True), | |
| gr.Row(visible=True), | |
| gr.Textbox( | |
| value=config_data.OtherMessages_SEC.format(vfe.dur), | |
| info=config_data.InformationMessages_VIDEO_DURATION, | |
| container=True, | |
| visible=True, | |
| ), | |
| gr.Textbox( | |
| value=timer.execution_time, | |
| info=config_data.InformationMessages_INFERENCE_TIME, | |
| container=True, | |
| visible=True, | |
| ), | |
| ) | |