Spaces:
Sleeping
Sleeping
| import os | |
| import logging | |
| from collections import defaultdict | |
| import numpy as np | |
| import pickle | |
| import tensorflow as tf | |
| from pprint import pformat | |
| from .utils import visualize, plot_functions, plot_img_functions | |
| class Runner(object): | |
| def __init__(self, args, model): | |
| self.args = args | |
| self.sess = model.sess | |
| self.model = model | |
| def set_dataset(self, trainset, validset, testset): | |
| self.trainset = trainset | |
| self.validset = validset | |
| self.testset = testset | |
| def train(self): | |
| train_metrics = [] | |
| num_batches = self.trainset.num_batches | |
| self.trainset.initialize() | |
| for i in range(num_batches): | |
| batch = self.trainset.next_batch() | |
| metric, summ, step, _ = self.model.execute( | |
| [self.model.metric, self.model.summ_op, | |
| self.model.global_step, self.model.train_op], | |
| batch) | |
| if (self.args.summ_freq > 0) and (i % self.args.summ_freq == 0): | |
| self.model.writer.add_summary(summ, step) | |
| train_metrics.append(metric) | |
| train_metrics = np.concatenate(train_metrics, axis=0) | |
| return np.mean(train_metrics) | |
| def valid(self): | |
| valid_metrics = [] | |
| num_batches = self.validset.num_batches | |
| self.validset.initialize() | |
| for i in range(num_batches): | |
| batch = self.validset.next_batch() | |
| metric = self.model.execute(self.model.metric, batch) | |
| valid_metrics.append(metric) | |
| valid_metrics = np.concatenate(valid_metrics, axis=0) | |
| return np.mean(valid_metrics) | |
| def valid_mse(self): | |
| valid_mse = [] | |
| num_batches = self.validset.num_batches | |
| self.validset.initialize() | |
| for i in range(num_batches): | |
| batch = self.validset.next_batch() | |
| sample = self.model.execute(self.model.sample, batch) | |
| mse = np.mean(np.sum(np.square(sample-batch['x']), axis=tuple(range(2,sample.ndim))), axis=1) | |
| valid_mse.append(mse) | |
| valid_mse = np.concatenate(valid_mse, axis=0) | |
| return np.mean(valid_mse) | |
| def valid_chd(self): | |
| pass | |
| def valid_emd(self): | |
| pass | |
| def test(self): | |
| test_metrics = [] | |
| num_batches = self.testset.num_batches | |
| self.testset.initialize() | |
| for i in range(num_batches): | |
| batch = self.testset.next_batch() | |
| metric = self.model.execute(self.model.metric, batch) | |
| test_metrics.append(metric) | |
| test_metrics = np.concatenate(test_metrics) | |
| return np.mean(test_metrics) | |
| def test_mse(self): | |
| test_mse = [] | |
| num_batches = self.testset.num_batches | |
| self.testset.initialize() | |
| for i in range(num_batches): | |
| batch = self.testset.next_batch() | |
| sample = self.model.execute(self.model.sample, batch) | |
| mse = np.mean(np.sum(np.square(sample-batch['x']), axis=tuple(range(2,sample.ndim))), axis=1) | |
| test_mse.append(mse) | |
| test_mse = np.concatenate(test_mse, axis=0) | |
| return np.mean(test_mse) | |
| def test_chd(self): | |
| pass | |
| def test_emd(self): | |
| pass | |
| def run(self): | |
| logging.info('==== start training ====') | |
| best_train_metric = -np.inf | |
| best_valid_metric = -np.inf | |
| best_test_metric = -np.inf | |
| for epoch in range(self.args.epochs): | |
| train_metric = self.train() | |
| valid_metric = self.valid() | |
| test_metric = self.test() | |
| # save | |
| if train_metric > best_train_metric: | |
| best_train_metric = train_metric | |
| if valid_metric > best_valid_metric: | |
| best_valid_metric = valid_metric | |
| self.model.save() | |
| if test_metric > best_test_metric: | |
| best_test_metric = test_metric | |
| logging.info("Epoch %d, train: %.4f/%.4f, valid: %.4f/%.4f test: %.4f/%.4f" % | |
| (epoch, train_metric, best_train_metric, | |
| valid_metric, best_valid_metric, | |
| test_metric, best_test_metric)) | |
| # evaluate | |
| if epoch % 100 == 0: | |
| logging.info('==== start evaluating ====') | |
| self.evaluate(folder=f'{epoch}', load=False) | |
| self.model.save('last') | |
| # finish | |
| logging.info('==== start evaluating ====') | |
| self.evaluate(load=True) | |
| def evaluate(self, folder='test', load=True): | |
| save_dir = f'{self.args.exp_dir}/evaluate/{folder}/' | |
| os.makedirs(save_dir, exist_ok=True) | |
| if load: self.model.load() | |
| # # likelihood | |
| if 'likel' in self.args.eval_metrics: | |
| valid_likel = self.valid() | |
| test_likel = self.test() | |
| logging.info(f"likelihood => valid: {valid_likel} test: {test_likel}") | |
| # # mse | |
| if 'mse' in self.args.eval_metrics: | |
| valid_mse = self.valid_mse() | |
| test_mse = self.test_mse() | |
| logging.info(f"mse => valid: {valid_mse} test: {test_mse}") | |
| if 'chd' in self.args.eval_metrics: | |
| valid_chd = self.valid_chd() | |
| test_chd = self.test_chd() | |
| logging.info(f"chd => valid: {valid_chd} test: {test_chd}") | |
| if 'emd' in self.args.eval_metrics: | |
| valid_emd = self.valid_emd() | |
| test_emd = self.test_emd() | |
| logging.info(f"emd => valid: {valid_emd} test: {test_emd}") | |
| if 'sam' in self.args.eval_metrics: | |
| # train set | |
| self.trainset.initialize() | |
| batch = self.trainset.next_batch() | |
| train_sample = self.model.execute(self.model.sample, batch) | |
| visualize(train_sample, batch, f'{save_dir}/train_sam') | |
| # valid set | |
| self.validset.initialize() | |
| batch = self.validset.next_batch() | |
| valid_sample = self.model.execute(self.model.sample, batch) | |
| visualize(valid_sample, batch, f'{save_dir}/valid_sam') | |
| # test set | |
| self.testset.initialize() | |
| batch = self.testset.next_batch() | |
| test_sample = self.model.execute(self.model.sample, batch) | |
| visualize(test_sample, batch, f'{save_dir}/test_sam') | |
| if 'fns' in self.args.eval_metrics: | |
| # train set | |
| self.trainset.initialize() | |
| batch = self.trainset.next_batch() | |
| train_mean, train_std = self.model.execute([self.model.mean, self.model.std], batch) | |
| plot_functions(train_mean, train_std, batch, f'{save_dir}/train_fn') | |
| # valid set | |
| self.validset.initialize() | |
| batch = self.validset.next_batch() | |
| valid_mean, valid_std = self.model.execute([self.model.mean, self.model.std], batch) | |
| plot_functions(valid_mean, valid_std, batch, f'{save_dir}/valid_fn') | |
| # test set | |
| self.testset.initialize() | |
| batch = self.testset.next_batch() | |
| test_mean, test_std = self.model.execute([self.model.mean, self.model.std], batch) | |
| plot_functions(test_mean, test_std, batch, f'{save_dir}/test_fn') | |
| if 'imfns' in self.args.eval_metrics: | |
| # train set | |
| self.trainset.initialize() | |
| batch = self.trainset.next_batch() | |
| train_mean, train_std = self.model.execute([self.model.mean, self.model.std], batch) | |
| plot_img_functions(train_mean, train_std, batch, f'{save_dir}/train_fn') | |
| # valid set | |
| self.validset.initialize() | |
| batch = self.validset.next_batch() | |
| valid_mean, valid_std = self.model.execute([self.model.mean, self.model.std], batch) | |
| plot_img_functions(valid_mean, valid_std, batch, f'{save_dir}/valid_fn') | |
| # test set | |
| self.testset.initialize() | |
| batch = self.testset.next_batch() | |
| test_mean, test_std = self.model.execute([self.model.mean, self.model.std], batch) | |
| plot_img_functions(test_mean, test_std, batch, f'{save_dir}/test_fn') | |