Spaces:
Running
Running
| import numpy as np | |
| import tensorflow as tf | |
| import tensorflow_probability as tfp | |
| tfd = tfp.distributions | |
| from .networks import dense_nn, cond_dense_nn | |
| class CondVAE(object): | |
| def __init__(self, hps, name="cvae"): | |
| self.hps = hps | |
| self.name = name | |
| def enc(self, x, cond=None): | |
| ''' | |
| x: [B, C] | |
| cond: [B, C] | |
| ''' | |
| B,C = tf.shape(input=x)[0], tf.shape(input=x)[1] | |
| with tf.compat.v1.variable_scope(self.name, reuse=tf.compat.v1.AUTO_REUSE): | |
| prior_dist = tfd.MultivariateNormalDiag(tf.zeros(self.hps['hid_dimensions']),tf.ones(self.hps['hid_dimensions'])) | |
| if cond is None: | |
| x = dense_nn(x, self.hps['enc_dense_hids'], 2 * self.hps['hid_dimensions'], False, "enc") | |
| else: | |
| x = cond_dense_nn(x, cond, self.hps['enc_dense_hids'], 2 * self.hps['hid_dimensions'], False, "enc") | |
| m, s = x[:, :self.hps['hid_dimensions']], tf.nn.softplus(x[:, self.hps['hid_dimensions']:]) | |
| posterior_dist = tfd.MultivariateNormalDiag(m,s) | |
| #kl = 0.5 * tf.reduce_sum(s + m ** 2 - 1.0 - tf.log(s), axis=-1) | |
| kl = - tfd.kl_divergence(posterior_dist, prior_dist) | |
| eps = prior_dist.sample(B) | |
| posterior_sample = m + eps * s | |
| return kl, posterior_sample | |
| def dec(self, x, cond=None): | |
| ''' | |
| x: [B, C] | |
| ''' | |
| B,C = tf.shape(input=x)[0], tf.shape(input=x)[1] | |
| with tf.compat.v1.variable_scope(self.name, reuse=tf.compat.v1.AUTO_REUSE): | |
| if cond is None: | |
| x = dense_nn(x, self.hps['dec_dense_hids'], 2 * self.hps['dimension'], False, "dec") | |
| else: | |
| x = cond_dense_nn(x, cond, self.hps['dec_dense_hids'], 2 * self.hps['dimension'], False, "dec") | |
| m, s = x[:, :self.hps['dimension']], tf.nn.softplus(x[:, self.hps['dimension']:]) | |
| sample_dist = tfd.MultivariateNormalDiag(loc=m, scale_diag=s) | |
| return sample_dist |