Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from transformers import BertModel, BertTokenizer | |
| from torch.optim import AdamW, lr_scheduler | |
| from .text_cnn import DynamicTextCNN | |
| from tqdm import tqdm | |
| import io | |
| import os | |
| class ToxicTextClassifier(nn.Module): | |
| def __init__(self, | |
| bert_name='hfl/chinese-roberta-wwm-ext', | |
| num_filters=1536, | |
| filter_sizes=(1,2,3,4), | |
| K=4, | |
| fc_dim=128, | |
| num_classes=2, | |
| dropout=0.1, | |
| name='lited_best'): | |
| super().__init__() | |
| self.tokenizer = BertTokenizer.from_pretrained(bert_name,from_tf=True) | |
| self.bert = BertModel.from_pretrained(bert_name) | |
| self.name = name | |
| self.unfrozen_layers = 0 | |
| hidden_size = self.bert.config.hidden_size * 2 | |
| os.makedirs(f'data/{name}', exist_ok=True) | |
| self.text_cnn = DynamicTextCNN(hidden_size, num_filters, filter_sizes, K, dropout) | |
| input_dim = len(filter_sizes) * num_filters | |
| self.classifier = nn.Sequential( | |
| nn.Linear(input_dim, fc_dim), | |
| nn.ReLU(), | |
| nn.LayerNorm(fc_dim), | |
| nn.Dropout(dropout), | |
| nn.Linear(fc_dim, fc_dim // 2), | |
| nn.ReLU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(fc_dim // 2, num_classes) | |
| ) | |
| self.criterion = nn.CrossEntropyLoss() | |
| self._rebuild_optimizer() | |
| self.warmup_scheduler = None | |
| def _get_warmup_scheduler(self, warmup_steps=1000): | |
| def lr_lambda(current_step): | |
| if current_step < warmup_steps: | |
| return float(current_step) / float(max(1, warmup_steps)) | |
| return 1.0 | |
| return lr_scheduler.LambdaLR(self.optimizer, lr_lambda) | |
| def _rebuild_optimizer(self): | |
| param_groups = [ | |
| {'params': self.text_cnn.parameters(), 'lr': 1e-4}, | |
| {'params': self.classifier.parameters(), 'lr': 1e-4}, | |
| ] | |
| if self.unfrozen_layers > 0: | |
| layers = self.bert.encoder.layer[-self.unfrozen_layers:] | |
| bert_params = [] | |
| for layer in layers: | |
| for p in layer.parameters(): | |
| p.requires_grad = True | |
| bert_params.append(p) | |
| param_groups.append({'params': bert_params, 'lr': 2e-5}) | |
| self.optimizer = AdamW(param_groups, weight_decay=0.01) | |
| self.scheduler = lr_scheduler.ReduceLROnPlateau( | |
| self.optimizer, | |
| mode='min', | |
| factor=0.5, | |
| patience=2, | |
| ) | |
| def forward(self, input_ids, attention_mask, token_type_ids=None): | |
| bert_out = self.bert( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| token_type_ids=token_type_ids, | |
| output_hidden_states=True, | |
| ) | |
| hidden = torch.cat(bert_out.hidden_states[-2:], dim=-1) | |
| feat = self.text_cnn(hidden) | |
| return self.classifier(feat) | |
| def validate(self, val_loader, device): | |
| self.eval() | |
| val_loss = 0 | |
| correct = 0 | |
| total = 0 | |
| all_preds = [] | |
| all_labels = [] | |
| with torch.no_grad(): | |
| pbar = tqdm(val_loader, desc='Validating') | |
| for batch in pbar: | |
| ids = batch['input_ids'].to(device) | |
| mask = batch['attention_mask'].to(device) | |
| types = batch['token_type_ids'].to(device) | |
| labels = batch['label'].to(device) | |
| logits = self(ids, mask, types) | |
| loss = self.criterion(logits, labels) | |
| val_loss += loss.item() | |
| preds = torch.argmax(logits, dim=1) | |
| correct += (preds == labels).sum().item() | |
| total += labels.size(0) | |
| all_preds.extend(preds.cpu().tolist()) | |
| all_labels.extend(labels.cpu().tolist()) | |
| pbar.set_postfix({'loss': f'{loss.item():.4f}'}) | |
| epoch_acc = correct / total if total > 0 else 0 | |
| metrics = { | |
| 'loss': val_loss / len(val_loader), | |
| 'acc': epoch_acc, | |
| 'report': classification_report(all_labels, all_preds, target_names=['non-toxic','toxic']), | |
| 'confusion_matrix': confusion_matrix(all_labels, all_preds) | |
| } | |
| torch.cuda.empty_cache() | |
| return metrics | |
| def train_model(self, train_loader, val_loader, | |
| num_epochs=3, device='cpu', | |
| save_path=None, | |
| logdir=None, | |
| validate_every=100, | |
| warmup_steps=1000, | |
| early_stop_patience=3): | |
| self.to(device) | |
| for param in self.bert.parameters(): | |
| param.requires_grad = False | |
| best_val_loss = float('inf') | |
| global_step = 0 | |
| epochs_no_improve = 0 | |
| best_model_state = None | |
| if save_path is None: | |
| save_path = f'output/{self.name}.pth' | |
| if logdir is None: | |
| logdir = f'runs/{self.name}' | |
| writer = SummaryWriter(logdir) | |
| for epoch in range(1, num_epochs + 1): | |
| print(f"\nEpoch {epoch}/{num_epochs}") | |
| total_loss = 0 | |
| correct = 0 | |
| total = 0 | |
| self.warmup_scheduler = self._get_warmup_scheduler(warmup_steps) | |
| if epoch == 2: | |
| print("Unfreezing 4 layers of BERT") | |
| self.unfrozen_layers = 2 | |
| self._rebuild_optimizer() | |
| pbar = tqdm(train_loader, desc='Training') | |
| for batch in pbar: | |
| ids = batch['input_ids'].to(device) | |
| mask = batch['attention_mask'].to(device) | |
| types = batch['token_type_ids'].to(device) | |
| labels = batch['label'].to(device) | |
| logits = self(ids, mask, types) | |
| loss = self.criterion(logits, labels) | |
| self.optimizer.zero_grad() | |
| loss.backward() | |
| torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=1.0) | |
| self.optimizer.step() | |
| if global_step < warmup_steps: | |
| self.warmup_scheduler.step() | |
| for i, group in enumerate(self.optimizer.param_groups): | |
| writer.add_scalar(f'LR/group_{i}', group['lr'], global_step) | |
| for name, param in self.named_parameters(): | |
| if "convs" in name: | |
| grad_norm = param.grad.norm().item() | |
| writer.add_scalar(f'Gradients/{name}', grad_norm, global_step) | |
| total_loss += loss.item() | |
| preds = torch.argmax(logits, dim=1) | |
| correct += (preds == labels).sum().item() | |
| total += labels.size(0) | |
| acc = correct / total | |
| writer.add_scalar('Loss/train', loss.item(), global_step) | |
| writer.add_scalar('Acc/train', acc, global_step) | |
| pbar.set_postfix({'loss': f'{loss.item():.4f}', 'acc': f'{acc:.4f}'}) | |
| global_step += 1 | |
| if global_step % validate_every == 0: | |
| torch.cuda.empty_cache() | |
| self.eval() | |
| with torch.no_grad(): | |
| metrics = self.validate(val_loader, device) | |
| val_loss, val_acc = metrics['loss'], metrics['acc'] | |
| self.scheduler.step(val_loss) | |
| print(f"\n[Step {global_step}] Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f}") | |
| print(metrics['report']) | |
| report_text = metrics['report'] | |
| conf_mat = metrics['confusion_matrix'] | |
| print(report_text) | |
| writer.add_text('Classification Report', report_text, global_step) | |
| writer.add_scalar('Loss/vali', val_loss, global_step) | |
| writer.add_scalar('Acc/vali', val_acc, global_step) | |
| if val_loss < best_val_loss: | |
| best_val_loss = val_loss | |
| best_model_state = self.state_dict() | |
| epochs_no_improve = 0 | |
| torch.save(best_model_state, save_path) | |
| print(f"Saved best model (step {global_step}) with loss {best_val_loss:.4f}") | |
| else: | |
| epochs_no_improve += 1 | |
| print(f"No improvement for {epochs_no_improve} checks") | |
| if epochs_no_improve >= early_stop_patience: | |
| print(f"Early stopping triggered at step {global_step}!") | |
| self.load_state_dict(best_model_state) | |
| writer.close() | |
| return | |
| flame_colors = ['#ffffcc', '#ffeda0', '#feb24c', '#fd8d3c', '#f03b20', '#bd0026'] | |
| flame_cmap = LinearSegmentedColormap.from_list("flame", flame_colors, N=256) | |
| fig, ax = plt.subplots(figsize=(8, 6)) | |
| sns.set_theme(font_scale=1.4) | |
| sns.heatmap( | |
| conf_mat, | |
| annot=True, | |
| fmt='d', | |
| cmap=flame_cmap, | |
| linewidths=0.5, | |
| linecolor='gray', | |
| square=True, | |
| cbar=True, | |
| xticklabels=['non-toxic', 'toxic'], | |
| yticklabels=['non-toxic', 'toxic'], | |
| annot_kws={"size": 16, "weight": "bold"} | |
| ) | |
| ax.set_xlabel('Predicted', fontsize=14, labelpad=10) | |
| ax.set_ylabel('True', fontsize=14, labelpad=10) | |
| ax.set_title('Confusion Matrix', fontsize=16, pad=12) | |
| ax.xaxis.set_tick_params(labelsize=12) | |
| ax.yaxis.set_tick_params(labelsize=12) | |
| ax.xaxis.set_major_locator(ticker.FixedLocator([0.5, 1.5])) | |
| ax.yaxis.set_major_locator(ticker.FixedLocator([0.5, 1.5])) | |
| buf = io.BytesIO() | |
| plt.tight_layout() | |
| plt.savefig(buf, format='png', dpi=150) | |
| plt.savefig(f'data/{self.name}/conf_matrix_step{global_step}.pdf', format='pdf', bbox_inches='tight') | |
| buf.seek(0) | |
| image = Image.open(buf) | |
| image_tensor = ToTensor()(image) | |
| writer.add_image('Confusion Matrix', image_tensor, global_step) | |
| buf.close() | |
| plt.close(fig) | |
| self.train() | |
| writer.close() | |
| def predict(self, texts, device='cpu'): | |
| """Used for inference. Predicts the class of the input text. | |
| Args: | |
| texts (str or list of str): The input text(s) to classify, pass str. | |
| - If a list is passed, the model will classify each text in the list as batch. | |
| - If a single string is passed, the model will classify the text as a single instance. | |
| - If a list of list is passed, the model will treate the first element as detected text and the second element as the context text. | |
| device (str): The device to run the model on ('cpu', 'cuda', or 'mps'). If None, it will use the available device. | |
| max_length (int): The maximum length of the input text. | |
| Returns: | |
| list: A list of dictionaries containing the prediction and probabilities for each input text. | |
| Each dictionary contains: | |
| - 'text': The input text. | |
| - 'prediction': The predicted class (0 or 1). | |
| - 'probabilities': The probabilities for each class. | |
| """ | |
| if device is None: | |
| if torch.cuda.is_available(): | |
| device = 'cuda' | |
| elif torch.backends.mps.is_available(): | |
| device = 'mps' | |
| else: | |
| device = 'cpu' | |
| self.eval() | |
| self.to(device) | |
| if isinstance(texts, str): | |
| texts = [texts] | |
| encoded_inputs = self.tokenizer( | |
| texts, | |
| padding=True, | |
| truncation=True, | |
| return_tensors="pt" | |
| ).to(device) | |
| elif isinstance(texts, list) and all(isinstance(item, list) for item in texts): | |
| encoded_inputs = self.tokenizer( | |
| [item[0] for item in texts], | |
| [item[1] for item in texts], | |
| padding=True, | |
| truncation=True, | |
| return_tensors="pt" | |
| ).to(device) | |
| elif isinstance(texts, list) and all(isinstance(item, str) for item in texts): | |
| encoded_inputs = self.tokenizer( | |
| texts, | |
| padding=True, | |
| truncation=True, | |
| return_tensors="pt" | |
| ).to(device) | |
| else: | |
| raise ValueError("Invalid input type. Expected str or list of str.") | |
| input_ids = encoded_inputs['input_ids'] | |
| attention_mask = encoded_inputs['attention_mask'] | |
| token_type_ids = encoded_inputs.get('token_type_ids', None) | |
| with torch.no_grad(): | |
| logits = self(input_ids, attention_mask, token_type_ids) | |
| probs = torch.softmax(logits, dim=-1) | |
| preds = torch.argmax(probs, dim=-1) | |
| results = [] | |
| for i, text in enumerate(texts): | |
| results.append({ | |
| 'text': text, | |
| 'prediction': preds[i].item(), | |
| 'probabilities': probs[i].cpu().tolist() | |
| }) | |
| return results | |