Spaces:
Runtime error
Runtime error
| import time | |
| import plotly.graph_objects as go | |
| from datetime import datetime, timedelta | |
| SAMPLING_RATE = 16_000 | |
| COLOR_MAP = { | |
| "Neutralità": "rgb(178, 178, 178)", | |
| "Rabbia": "rgb(160, 61, 62)", | |
| "Paura": "rgb(91, 57, 136)", | |
| "Gioia": "rgb(255, 255, 0)", | |
| "Sorpresa": "rgb(60, 175, 175)", | |
| "Tristezza": "rgb(64, 106, 173)", | |
| "Disgusto": "rgb(100, 153, 65)", | |
| } | |
| def create_behaviour_gantt_plot(behaviour_chunks, confidence_threshold=60): | |
| print("Creating behaviour Gantt plot...") | |
| emotion_order = [ | |
| "Gioia", | |
| "Sorpresa", | |
| "Disgusto", | |
| "Tristezza", | |
| "Paura", | |
| "Rabbia", | |
| "Neutralità" | |
| ] | |
| fig = go.Figure() | |
| chunk_starts = [start/SAMPLING_RATE for start, _, _, _, _ in behaviour_chunks] | |
| chunk_ends = [end/SAMPLING_RATE for _, end, _, _, _ in behaviour_chunks] | |
| # Create reference time for plotting (starting at 0) | |
| # We'll use a base datetime and add seconds | |
| base_time = datetime(2_000, 1, 1, 0, 0, 0) # TODO: change magic numbers | |
| start_times = [base_time + timedelta(seconds=t) for t in chunk_starts] | |
| end_times = [base_time + timedelta(seconds=t) for t in chunk_ends] | |
| # Calculate midpoints for each chunk (for trend line) | |
| mid_times = [base_time + timedelta(seconds=(s+e)/2) for s, e in zip(chunk_starts, chunk_ends)] | |
| heights = [height * 100 for _, _, _, height, _ in behaviour_chunks] | |
| emotions = [emotion for _, _, _, _, emotion in behaviour_chunks] | |
| hover_texts = [] | |
| for i, (start, end, label, height, emotion) in enumerate(behaviour_chunks): | |
| start_fmt = time.strftime('%H:%M:%S', time.gmtime(start / SAMPLING_RATE)) | |
| end_fmt = time.strftime('%H:%M:%S', time.gmtime(end / SAMPLING_RATE)) | |
| duration_seconds = (end - start) / SAMPLING_RATE | |
| duration_str = time.strftime('%H:%M:%S', time.gmtime(duration_seconds)) | |
| hover_text = f"Inizio: {start_fmt}<br>Fine: {end_fmt}<br>Durata: {duration_str}<br>Testo: {label}<br>Attendibilità: {height*100:.2f}%<br>Emozione: {emotion}" | |
| hover_texts.append(hover_text) | |
| fig.add_shape( | |
| type="rect", | |
| x0=start_times[0], | |
| x1=end_times[-1], | |
| y0=confidence_threshold, | |
| y1=100, | |
| fillcolor="rgba(188,223,241,0.8)", | |
| opacity=0.8, | |
| layer="below", | |
| line_width=0, | |
| ) | |
| fig.add_hline(y=confidence_threshold, line_dash="dash", line_color="black", line_width=1) | |
| fig.add_trace( | |
| go.Scatter( | |
| x=mid_times, | |
| y=heights, | |
| mode='lines', | |
| name='Disregolazione', | |
| line=dict( | |
| color='orange', | |
| width=2, | |
| shape='spline', # This enables smoothing | |
| smoothing=1.0, # Adjust smoothing factor | |
| ), | |
| text=hover_texts, | |
| hoverinfo='text', | |
| showlegend=False, | |
| ) | |
| ) | |
| emotion_data = {} | |
| for i, height in enumerate(heights): | |
| if height >= confidence_threshold: | |
| emotion = emotions[i] | |
| if emotion not in emotion_data: | |
| emotion_data[emotion] = { | |
| 'times': [], | |
| 'heights': [], | |
| 'hover_texts': [] | |
| } | |
| emotion_data[emotion]['times'].append(mid_times[i]) | |
| emotion_data[emotion]['heights'].append(height) | |
| emotion_data[emotion]['hover_texts'].append(hover_texts[i]) | |
| for emotion in emotion_order: | |
| color = COLOR_MAP.get(emotion, '#000000') | |
| if emotion in emotion_data: | |
| data = emotion_data[emotion] | |
| fig.add_trace( | |
| go.Scatter( | |
| x=data['times'], | |
| y=data['heights'], | |
| mode='markers', | |
| name=emotion.capitalize(), | |
| marker=dict( | |
| size=15, | |
| color=color, | |
| symbol='circle' | |
| ), | |
| text=data['hover_texts'], | |
| hoverinfo='text', | |
| showlegend=True, | |
| ) | |
| ) | |
| else: | |
| fig.add_trace( | |
| go.Scatter( | |
| x=[None], | |
| y=[None], | |
| mode='markers', | |
| name=emotion.capitalize(), | |
| marker=dict( | |
| size=15, | |
| color=color, | |
| symbol='circle' | |
| ), | |
| showlegend=True, | |
| ) | |
| ) | |
| fig.update_layout( | |
| title='Distribuzione della disregolazione', | |
| xaxis_title='Tempo', | |
| yaxis_title='Attendibilità', | |
| xaxis=dict( | |
| type='date', | |
| tickformat='%H:%M:%S', | |
| showline=True, | |
| zeroline=False, | |
| side='bottom', | |
| showgrid=False, | |
| ), | |
| yaxis=dict( | |
| range=[0, 100], | |
| tickvals=[0, 20, 40, 60, 80, 100], | |
| ticktext=['0%', '20%', '40%', '60%', '80%', '100%'], | |
| tickmode='array', | |
| showgrid=False, | |
| ), | |
| legend_title=None, | |
| legend=dict( | |
| yanchor="top" | |
| ), | |
| hoverlabel=dict( | |
| font_size=12, | |
| font_family="Arial" | |
| ), | |
| paper_bgcolor='white', | |
| plot_bgcolor='white', | |
| ) | |
| fig.update_traces(hovertemplate=None) | |
| return fig |