File size: 2,902 Bytes
b265364 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
import numpy as np
import torch
import matplotlib.pyplot as plt
import torch.nn as nn
import time
from util.time import *
from util.env import *
from sklearn.metrics import mean_squared_error
from pipeline.test import *
import torch.nn.functional as F
import numpy as np
from pipeline.evaluate import get_best_performance_data, get_val_performance_data, get_full_err_scores
from sklearn.metrics import precision_score, recall_score, roc_auc_score, f1_score
from torch.utils.data import DataLoader, random_split, Subset
from scipy.stats import iqr
def loss_func(y_pred, y_true):
loss = F.mse_loss(y_pred, y_true, reduction='mean')
return loss
def train(model = None, save_path = '', config={}, train_dataloader=None, val_dataloader=None, feature_map={}, test_dataloader=None, test_dataset=None, dataset_name='swat', train_dataset=None):
seed = config['seed']
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=config['decay'])
now = time.time()
train_loss_list = []
cmp_loss_list = []
device = get_device()
acu_loss = 0
min_loss = 1e+8
min_f1 = 0
min_pre = 0
best_prec = 0
i = 0
epoch = config['epoch']
early_stop_win = 15
model.train()
log_interval = 1000
stop_improve_count = 0
dataloader = train_dataloader
for i_epoch in range(epoch):
acu_loss = 0
model.train()
for x, labels, attack_labels, edge_index in dataloader:
_start = time.time()
x, labels, edge_index = [item.float().to(device) for item in [x, labels, edge_index]]
optimizer.zero_grad()
out = model(x, edge_index).float().to(device)
loss = loss_func(out, labels)
loss.backward()
optimizer.step()
train_loss_list.append(loss.item())
acu_loss += loss.item()
i += 1
# each epoch
print('epoch ({} / {}) (Loss:{:.8f}, ACU_loss:{:.8f})'.format(
i_epoch, epoch,
acu_loss/len(dataloader), acu_loss), flush=True
)
# use val dataset to judge
if val_dataloader is not None:
val_loss, val_result = test(model, val_dataloader)
if val_loss < min_loss:
torch.save(model.state_dict(), save_path)
min_loss = val_loss
stop_improve_count = 0
else:
stop_improve_count += 1
if stop_improve_count >= early_stop_win:
break
else:
if acu_loss < min_loss :
torch.save(model.state_dict(), save_path)
min_loss = acu_loss
return train_loss_list
|