Dinars34's picture
Upload 60 files
89c5d90 verified
from tensorflow.keras import callbacks
from tensorflow.keras import layers, regularizers
from tensorflow.keras import optimizers, metrics, losses
from tensorflow.keras.models import Model
from tensorflow.keras.models import load_model
from tensorflow.keras import backend as K
from tensorflow import keras
import tensorflow as tf
l2 = regularizers.l2
w_decay=1e-8 #0.0#2e-4#1e-3, 2e-4 # please define weight decay
K.clear_session()
# weight_init = tf.initializers.RandomNormal(mean=0.,stddev=0.01)
# weight_init = tf.initializers.glorot_normal()
weight_init = tf.initializers.glorot_normal()
class _DenseLayer(layers.Layer):
"""_DenseBlock model.
Arguments:
out_features: number of output features
"""
def __init__(self, out_features,**kwargs):
super(_DenseLayer, self).__init__(**kwargs)
k_reg = None if w_decay is None else l2(w_decay)
self.layers = []
self.layers.append(tf.keras.Sequential(
[
layers.ReLU(),
layers.Conv2D(
filters=out_features, kernel_size=(3,3), strides=(1,1), padding='same',
use_bias=True, kernel_initializer=weight_init,
kernel_regularizer=k_reg),
layers.BatchNormalization(),
layers.ReLU(),
layers.Conv2D(
filters=out_features, kernel_size=(3,3), strides=(1,1), padding='same',
use_bias=True, kernel_initializer=weight_init,
kernel_regularizer=k_reg),
layers.BatchNormalization(),
])) # first relu can be not needed
def call(self, inputs):
x1, x2 = tuple(inputs)
new_features = x1
for layer in self.layers:
new_features = layer(new_features)
return 0.5 * (new_features + x2), x2
class _DenseBlock(layers.Layer):
"""DenseBlock layer.
Arguments:
num_layers: number of _DenseLayer's per block
out_features: number of output features
"""
def __init__(self,
num_layers,
out_features,**kwargs):
super(_DenseBlock, self).__init__(**kwargs)
self.layers = [_DenseLayer(out_features) for i in range(num_layers)]
def call(self, inputs):
for layer in self.layers:
inputs = layer(inputs)
return inputs
class UpConvBlock(layers.Layer):
"""UpConvDeconvBlock layer.
Arguments:
up_scale: int
"""
def __init__(self, up_scale,**kwargs):
super(UpConvBlock, self).__init__(**kwargs)
constant_features = 16
k_reg = None if w_decay is None else l2(w_decay)
features = []
total_up_scale = 2 ** up_scale
for i in range(up_scale):
out_features = 1 if i == up_scale-1 else constant_features
if i==up_scale-1:
features.append(layers.Conv2D(
filters=out_features, kernel_size=(1,1), strides=(1,1), padding='same',
activation='relu', kernel_initializer=tf.initializers.RandomNormal(mean=0.),
kernel_regularizer=k_reg,use_bias=True)) #tf.initializers.TruncatedNormal(mean=0.)
features.append(layers.Conv2DTranspose(
out_features, kernel_size=(total_up_scale,total_up_scale),
strides=(2,2), padding='same',
kernel_initializer=tf.initializers.RandomNormal(stddev=0.1),
kernel_regularizer=k_reg,use_bias=True)) # stddev=0.1
else:
features.append(layers.Conv2D(
filters=out_features, kernel_size=(1,1), strides=(1,1), padding='same',
activation='relu',kernel_initializer=weight_init,
kernel_regularizer=k_reg,use_bias=True))
features.append(layers.Conv2DTranspose(
out_features, kernel_size=(total_up_scale,total_up_scale),
strides=(2,2), padding='same', use_bias=True,
kernel_initializer=weight_init, kernel_regularizer=k_reg))
self.features = keras.Sequential(features)
def call(self, inputs):
return self.features(inputs)
class SingleConvBlock(layers.Layer):
"""SingleConvBlock layer.
Arguments:
out_features: number of output features
stride: stride per convolution
"""
def __init__(self, out_features, k_size=(1,1),stride=(1,1),
use_bs=False, use_act=False,w_init=None,**kwargs): # bias_init=tf.constant_initializer(0.0)
super(SingleConvBlock, self).__init__(**kwargs)
self.use_bn = use_bs
self.use_act = use_act
k_reg = None if w_decay is None else l2(w_decay)
self.conv = layers.Conv2D(
filters=out_features, kernel_size=k_size, strides=stride,
padding='same',kernel_initializer=w_init,
kernel_regularizer=k_reg)#, use_bias=True, bias_initializer=bias_init
if self.use_bn:
self.bn = layers.BatchNormalization()
if self.use_act:
self.relu = layers.ReLU()
def call(self, inputs):
x =self.conv(inputs)
if self.use_bn:
x = self.bn(x)
if self.use_act:
x = self.relu(x)
return x
class DoubleConvBlock(layers.Layer):
"""DoubleConvBlock layer.
Arguments:
mid_features: number of middle features
out_features: number of output features
stride: stride per mid-layer convolution
"""
def __init__(self, mid_features, out_features=None, stride=(1,1),
use_bn=True,use_act=True,**kwargs):
super(DoubleConvBlock, self).__init__(**kwargs)
self.use_bn =use_bn
self.use_act =use_act
out_features = mid_features if out_features is None else out_features
k_reg = None if w_decay is None else l2(w_decay)
self.conv1 = layers.Conv2D(
filters=mid_features, kernel_size=(3, 3), strides=stride, padding='same',
use_bias=True, kernel_initializer=weight_init,
kernel_regularizer=k_reg)
self.bn1 = layers.BatchNormalization()
self.conv2 = layers.Conv2D(
filters=out_features, kernel_size=(3, 3), padding='same',strides=(1,1),
use_bias=True, kernel_initializer=weight_init,
kernel_regularizer=k_reg)
self.bn2 = layers.BatchNormalization()
self.relu = layers.ReLU()
def call(self, inputs):
x = self.conv1(inputs)
x = self.bn1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.bn2(x)
if self.use_act:
x = self.relu(x)
return x
class DexiNed(tf.keras.Model):
"""DexiNet model."""
def __init__(self,rgb_mean=None,
**kwargs):
super(DexiNed, self).__init__(**kwargs)
self.rgbn_mean = rgb_mean
self.block_1 = DoubleConvBlock(32, 64, stride=(2,2),use_act=False)
self.block_2 = DoubleConvBlock(128,use_act=False)
self.dblock_3 = _DenseBlock(2, 256)
self.dblock_4 = _DenseBlock(3, 512)
self.dblock_5 = _DenseBlock(3, 512)
self.dblock_6 = _DenseBlock(3, 256)
self.maxpool = layers.MaxPool2D(pool_size=(3, 3), strides=2, padding='same')
# left skip connections, figure in Journal
self.side_1 = SingleConvBlock(128,k_size=(1,1),stride=(2,2),use_bs=True,
w_init=weight_init)
self.side_2 = SingleConvBlock(256,k_size=(1,1),stride=(2,2),use_bs=True,
w_init=weight_init)
self.side_3 = SingleConvBlock(512,k_size=(1,1),stride=(2,2),use_bs=True,
w_init=weight_init)
self.side_4 = SingleConvBlock(512,k_size=(1,1),stride=(1,1),use_bs=True,
w_init=weight_init)
# self.side_5 = SingleConvBlock(256,k_size=(1,1),stride=(1,1),use_bs=True,
# w_init=weight_init)
# right skip connections, figure in Journal paper
self.pre_dense_2 = SingleConvBlock(256,k_size=(1,1),stride=(2,2),
w_init=weight_init) # use_bn=True
self.pre_dense_3 = SingleConvBlock(256,k_size=(1,1),stride=(1,1),use_bs=True,
w_init=weight_init)
self.pre_dense_4 = SingleConvBlock(512,k_size=(1,1),stride=(1,1),use_bs=True,
w_init=weight_init)
# self.pre_dense_5_0 = SingleConvBlock(512, k_size=(1,1),stride=(2,2),
# w_init=weight_init) # use_bn=True
self.pre_dense_5 = SingleConvBlock(512,k_size=(1,1),stride=(1,1),use_bs=True,
w_init=weight_init)
self.pre_dense_6 = SingleConvBlock(256,k_size=(1,1),stride=(1,1),use_bs=True,
w_init=weight_init)
# USNet
self.up_block_1 = UpConvBlock(1)
self.up_block_2 = UpConvBlock(1)
self.up_block_3 = UpConvBlock(2)
self.up_block_4 = UpConvBlock(3)
self.up_block_5 = UpConvBlock(4)
self.up_block_6 = UpConvBlock(4)
self.block_cat = SingleConvBlock(
1,k_size=(1,1),stride=(1,1),
w_init=tf.constant_initializer(1/5))
def slice(self, tensor, slice_shape):
height, width = slice_shape
return tensor[..., :height, :width]
def call(self, x):
# Block 1
x = x-self.rgbn_mean[:-1]
block_1 = self.block_1(x)
block_1_side = self.side_1(block_1)
# Block 2
block_2 = self.block_2(block_1)
block_2_down = self.maxpool(block_2) # the key for the second skip connec...
block_2_add = block_2_down + block_1_side
block_2_side = self.side_2(block_2_add) #
# Block 3
block_3_pre_dense = self.pre_dense_3(block_2_down)
block_3, _ = self.dblock_3([block_2_add, block_3_pre_dense])
block_3_down = self.maxpool(block_3)
block_3_add = block_3_down + block_2_side
block_3_side = self.side_3(block_3_add)
# Block 4
block_4_pre_dense_256 = self.pre_dense_2(block_2_down)
block_4_pre_dense = self.pre_dense_4(block_4_pre_dense_256 + block_3_down)
block_4, _ = self.dblock_4([block_3_add, block_4_pre_dense])
block_4_down = self.maxpool(block_4)
block_4_add = block_4_down + block_3_side
block_4_side = self.side_4(block_4_add)
# Block 5
# block_5_pre_dense_512 = self.pre_dense_5_0(block_4_pre_dense_256)
block_5_pre_dense = self.pre_dense_5(block_4_down )
block_5, _ = self.dblock_5([block_4_add, block_5_pre_dense])
block_5_add = block_5 + block_4_side
# Block 6
block_6_pre_dense = self.pre_dense_6(block_5)
block_6, _ = self.dblock_6([block_5_add, block_6_pre_dense])
# upsampling blocks
height, width = x.shape[1:3]
slice_shape = (height, width)
out_1 = self.up_block_1(block_1) # self.slice(, slice_shape)
out_2 = self.up_block_2(block_2)
out_3 = self.up_block_3(block_3)
out_4 = self.up_block_4(block_4)
out_5 = self.up_block_5(block_5)
out_6 = self.up_block_6(block_6)
results = [out_1, out_2, out_3, out_4, out_5, out_6]
# concatenate multiscale outputs
block_cat = tf.concat(results, 3) # BxHxWX6
block_cat = self.block_cat(block_cat) # BxHxWX1
results.append(block_cat)
return results
def weighted_cross_entropy_loss(input, label):
y = tf.cast(label,dtype=tf.float32)
negatives = tf.math.reduce_sum(1.-y)
positives = tf.math.reduce_sum(y)
beta = negatives/(negatives + positives)
pos_w = beta/(1-beta)
cost = tf.nn.weighted_cross_entropy_with_logits(
labels=label, logits=input, pos_weight=pos_w, name=None)
cost = tf.reduce_sum(cost*(1-beta))
return tf.where(tf.equal(positives, 0.0), 0.0, cost)
def pre_process_binary_cross_entropy(bc_loss,input, label,arg, use_tf_loss=False):
# preprocess data
y = label
loss = 0
w_loss=1.0
preds = []
for tmp_p in input:
# tmp_p = input[i]
# loss processing
tmp_y = tf.cast(y, dtype=tf.float32)
mask = tf.dtypes.cast(tmp_y > 0., tf.float32)
b,h,w,c=mask.get_shape()
positives = tf.math.reduce_sum(mask, axis=[1, 2, 3], keepdims=True)
negatives = h*w*c-positives
beta2 = (1.*positives) / (negatives + positives) # negatives in hed
beta = (1.1*negatives)/ (positives + negatives) # positives in hed
pos_w = tf.where(tf.equal(y, 0.0), beta2, beta)
logits = tf.sigmoid(tmp_p)
l_cost = bc_loss(y_true=tmp_y, y_pred=logits,
sample_weight=pos_w)
preds.append(logits)
loss += (l_cost*w_loss)
return preds, loss