Spaces:
Running
Running
| import numpy as np | |
| import tensorflow as tf | |
| def layer_norm(x): | |
| mean = tf.reduce_mean(input_tensor=x, axis=[1,2], keepdims=True) | |
| std = tf.math.reduce_std(x, axis=[1,2], keepdims=True) | |
| x = (x - mean) / std | |
| return x | |
| def set_attention(Q, K, dim, num_heads, name='set_attention'): | |
| with tf.compat.v1.variable_scope(name, reuse=tf.compat.v1.AUTO_REUSE): | |
| q = tf.compat.v1.layers.dense(Q, dim, name='query') | |
| k = tf.compat.v1.layers.dense(K, dim, name='key') | |
| v = tf.compat.v1.layers.dense(K, dim, name='value') | |
| q_ = tf.concat(tf.split(q, num_heads, axis=-1), axis=0) | |
| k_ = tf.concat(tf.split(k, num_heads, axis=-1), axis=0) | |
| v_ = tf.concat(tf.split(v, num_heads, axis=-1), axis=0) | |
| logits = tf.matmul(q_, k_, transpose_b=True)/np.sqrt(dim) # [B*Nh,Nq,Nk] | |
| A = tf.nn.softmax(logits, axis=-1) | |
| o = q_ + tf.matmul(A, v_) | |
| o = tf.concat(tf.split(o, num_heads, axis=0), axis=-1) | |
| # o = tf.contrib.layers.layer_norm(o) | |
| o = o + tf.compat.v1.layers.dense(o, dim, activation=tf.nn.relu, name='output') | |
| # o = tf.contrib.layers.layer_norm(o) | |
| return o | |
| def set_transformer(inputs, layer_sizes, name, num_heads=4, num_inds=16): | |
| with tf.compat.v1.variable_scope(name, reuse=tf.compat.v1.AUTO_REUSE): | |
| out = inputs | |
| for i, size in enumerate(layer_sizes): | |
| inds = tf.compat.v1.get_variable(f'inds_{i}', shape=[1,num_inds,size], dtype=tf.float32, trainable=True, | |
| initializer=tf.compat.v1.keras.initializers.VarianceScaling(scale=1.0, mode="fan_avg", distribution="uniform")) | |
| inds = tf.tile(inds, [tf.shape(input=out)[0],1,1]) | |
| tmp = set_attention(inds, out, size, num_heads, name=f'self_attn_{i}_pre') | |
| out = set_attention(out, tmp, size, num_heads, name=f'self_attn_{i}_post') | |
| return out | |
| def set_pooling(inputs, name, num_heads=4): | |
| B = tf.shape(input=inputs)[0] | |
| N = tf.shape(input=inputs)[1] | |
| d = inputs.get_shape().as_list()[-1] | |
| with tf.compat.v1.variable_scope(name, reuse=tf.compat.v1.AUTO_REUSE): | |
| seed = tf.compat.v1.get_variable('pool_seed', shape=[1,1,d], dtype=tf.float32, trainable=True, | |
| initializer=tf.compat.v1.keras.initializers.VarianceScaling(scale=1.0, mode="fan_avg", distribution="uniform")) | |
| seed = tf.tile(seed, [B,1,1]) | |
| out = set_attention(seed, inputs, d, num_heads, name='pool_attn') | |
| out = tf.squeeze(out, axis=1) | |
| return out | |