Commit 52cf85e5 authored by Yipeng Hu's avatar Yipeng Hu

ref #5 variable machanism added

parent 65b05d1a
......@@ -10,40 +10,83 @@ import matplotlib.pyplot as plt
# https://colab.research.google.com/drive/1i-7Vn_9hGdOvMjkYNedK5nsonNizhe0o#scrollTo=z3w4D1V0SkZ8&forceEdit=true&sandboxMode=true
### Define a few functions for network layers
def conv3d(input, filters):
out = tf.nn.conv3d(input, filters, strides=[1,3,3,3,1], padding='SAME')
return tf.nn.relu(out)
y = tf.nn.conv3d(input, filters, strides=[1,2,2,2,1], padding='SAME')
return tf.nn.relu(y) # where bn can be added
def resnet_block(input, filters):
y = conv3d(input, filters[..., 0])
y = tf.nn.conv3d(y, filters[..., 1], strides=[1,2,2,2,1], padding='SAME')
return tf.nn.relu(y + input) # where bn can be added
def max_pool(input):
return tf.nn.max_pool3d(input, ksize=[1,3,3,3,1], padding='VALID', strides=[1,2,2,2,1])
def downsample_maxpool(input, filters):
y = conv3d(input, filters)
return tf.nn.max_pool3d(input, ksize=[1,3,3,3,1], padding='SAME', strides=[1,2,2,2,1])
def deconv3d(input, filters):
out = tf.nn.conv3d_transpose(input, filters, strides=[1,2,2,2,1], padding='SAME')
return tf.nn.relu(out)
return tf.nn.relu(out) # where bn can be added
def add_variable(var_shape, var_list, var_name=None, initialiser=None):
if initialiser is None:
initialiser = tf.initializers.glorot_normal()
if var_name is None:
var_name = 'var{}'.format(len(var_list))
var_list.append(tf.Variable(initialiser(var_shape), name=var_name, trainable=True))
return var_list
### Define the "weights" - trainable variables
# weights for encoder
var_list = []
num_channels = 32
n0 = num_channels
var_list = add_variable([3,3,1,n0], var_list) # x0
var_list = add_variable([3,3,n0,n0,2], var_list) # x1
var_list = add_variable([3,3,n0,n0,2], var_list) # x2
var_list = add_variable([3,3,n0,n0], var_list) # x3
### Define a model (the 3D U-Net) with residual layers
# ref: https://arxiv.org/abs/1512.03385 & https://arxiv.org/abs/1505.04597
def model(x):
layer_list = []
layer_list.append(downsample_maxpool(x, var_list[0]))
x1 = resnet_block(x0, var_list[1])
x2 = resnet_block(x1, var_list[2])
x3 = downsample_maxpool(x2, var_list[3])
x4 = resnet_block(x3, var_list[4])
x5 = resnet_block(x4, var_list[5])
x6 = downsample_maxpool(x5, var_list[6])
x7 = resnet_block(x6, var_list[7])
x8 = resnet_block(x7, var_list[8])
x9 = resnet_block(x8, var_list[9])
x10 = deconv3d(x9, var_list[10])
x11 = resnet_block(x10+x5, var_list[11]) # skip
x12 = resnet_block(x11, var_list[12])
x13 = deconv3d(x12, var_list[13])
x11 = resnet_block(x13+x2, var_list[11]) # skip
x12 = resnet_block(x11, var_list[12])
x13 = deconv3d(x12, var_list[13])
'''
def dense(input, weights):
x = tf.nn.relu(tf.matmul(input, weights))
return tf.nn.dropout(x, rate=0.5)
def resnet_block(input, filters1, filters2):
y = tf.nn.conv3d(input, filters1, strides=[1,2,2,2,1], padding='SAME')
y = tf.nn.relu(y) # where bn can be added
y = tf.nn.conv3d(y, filters2, strides=[1,2,2,2,1], padding='SAME')
return tf.nn.relu(y + input)
def up_sampling_deconv(input, filters):
out = tf.nn.conv3d_transpose(input, filters, strides=[1,2,2,2,1], padding='SAME')
return tf.nn.relu(out)
# https://r2rt.com/implementing-batch-normalization-in-tensorflow.html
'''
batch_mean2, batch_var2 = tf.nn.moments(z2_BN,[0])
scale2 = tf.Variable(tf.ones([100]))
beta2 = tf.Variable(tf.zeros([100]))
BN2 = tf.nn.batch_normalization(z2_BN,batch_mean2,batch_var2,beta2,scale2,epsilon)
'''
def batch_norm(inputs, is_training, decay = 0.999):
scale = tf.Variable(tf.ones([inputs.get_shape()[-1]]))
beta = tf.Variable(tf.zeros([inputs.get_shape()[-1]]))
......@@ -62,11 +105,4 @@ def batch_norm(inputs, is_training, decay = 0.999):
else:
return tf.nn.batch_normalization(inputs,
pop_mean, pop_var, beta, scale, epsilon)
### Define the "weights" - trainable variables
initialiser = tf.initializers.glorot_normal()
### Define a model (the 3D U-Net) with these layers
def model(x):
'''
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment