import gradio as gr import pandas as pd import numpy as np from sklearn.model_selection import train_test_split from transformers import ( AutoTokenizer, BertForTokenClassification, AutoModelForTokenClassification, pipeline ) import torch import os import seaborn as sns from matplotlib.colors import to_hex import html os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["CUDA_VISIBLE_DEVICES"] = '0' class SpanClassifierWithStrictF1: def __init__(self, model_name="deepset/gbert-base"): self.model_name = model_name self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.labels =[ "O", "B-positive feedback", "B-compliment", "B-affection declaration", "B-encouragement", "B-gratitude", "B-agreement", "B-ambiguous", "B-implicit", "B-group membership", "B-sympathy", "I-positive feedback", "I-compliment", "I-affection declaration", "I-encouragement", "I-gratitude", "I-agreement", "I-ambiguous", "I-implicit", "I-group membership", "I-sympathy" ] self.label2id = {label: i for i, label in enumerate(self.labels)} self.id2label = {i: label for i, label in enumerate(self.labels)} def create_dataset(self, comments_df, spans_df): """Erstelle Dataset mit BIO-Labels und speichere Evaluation-Daten""" examples = [] eval_data = [] # Für Strict F1 Berechnung spans_grouped = spans_df.groupby(['document', 'comment_id']) for _, row in comments_df.iterrows(): text = row['comment'] document = row['document'] comment_id = row['comment_id'] key = (document, comment_id) # True spans für diesen Kommentar if key in spans_grouped.groups: true_spans = [(span_type, int(start), int(end)) for span_type, start, end in spans_grouped.get_group(key)[['type', 'start', 'end']].values] else: true_spans = [] # Tokenisierung tokenized = self.tokenizer(text, truncation=True, max_length=512, return_offsets_mapping=True) # BIO-Labels erstellen labels = self._create_bio_labels(tokenized['offset_mapping'], spans_grouped.get_group(key)[['start', 'end', 'type']].values if key in spans_grouped.groups else []) examples.append({ 'input_ids': tokenized['input_ids'], 'attention_mask': tokenized['attention_mask'], 'labels': labels }) # Evaluation-Daten speichern eval_data.append({ 'text': text, 'offset_mapping': tokenized['offset_mapping'], 'true_spans': true_spans, 'document': document, 'comment_id': comment_id }) return examples, eval_data def _create_bio_labels(self, offset_mapping, spans): """Erstelle BIO-Labels für Tokens""" labels = [0] * len(offset_mapping) # 0 = "O" for start, end, type_label in spans: for i, (token_start, token_end) in enumerate(offset_mapping): if token_start is None: # Spezielle Tokens continue # Token überlappt mit Span if token_start < end and token_end > start: if token_start <= start: labels[i] = self.label2id[f'B-{type_label}'] # B-compliment else: labels[i] = self.label2id[f'I-{type_label}'] # I-compliment return labels def compute_metrics(self, eval_pred): """Berechne Strict F1 für Trainer""" predictions, labels = eval_pred predictions = np.argmax(predictions, axis=2) # Konvertiere Vorhersagen zu Spans batch_pred_spans = [] batch_true_spans = [] for i, (pred_seq, label_seq) in enumerate(zip(predictions, labels)): # Evaluation-Daten für dieses Beispiel if i < len(self.current_eval_data): eval_item = self.current_eval_data[i] text = eval_item['text'] offset_mapping = eval_item['offset_mapping'] true_spans = eval_item['true_spans'] # Filtere gültige Vorhersagen (keine Padding-Tokens) valid_predictions = [] valid_offsets = [] for j, (pred_label, true_label) in enumerate(zip(pred_seq, label_seq)): if true_label != -100 and j < len(offset_mapping): valid_predictions.append(pred_label) valid_offsets.append(offset_mapping[j]) # Konvertiere zu Spans pred_spans = self._predictions_to_spans(valid_predictions, valid_offsets, text) pred_spans_tuples = [(span['type'], span['start'], span['end']) for span in pred_spans] batch_pred_spans.append(pred_spans_tuples) batch_true_spans.append(true_spans) # Berechne Strict F1 strict_f1, strict_precision, strict_recall, tp, fp, fn = self._calculate_strict_f1( batch_true_spans, batch_pred_spans ) torch.cuda.memory.empty_cache() return { "strict_f1": torch.tensor(strict_f1), "strict_precision": torch.tensor(strict_precision), "strict_recall": torch.tensor(strict_recall), "true_positives": torch.tensor(tp), "false_positives": torch.tensor(fp), "false_negatives": torch.tensor(fn) } def _calculate_strict_f1(self, true_spans_list, pred_spans_list): """Berechne Strict F1 über alle Kommentare""" tp, fp, fn = 0, 0, 0 for true_spans, pred_spans in zip(true_spans_list, pred_spans_list): # Finde exakte Matches (Typ und Span müssen übereinstimmen) matches = self._find_exact_matches(true_spans, pred_spans) tp += len(matches) fp += len(pred_spans) - len(matches) fn += len(true_spans) - len(matches) # Berechne Metriken precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0 recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0 f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0 return f1, precision, recall, tp, fp, fn def _find_exact_matches(self, true_spans, pred_spans): """Finde exakte Matches zwischen True und Predicted Spans""" matches = [] used_pred = set() for true_span in true_spans: for i, pred_span in enumerate(pred_spans): if i not in used_pred and true_span == pred_span: matches.append((true_span, pred_span)) used_pred.add(i) break return matches def _predictions_to_spans(self, predicted_labels, offset_mapping, text): """Konvertiere Token-Vorhersagen zu Spans""" spans = [] current_span = None for i, label_id in enumerate(predicted_labels): if i >= len(offset_mapping): break label = self.id2label[label_id] token_start, token_end = offset_mapping[i] if token_start is None: continue if label.startswith('B-'): if current_span: spans.append(current_span) current_span = { 'type': label[2:], 'start': token_start, 'end': token_end, 'text': text[token_start:token_end] } elif label.startswith('I-') and current_span: current_span['end'] = token_end current_span['text'] = text[current_span['start']:current_span['end']] else: if current_span: spans.append(current_span) current_span = None if current_span: spans.append(current_span) return spans def predict(self, texts): """Vorhersage für neue Texte""" if not hasattr(self, 'model'): raise ValueError("Modell muss erst trainiert werden!") predictions = [] device = next(self.model.parameters()).device for text in texts: # Tokenisierung inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=512, return_offsets_mapping=True) offset_mapping = inputs.pop('offset_mapping') inputs = {k: v.to(device) for k, v in inputs.items()} # Vorhersage with torch.no_grad(): outputs = self.model(**inputs) predicted_labels = torch.argmax(outputs.logits, dim=2)[0].cpu().numpy() # Spans extrahieren spans = self._predictions_to_spans(predicted_labels, offset_mapping[0], text) predictions.append({'text': text, 'spans': spans}) return predictions def evaluate_strict_f1(self, comments_df, spans_df): """Evaluiere Strict F1 auf Test-Daten""" if not hasattr(self, 'model'): raise ValueError("Modell muss erst trainiert werden!") print("Evaluiere Strict F1...") # Vorhersagen für alle Kommentare texts = comments_df['comment'].tolist() predictions = self.predict(texts) # Organisiere True Spans spans_grouped = spans_df.groupby(['document', 'comment_id']) true_spans_dict = {} pred_spans_dict = {} for i, (_, row) in enumerate(comments_df.iterrows()): key = (row['document'], row['comment_id']) # True spans if key in spans_grouped.groups: true_spans = [(span_type, int(start), int(end)) for span_type, start, end in spans_grouped.get_group(key)[['type', 'start', 'end']].values] else: true_spans = [] # Predicted spans pred_spans = [(span['type'], span['start'], span['end']) for span in predictions[i]['spans']] true_spans_dict[key] = true_spans pred_spans_dict[key] = pred_spans # Berechne Strict F1 all_true_spans = list(true_spans_dict.values()) all_pred_spans = list(pred_spans_dict.values()) f1, precision, recall, tp, fp, fn = self._calculate_strict_f1(all_true_spans, all_pred_spans) print(f"\nStrict F1 Ergebnisse:") print(f"Precision: {precision:.4f}") print(f"Recall: {recall:.4f}") print(f"F1-Score: {f1:.4f}") print(f"True Positives: {tp}, False Positives: {fp}, False Negatives: {fn}") return { 'strict_f1': f1, 'strict_precision': precision, 'strict_recall': recall, 'true_positives': tp, 'false_positives': fp, 'false_negatives': fn } def convert_spans(row): spans = row['predicted_spans'] document = row['document'] comment_id = row['comment_id'] return [{'document': document, 'comment_id': comment_id, 'type': span['type'], 'start': span['start'], 'end': span['end']} for span in spans] def pred_to_spans(row): predicted_labels, offset_mapping, text = row['predicted_labels'], row['offset_mapping'], row['comment'] return [classifier._predictions_to_spans(predicted_labels, offset_mapping, text)] def create_highlighted_html(text, spans): """Erstelle HTML mit hervorgehobenen Spans""" if not spans: return html.escape(text) # Definiere Farben für verschiedene Span-Typen colors = { 'positive feedback': '#FFE5E5', 'compliment': '#E5F3FF', 'affection declaration': '#FFE5F3', 'encouragement': '#E5FFE5', 'gratitude': '#FFF5E5', 'agreement': '#F0E5FF', 'ambiguous': '#E5E5E5', 'implicit': '#E5FFFF', 'group membership': '#FFFFE5', 'sympathy': '#F5E5FF' } colors = { 'positive feedback': '#8dd3c7', # tealfarbenes Pastell 'compliment': '#ffffb3', # helles Pastellgelb 'affection declaration': '#bebada', # fliederfarbenes Pastell 'encouragement': '#fb8072', # lachsfarbenes Pastell 'gratitude': '#80b1d3', # himmelblaues Pastell 'agreement': '#fdb462', # pfirsichfarbenes Pastell 'ambiguous': '#d9d9d9', # neutrales Pastellgrau 'implicit': '#fccde5', # roséfarbenes Pastell 'group membership': '#b3de69', # lindgrünes Pastell 'sympathy': '#bc80bd' # lavendelfarbenes Pastell } # Sortiere Spans nach Start-Position sorted_spans = sorted(spans, key=lambda x: x['start']) html_parts = [] last_end = 0 for span in sorted_spans: # Text vor dem Span if span['start'] > last_end: html_parts.append(html.escape(text[last_end:span['start']])) # Hervorgehobener Span color = colors.get(span['type'], '#EEEEEE') span_text = html.escape(text[span['start']:span['end']]) html_parts.append( f'{span_text}') last_end = span['end'] # Restlicher Text if last_end < len(text): html_parts.append(html.escape(text[last_end:])) return ''.join(html_parts) def create_legend(): """Erstelle eine Legende für die Span-Typen""" #colors = { # 'positive feedback': '#FFE5E5', # 'compliment': '#E5F3FF', # 'affection declaration': '#FFE5F3', # 'encouragement': '#E5FFE5', # 'gratitude': '#FFF5E5', # 'agreement': '#F0E5FF', # 'ambiguous': '#E5E5E5', # 'implicit': '#E5FFFF', # 'group membership': '#FFFFE5', # 'sympathy': '#F5E5FF' #} colors = { 'positive feedback': '#8dd3c7', # tealfarbenes Pastell 'compliment': '#ffffb3', # helles Pastellgelb 'affection declaration': '#bebada', # fliederfarbenes Pastell 'encouragement': '#fb8072', # lachsfarbenes Pastell 'gratitude': '#80b1d3', # himmelblaues Pastell 'agreement': '#fdb462', # pfirsichfarbenes Pastell 'ambiguous': '#d9d9d9', # neutrales Pastellgrau 'implicit': '#fccde5', # roséfarbenes Pastell 'group membership': '#b3de69', # lindgrünes Pastell 'sympathy': '#bc80bd' # lavendelfarbenes Pastell } legend_html = "
Analysieren Sie Texte und identifizieren Sie verschiedene Arten positiver Kommunikation.