Spaces:
Runtime error
Runtime error
| import time | |
| import torch | |
| import librosa | |
| import numpy as np | |
| import gradio as gr | |
| import gradio as gr | |
| from .generate_graph import create_behaviour_gantt_plot | |
| from transformers import Wav2Vec2Processor | |
| SAMPLING_RATE = 16_000 | |
| class AudioProcessor: | |
| def __init__( | |
| self, | |
| emotion_model, | |
| segmentation_model, | |
| device, | |
| behaviour_model=None, | |
| ): | |
| self.emotion_processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h") | |
| self.emotion_model = emotion_model | |
| self.behaviour_model = behaviour_model | |
| self.device = device | |
| self.audio_emotion_labels = { | |
| 0: "Neutralità", | |
| 1: "Rabbia", | |
| 2: "Paura", | |
| 3: "Gioia", | |
| 4: "Sorpresa", | |
| 5: "Tristezza", | |
| 6: "Disgusto", | |
| } | |
| self.emotion_translation = { | |
| "neutrality": "Neutralità", | |
| "anger": "Rabbia", | |
| "fear": "Paura", | |
| "joy": "Gioia", | |
| "surprise": "Sorpresa", | |
| "sadness": "Tristezza", | |
| "disgust": "Disgusto" | |
| } | |
| self.behaviour_labels = { | |
| 0: "frustrated", | |
| 1: "delighted", | |
| 2: "dysregulated", | |
| } | |
| self.behaviour_translation = { | |
| "frustrated": "frustazione", | |
| "delighted": "incantato", | |
| "dysregulated": "disregolazione", | |
| } | |
| self.segmentation_model = segmentation_model | |
| self._set_emotion_model() | |
| if self.behaviour_model: | |
| self._set_behaviour_model() | |
| self.behaviour_confidence = 0.6 | |
| self.chart_generator = None | |
| def _set_emotion_model(self): | |
| self.emotion_model.to(self.device) | |
| self.emotion_model.eval() | |
| def _set_behaviour_model(self): | |
| self.behaviour_model.to(self.device) | |
| self.behaviour_model.eval() | |
| def _prepare_transcribed_text(self, chunks): | |
| formated_timestamps = [] | |
| predictions = [] | |
| for chunk in chunks: | |
| start = chunk[0] / SAMPLING_RATE | |
| end = chunk[1] / SAMPLING_RATE | |
| formated_start = time.strftime('%H:%M:%S', time.gmtime(start)) | |
| formated_end = time.strftime('%H:%M:%S', time.gmtime(end)) | |
| formated_timestamps.append(f"**({formated_start} - {formated_end})**") | |
| predictions.append(f"**[{chunk[2]}]**") | |
| transcribed_texts = [chunk[3] for chunk in chunks] | |
| transcribed_text = "<br/>".join( | |
| [ | |
| f"{formated_timestamps[i]}: {transcribed_texts[i]} {predictions[i]}" for i in range(len(transcribed_texts)) | |
| ] | |
| ) | |
| print(f"Transcribed text:\n{transcribed_text}") | |
| return transcribed_text | |
| def __call__(self, audio_path: str): | |
| """ | |
| Predicts the emotion label for a given audio input. | |
| Args: | |
| audio (filepath): The audio input path to be processed. | |
| Returns: | |
| str: The predicted emotion label. | |
| """ | |
| try: | |
| input_frames, _ = librosa.load( | |
| audio_path, | |
| sr=SAMPLING_RATE | |
| ) | |
| except Exception as e: | |
| gr.Error(f"Error loading audio file: {e}.") | |
| print("Segmenting audio...") | |
| out = self.segmentation_model( | |
| inputs={ | |
| "raw": input_frames, | |
| "sampling_rate": SAMPLING_RATE, | |
| }, | |
| chunk_length_s=30, | |
| stride_length_s=5, | |
| return_timestamps=True, | |
| ) | |
| emotion_chunks = [] | |
| behaviour_chunks = [] | |
| timestamps = [] | |
| predicted_labels = [] | |
| all_probabilities = [] | |
| print("Analizing chunks...") | |
| for chunk in out["chunks"]: | |
| # trim audio from timestamps | |
| start = int(chunk["timestamp"][0] * SAMPLING_RATE) | |
| end = int(chunk["timestamp"][1] * SAMPLING_RATE if chunk["timestamp"][1] else len(input_frames)) | |
| audio = input_frames[start:end] | |
| inputs = self.emotion_processor(audio, chunk["text"], return_tensors="pt", sampling_rate=SAMPLING_RATE) | |
| print(f"Inputs: {inputs}") | |
| if "input_values" in inputs: | |
| inputs["input_features"] = inputs.pop("input_values") | |
| inputs['input_features'] = inputs['input_features'].to(self.device) | |
| inputs['input_ids'] = inputs['input_ids'].to(self.device) | |
| inputs['text_attention_mask'] = inputs['text_attention_mask'].to(self.device) | |
| print("Predicting emotion for chunk...") | |
| logits = self.emotion_model(**inputs).logits | |
| logits = logits.detach().cpu() | |
| softmax = torch.nn.Softmax(dim=1) | |
| probabilities = softmax(logits).squeeze(0) | |
| prediction = probabilities.argmax().item() | |
| predicted_label = self.emotion_processor.config.id2label[prediction] | |
| label_translation = self.emotion_translation[predicted_label] | |
| emotion_chunks.append( | |
| ( | |
| start, | |
| end, | |
| label_translation, | |
| chunk["text"], | |
| np.round(probabilities[prediction].item(), 2) | |
| ) | |
| ) | |
| timestamps.append((start, end)) | |
| predicted_labels.append(label_translation) | |
| all_probabilities.append(probabilities[prediction].item()) | |
| inputs = self.emotion_processor(audio, return_tensors="pt", sampling_rate=SAMPLING_RATE) | |
| if "input_values" in inputs: | |
| inputs["input_features"] = inputs.pop("input_values") | |
| inputs = inputs.input_features.to(self.device) | |
| print("Predicting behaviour for chunk...") | |
| logits = self.behaviour_model(inputs).logits | |
| probabilities = torch.nn.functional.softmax(logits.detach().cpu(), dim=-1).squeeze() | |
| behaviour_chunks.append( | |
| ( | |
| start, | |
| end, | |
| chunk["text"], | |
| np.round(probabilities[2].item(), 2), | |
| label_translation, | |
| ) | |
| ) | |
| behaviour_gantt = create_behaviour_gantt_plot(behaviour_chunks) | |
| # transcribed_text = self._prepare_transcribed_text(emotion_chunks) | |
| return ( | |
| behaviour_gantt, | |
| # transcribed_text, | |
| ) |