lite_DETECTIVE / cold /classifier.py
AlbertCAC's picture
update
9c6c6a5
import torch
import torch.nn as nn
from transformers import BertModel, BertTokenizer
from torch.optim import AdamW, lr_scheduler
from .text_cnn import DynamicTextCNN
from tqdm import tqdm
import io
import os
class ToxicTextClassifier(nn.Module):
def __init__(self,
bert_name='hfl/chinese-roberta-wwm-ext',
num_filters=1536,
filter_sizes=(1,2,3,4),
K=4,
fc_dim=128,
num_classes=2,
dropout=0.1,
name='lited_best'):
super().__init__()
self.tokenizer = BertTokenizer.from_pretrained(bert_name,from_tf=True)
self.bert = BertModel.from_pretrained(bert_name)
self.name = name
self.unfrozen_layers = 0
hidden_size = self.bert.config.hidden_size * 2
os.makedirs(f'data/{name}', exist_ok=True)
self.text_cnn = DynamicTextCNN(hidden_size, num_filters, filter_sizes, K, dropout)
input_dim = len(filter_sizes) * num_filters
self.classifier = nn.Sequential(
nn.Linear(input_dim, fc_dim),
nn.ReLU(),
nn.LayerNorm(fc_dim),
nn.Dropout(dropout),
nn.Linear(fc_dim, fc_dim // 2),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(fc_dim // 2, num_classes)
)
self.criterion = nn.CrossEntropyLoss()
self._rebuild_optimizer()
self.warmup_scheduler = None
def _get_warmup_scheduler(self, warmup_steps=1000):
def lr_lambda(current_step):
if current_step < warmup_steps:
return float(current_step) / float(max(1, warmup_steps))
return 1.0
return lr_scheduler.LambdaLR(self.optimizer, lr_lambda)
def _rebuild_optimizer(self):
param_groups = [
{'params': self.text_cnn.parameters(), 'lr': 1e-4},
{'params': self.classifier.parameters(), 'lr': 1e-4},
]
if self.unfrozen_layers > 0:
layers = self.bert.encoder.layer[-self.unfrozen_layers:]
bert_params = []
for layer in layers:
for p in layer.parameters():
p.requires_grad = True
bert_params.append(p)
param_groups.append({'params': bert_params, 'lr': 2e-5})
self.optimizer = AdamW(param_groups, weight_decay=0.01)
self.scheduler = lr_scheduler.ReduceLROnPlateau(
self.optimizer,
mode='min',
factor=0.5,
patience=2,
)
def forward(self, input_ids, attention_mask, token_type_ids=None):
bert_out = self.bert(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
output_hidden_states=True,
)
hidden = torch.cat(bert_out.hidden_states[-2:], dim=-1)
feat = self.text_cnn(hidden)
return self.classifier(feat)
def validate(self, val_loader, device):
self.eval()
val_loss = 0
correct = 0
total = 0
all_preds = []
all_labels = []
with torch.no_grad():
pbar = tqdm(val_loader, desc='Validating')
for batch in pbar:
ids = batch['input_ids'].to(device)
mask = batch['attention_mask'].to(device)
types = batch['token_type_ids'].to(device)
labels = batch['label'].to(device)
logits = self(ids, mask, types)
loss = self.criterion(logits, labels)
val_loss += loss.item()
preds = torch.argmax(logits, dim=1)
correct += (preds == labels).sum().item()
total += labels.size(0)
all_preds.extend(preds.cpu().tolist())
all_labels.extend(labels.cpu().tolist())
pbar.set_postfix({'loss': f'{loss.item():.4f}'})
epoch_acc = correct / total if total > 0 else 0
metrics = {
'loss': val_loss / len(val_loader),
'acc': epoch_acc,
'report': classification_report(all_labels, all_preds, target_names=['non-toxic','toxic']),
'confusion_matrix': confusion_matrix(all_labels, all_preds)
}
torch.cuda.empty_cache()
return metrics
def train_model(self, train_loader, val_loader,
num_epochs=3, device='cpu',
save_path=None,
logdir=None,
validate_every=100,
warmup_steps=1000,
early_stop_patience=3):
self.to(device)
for param in self.bert.parameters():
param.requires_grad = False
best_val_loss = float('inf')
global_step = 0
epochs_no_improve = 0
best_model_state = None
if save_path is None:
save_path = f'output/{self.name}.pth'
if logdir is None:
logdir = f'runs/{self.name}'
writer = SummaryWriter(logdir)
for epoch in range(1, num_epochs + 1):
print(f"\nEpoch {epoch}/{num_epochs}")
total_loss = 0
correct = 0
total = 0
self.warmup_scheduler = self._get_warmup_scheduler(warmup_steps)
if epoch == 2:
print("Unfreezing 4 layers of BERT")
self.unfrozen_layers = 2
self._rebuild_optimizer()
pbar = tqdm(train_loader, desc='Training')
for batch in pbar:
ids = batch['input_ids'].to(device)
mask = batch['attention_mask'].to(device)
types = batch['token_type_ids'].to(device)
labels = batch['label'].to(device)
logits = self(ids, mask, types)
loss = self.criterion(logits, labels)
self.optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=1.0)
self.optimizer.step()
if global_step < warmup_steps:
self.warmup_scheduler.step()
for i, group in enumerate(self.optimizer.param_groups):
writer.add_scalar(f'LR/group_{i}', group['lr'], global_step)
for name, param in self.named_parameters():
if "convs" in name:
grad_norm = param.grad.norm().item()
writer.add_scalar(f'Gradients/{name}', grad_norm, global_step)
total_loss += loss.item()
preds = torch.argmax(logits, dim=1)
correct += (preds == labels).sum().item()
total += labels.size(0)
acc = correct / total
writer.add_scalar('Loss/train', loss.item(), global_step)
writer.add_scalar('Acc/train', acc, global_step)
pbar.set_postfix({'loss': f'{loss.item():.4f}', 'acc': f'{acc:.4f}'})
global_step += 1
if global_step % validate_every == 0:
torch.cuda.empty_cache()
self.eval()
with torch.no_grad():
metrics = self.validate(val_loader, device)
val_loss, val_acc = metrics['loss'], metrics['acc']
self.scheduler.step(val_loss)
print(f"\n[Step {global_step}] Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f}")
print(metrics['report'])
report_text = metrics['report']
conf_mat = metrics['confusion_matrix']
print(report_text)
writer.add_text('Classification Report', report_text, global_step)
writer.add_scalar('Loss/vali', val_loss, global_step)
writer.add_scalar('Acc/vali', val_acc, global_step)
if val_loss < best_val_loss:
best_val_loss = val_loss
best_model_state = self.state_dict()
epochs_no_improve = 0
torch.save(best_model_state, save_path)
print(f"Saved best model (step {global_step}) with loss {best_val_loss:.4f}")
else:
epochs_no_improve += 1
print(f"No improvement for {epochs_no_improve} checks")
if epochs_no_improve >= early_stop_patience:
print(f"Early stopping triggered at step {global_step}!")
self.load_state_dict(best_model_state)
writer.close()
return
flame_colors = ['#ffffcc', '#ffeda0', '#feb24c', '#fd8d3c', '#f03b20', '#bd0026']
flame_cmap = LinearSegmentedColormap.from_list("flame", flame_colors, N=256)
fig, ax = plt.subplots(figsize=(8, 6))
sns.set_theme(font_scale=1.4)
sns.heatmap(
conf_mat,
annot=True,
fmt='d',
cmap=flame_cmap,
linewidths=0.5,
linecolor='gray',
square=True,
cbar=True,
xticklabels=['non-toxic', 'toxic'],
yticklabels=['non-toxic', 'toxic'],
annot_kws={"size": 16, "weight": "bold"}
)
ax.set_xlabel('Predicted', fontsize=14, labelpad=10)
ax.set_ylabel('True', fontsize=14, labelpad=10)
ax.set_title('Confusion Matrix', fontsize=16, pad=12)
ax.xaxis.set_tick_params(labelsize=12)
ax.yaxis.set_tick_params(labelsize=12)
ax.xaxis.set_major_locator(ticker.FixedLocator([0.5, 1.5]))
ax.yaxis.set_major_locator(ticker.FixedLocator([0.5, 1.5]))
buf = io.BytesIO()
plt.tight_layout()
plt.savefig(buf, format='png', dpi=150)
plt.savefig(f'data/{self.name}/conf_matrix_step{global_step}.pdf', format='pdf', bbox_inches='tight')
buf.seek(0)
image = Image.open(buf)
image_tensor = ToTensor()(image)
writer.add_image('Confusion Matrix', image_tensor, global_step)
buf.close()
plt.close(fig)
self.train()
writer.close()
def predict(self, texts, device='cpu'):
"""Used for inference. Predicts the class of the input text.
Args:
texts (str or list of str): The input text(s) to classify, pass str.
- If a list is passed, the model will classify each text in the list as batch.
- If a single string is passed, the model will classify the text as a single instance.
- If a list of list is passed, the model will treate the first element as detected text and the second element as the context text.
device (str): The device to run the model on ('cpu', 'cuda', or 'mps'). If None, it will use the available device.
max_length (int): The maximum length of the input text.
Returns:
list: A list of dictionaries containing the prediction and probabilities for each input text.
Each dictionary contains:
- 'text': The input text.
- 'prediction': The predicted class (0 or 1).
- 'probabilities': The probabilities for each class.
"""
if device is None:
if torch.cuda.is_available():
device = 'cuda'
elif torch.backends.mps.is_available():
device = 'mps'
else:
device = 'cpu'
self.eval()
self.to(device)
if isinstance(texts, str):
texts = [texts]
encoded_inputs = self.tokenizer(
texts,
padding=True,
truncation=True,
return_tensors="pt"
).to(device)
elif isinstance(texts, list) and all(isinstance(item, list) for item in texts):
encoded_inputs = self.tokenizer(
[item[0] for item in texts],
[item[1] for item in texts],
padding=True,
truncation=True,
return_tensors="pt"
).to(device)
elif isinstance(texts, list) and all(isinstance(item, str) for item in texts):
encoded_inputs = self.tokenizer(
texts,
padding=True,
truncation=True,
return_tensors="pt"
).to(device)
else:
raise ValueError("Invalid input type. Expected str or list of str.")
input_ids = encoded_inputs['input_ids']
attention_mask = encoded_inputs['attention_mask']
token_type_ids = encoded_inputs.get('token_type_ids', None)
with torch.no_grad():
logits = self(input_ids, attention_mask, token_type_ids)
probs = torch.softmax(logits, dim=-1)
preds = torch.argmax(probs, dim=-1)
results = []
for i, text in enumerate(texts):
results.append({
'text': text,
'prediction': preds[i].item(),
'probabilities': probs[i].cpu().tolist()
})
return results