Spaces:
Running
Running
| import numpy as np | |
| import tensorflow as tf | |
| import tensorflow_probability as tfp | |
| tfd = tfp.distributions | |
| from easydict import EasyDict as edict | |
| from .base import BaseModel | |
| from .pc_encoder import * | |
| from .flow.transforms import Transform | |
| from .cVAE import CondVAE | |
| class ACSetVAE(BaseModel): | |
| def __init__(self, hps): | |
| self.prior_net = LatentEncoder(hps, name='prior') | |
| self.posterior_net = LatentEncoder(hps, name='posterior') | |
| self.cvae = CondVAE(hps.vae_params, name="cvae") | |
| if hps.use_peq_embed: | |
| self.peq_embed = SetXformer(hps) | |
| super(ACSetVAE, self).__init__(hps) | |
| def build_net(self): | |
| with tf.compat.v1.variable_scope('acset_vae', reuse=tf.compat.v1.AUTO_REUSE): | |
| self.x = tf.compat.v1.placeholder(tf.float32, [None, self.hps.set_size, self.hps.dimension]) #[200, 1024] | |
| self.b = tf.compat.v1.placeholder(tf.float32, [None, self.hps.set_size, self.hps.dimension]) | |
| self.m = tf.compat.v1.placeholder(tf.float32, [None, self.hps.set_size, self.hps.dimension]) | |
| # build transform | |
| self.transform = Transform(edict(self.hps.trans_params)) | |
| # prior | |
| prior_inputs = tf.concat([self.x*self.b, self.b], axis=-1) # [200, 2048] | |
| prior = self.prior_net(prior_inputs) # [256] | |
| prior_sample = prior.sample() # [256] | |
| prior_sample, _ = self.transform.inverse(prior_sample) # [256] | |
| # peq embedding | |
| cm = peq_embed = None | |
| if self.hps.use_peq_embed: | |
| peq_embed = self.peq_embed(prior_inputs) # [200, 256] | |
| C = peq_embed.get_shape().as_list()[-1] # 256 | |
| cm = tf.reshape(peq_embed, [-1,C]) # [256] | |
| # posterior | |
| posterior_inputs = tf.concat([self.x*self.m, self.m], axis=-1) # [200, 2048] | |
| posterior = self.posterior_net(posterior_inputs) # [256] | |
| posterior_sample = posterior.sample() # [256] | |
| # kl term | |
| z_sample, logdet = self.transform.forward(posterior_sample) # [256], [1] | |
| logp = tf.reduce_sum(input_tensor=prior.log_prob(z_sample), axis=1) + logdet # [256] -> [1] | |
| kl = tf.reduce_sum(input_tensor=posterior.entropy(), axis=1) + logp # [256] -> [1] | |
| # generator | |
| x = tf.reshape(self.x, [-1,self.hps.dimension]) # [1024] | |
| b = tf.reshape(self.b, [-1,self.hps.dimension]) | |
| m = tf.reshape(self.m, [-1,self.hps.dimension]) | |
| cv = tf.reshape(tf.tile(tf.expand_dims(posterior_sample, axis=1), [1,self.hps.set_size,1]), [-1,self.hps.latent_dim]) # [256] -> [1, 256] -> [1, 200, 256] -> [256] | |
| if not cm is None: | |
| c = tf.concat([cv, cm], axis=-1) # [512] | |
| else: | |
| c = cv | |
| # vector-wise posterior | |
| vec_kl, vec_post_sample = self.cvae.enc(tf.concat([x, c], axis=-1), c) # [1], [256] | |
| vec_kl = tf.reshape(vec_kl, [-1, self.hps.set_size]) # [200] | |
| recon_dist = self.cvae.dec(tf.concat([vec_post_sample, c], axis=-1), c) | |
| log_likel = recon_dist.log_prob(x) # [1] | |
| log_likel = tf.reshape(log_likel, [-1,self.hps.set_size]) # [200] | |
| self.set_metric = self.set_elbo = log_likel + tf.expand_dims(kl, axis=1) / self.hps.set_size # [200] | |
| log_likel = tf.reduce_mean(input_tensor=log_likel, axis=1) | |
| tf.compat.v1.summary.scalar('log_likel', tf.reduce_mean(input_tensor=log_likel)) | |
| # elbo | |
| self.elbo = log_likel + kl / self.hps.set_size + tf.reduce_mean(input_tensor=vec_kl, axis=1) | |
| self.metric = self.elbo | |
| self.loss = tf.reduce_mean(input_tensor=-self.elbo) | |
| tf.compat.v1.summary.scalar('loss', self.loss) | |
| # sample | |
| x = tf.reshape(self.x, [-1,self.hps.dimension]) | |
| b = tf.reshape(self.b, [-1,self.hps.dimension]) | |
| m = tf.reshape(self.m, [-1,self.hps.dimension]) | |
| cv = tf.reshape(tf.tile(tf.expand_dims(prior_sample, axis=1), [1,self.hps.set_size,1]), [-1,self.hps.latent_dim]) | |
| if not cm is None: | |
| c = tf.concat([cv, cm], axis=-1) | |
| else: | |
| c = cv | |
| vec_prior_sample = tf.random.normal(shape=tf.shape(input=vec_post_sample)) | |
| sample_dist = self.cvae.dec(tf.concat([vec_prior_sample, c], axis=-1), c) | |
| log_likel = sample_dist.log_prob(x) | |
| sample = sample_dist.sample() | |
| self.sample = tf.reshape(sample, [-1, self.hps.set_size, self.hps.dimension]) | |
| # compress | |
| self.log_likel = tf.reshape(log_likel, [-1,self.hps.set_size]) + vec_kl | |
| # total_params = 0 | |
| # for variable in tf.trainable_variables(): | |
| # shape = variable.get_shape() | |
| # params = 1 | |
| # for dim in shape: | |
| # params *= dim | |
| # total_params += params | |
| # print("Total number of parameters: ", total_params) |