Spaces:
Running
Running
| import logging | |
| from pprint import pformat | |
| import numpy as np | |
| import tensorflow as tf | |
| class BaseModel(object): | |
| def __init__(self, hps): | |
| super(BaseModel, self).__init__() | |
| self.hps = hps | |
| g = tf.Graph() | |
| with g.as_default(): | |
| # open a session | |
| config = tf.compat.v1.ConfigProto() | |
| config.log_device_placement = True | |
| config.allow_soft_placement = True | |
| config.gpu_options.allow_growth = True | |
| self.sess = tf.compat.v1.Session(config=config, graph=g) | |
| # build model | |
| self.build_net() | |
| self.build_ops() | |
| # initialize | |
| self.sess.run(tf.compat.v1.global_variables_initializer()) | |
| self.saver = tf.compat.v1.train.Saver() | |
| self.writer = tf.compat.v1.summary.FileWriter(self.hps.exp_dir + '/summary') | |
| # logging | |
| total_params = 0 | |
| trainable_variables = tf.compat.v1.trainable_variables() | |
| logging.info('=' * 20) | |
| logging.info("Variables:") | |
| logging.info(pformat(trainable_variables)) | |
| for v in trainable_variables: | |
| num_params = np.prod(v.get_shape().as_list()) | |
| total_params += num_params | |
| logging.info("TOTAL TENSORS: %d TOTAL PARAMS: %f[M]" % ( | |
| len(trainable_variables), total_params / 1e6)) | |
| logging.info('=' * 20) | |
| def save(self, filename='params'): | |
| fname = f'{self.hps.exp_dir}/weights/{filename}.ckpt' | |
| self.saver.save(self.sess, fname) | |
| def load(self, filename='params'): | |
| fname = f'{self.hps.exp_dir}/weights/{filename}.ckpt' | |
| self.saver.restore(self.sess, fname) | |
| def build_net(self): | |
| raise NotImplementedError() | |
| def build_ops(self): | |
| # optimizer | |
| self.global_step = tf.compat.v1.train.get_or_create_global_step() | |
| learning_rate = tf.compat.v1.train.inverse_time_decay( | |
| self.hps.lr, self.global_step, | |
| self.hps.decay_steps, self.hps.decay_rate, | |
| staircase=True) | |
| warmup_lr = tf.compat.v1.train.inverse_time_decay( | |
| 0.001 * self.hps.lr, self.global_step, | |
| self.hps.decay_steps, self.hps.decay_rate, | |
| staircase=True) | |
| learning_rate = tf.cond(pred=tf.less(self.global_step, 1000), true_fn=lambda: warmup_lr, false_fn=lambda: learning_rate) | |
| tf.compat.v1.summary.scalar('lr', learning_rate) | |
| if self.hps.optimizer == 'adam': | |
| optimizer = tf.compat.v1.train.AdamOptimizer( | |
| learning_rate=learning_rate, | |
| beta1=0.9, beta2=0.999, epsilon=1e-08, | |
| use_locking=False, name="Adam") | |
| elif self.hps.optimizer == 'rmsprop': | |
| optimizer = tf.compat.v1.train.RMSPropOptimizer( | |
| learning_rate=learning_rate) | |
| elif self.hps.optimizer == 'mom': | |
| optimizer = tf.compat.v1.train.MomentumOptimizer( | |
| learning_rate=learning_rate, | |
| momentum=0.9) | |
| else: | |
| optimizer = tf.compat.v1.train.GradientDescentOptimizer( | |
| learning_rate=learning_rate) | |
| # regularization | |
| l2_reg = sum( | |
| [tf.reduce_sum(input_tensor=tf.square(v)) for v in tf.compat.v1.trainable_variables() | |
| if ("magnitude" in v.name) or ("rescaling_scale" in v.name)]) | |
| reg_loss = 0.00005 * l2_reg | |
| # train | |
| grads_and_vars = optimizer.compute_gradients( | |
| self.loss+reg_loss, tf.compat.v1.trainable_variables()) | |
| grads, vars_ = zip(*grads_and_vars) | |
| if self.hps.clip_gradient > 0: | |
| grads, gradient_norm = tf.clip_by_global_norm( | |
| grads, clip_norm=self.hps.clip_gradient) | |
| gradient_norm = tf.debugging.check_numerics( | |
| gradient_norm, "Gradient norm is NaN or Inf.") | |
| tf.compat.v1.summary.scalar('gradient_norm', gradient_norm) | |
| capped_grads_and_vars = zip(grads, vars_) | |
| self.train_op = optimizer.apply_gradients( | |
| capped_grads_and_vars, global_step=self.global_step) | |
| # summary | |
| self.summ_op = tf.compat.v1.summary.merge_all() | |
| def execute(self, cmd, batch): | |
| return self.sess.run(cmd, {self.x:batch['x'], self.b:batch['b'], self.m:batch['m']}) | |