Commit 12f23d37 authored by Yipeng Hu's avatar Yipeng Hu

ref #5 running training steps

parent ad2f9cb5
......@@ -31,8 +31,8 @@ def downsample_maxpool(input, filters):
y = conv3d(input, filters)
return tf.nn.max_pool3d(y, ksize=[1,3,3,3,1], padding='SAME', strides=[1,2,2,2,1])
def deconv3d(input, filters, batch_norm=False):
y = tf.nn.conv3d_transpose(input, filters, strides=[1,2,2,2,1], padding='SAME')
def deconv3d(input, filters, out_shape, batch_norm=False):
y = tf.nn.conv3d_transpose(input, filters, output_shape=out_shape, strides=[1,2,2,2,1], padding='SAME')
if batch_norm: y = batch_norm(y)
return tf.nn.relu(y) # where bn can be added
......@@ -56,40 +56,40 @@ num_channels = 32
nc = [num_channels*(2**i) for i in range(4)]
var_list=[]
# intial-layer
var_list = add_variable([5,5,1,nc[0]], var_list)
var_list = add_variable([5,5,5,1,nc[0]], var_list)
# encoder-s0
var_list = add_variable([3,3,nc[0],nc[0]], var_list)
var_list = add_variable([3,3,nc[0],nc[0],2], var_list)
var_list = add_variable([3,3,nc[0],nc[0]], var_list)
var_list = add_variable([3,3,nc[0],nc[1]], var_list)
var_list = add_variable([3,3,3,nc[0],nc[0],2], var_list)
var_list = add_variable([3,3,3,nc[0],nc[0],2], var_list)
var_list = add_variable([3,3,3,nc[0],nc[0]], var_list)
var_list = add_variable([3,3,3,nc[0],nc[1]], var_list)
# encoder-s1
var_list = add_variable([3,3,nc[1],nc[1],2], var_list)
var_list = add_variable([3,3,nc[1],nc[1],2], var_list)
var_list = add_variable([3,3,nc[1],nc[1]], var_list)
var_list = add_variable([3,3,nc[1],nc[2]], var_list)
var_list = add_variable([3,3,3,nc[1],nc[1],2], var_list)
var_list = add_variable([3,3,3,nc[1],nc[1],2], var_list)
var_list = add_variable([3,3,3,nc[1],nc[1]], var_list)
var_list = add_variable([3,3,3,nc[1],nc[2]], var_list)
# encoder-s2
var_list = add_variable([3,3,nc[2],nc[2],2], var_list)
var_list = add_variable([3,3,nc[2],nc[2],2], var_list)
var_list = add_variable([3,3,nc[2],nc[2]], var_list)
var_list = add_variable([3,3,nc[2],nc[3]], var_list)
var_list = add_variable([3,3,3,nc[2],nc[2],2], var_list)
var_list = add_variable([3,3,3,nc[2],nc[2],2], var_list)
var_list = add_variable([3,3,3,nc[2],nc[2]], var_list)
var_list = add_variable([3,3,3,nc[2],nc[3]], var_list)
# deep-layers-s3
var_list = add_variable([3,3,nc[3],nc[3]], var_list)
var_list = add_variable([3,3,nc[3],nc[3],2], var_list)
var_list = add_variable([3,3,nc[3],nc[3]], var_list)
var_list = add_variable([3,3,3,nc[3],nc[3],2], var_list)
var_list = add_variable([3,3,3,nc[3],nc[3],2], var_list)
var_list = add_variable([3,3,3,nc[3],nc[3],2], var_list)
# decoder-s2
var_list = add_variable([3,3,nc[2],nc[3]], var_list)
var_list = add_variable([3,3,nc[2],nc[2],2], var_list)
var_list = add_variable([3,3,nc[2],nc[2],2], var_list)
var_list = add_variable([3,3,3,nc[2],nc[3]], var_list)
var_list = add_variable([3,3,3,nc[2],nc[2],2], var_list)
var_list = add_variable([3,3,3,nc[2],nc[2],2], var_list)
# decoder-s1
var_list = add_variable([3,3,nc[1],nc[2]], var_list)
var_list = add_variable([3,3,nc[1],nc[1],2], var_list)
var_list = add_variable([3,3,nc[1],nc[1],2], var_list)
var_list = add_variable([3,3,3,nc[1],nc[2]], var_list)
var_list = add_variable([3,3,3,nc[1],nc[1],2], var_list)
var_list = add_variable([3,3,3,nc[1],nc[1],2], var_list)
# decoder-s0
var_list = add_variable([3,3,nc[0],nc[1]], var_list)
var_list = add_variable([3,3,nc[0],nc[0],2], var_list)
var_list = add_variable([3,3,nc[0],nc[0],2], var_list)
var_list = add_variable([3,3,3,nc[0],nc[1]], var_list)
var_list = add_variable([3,3,3,nc[0],nc[0],2], var_list)
var_list = add_variable([3,3,3,nc[0],nc[0],2], var_list)
# output-layer
var_list = add_variable([3,3,nc[0],1], var_list)
var_list = add_variable([3,3,3,nc[0],1], var_list)
## model with corresponding layers
def residual_unet(input):
......@@ -99,36 +99,36 @@ def residual_unet(input):
layer = conv3d(input, var_list[0])
# encoder-s0
layer = resnet_block(layer, var_list[1])
layer = resnet_block(layer, var_list[2])
layer = downsample_maxpool(layer, var_list[3])
layer = resnet_block(layer, var_list[2])
skip_layers.append(layer)
layer = downsample_maxpool(layer, var_list[3])
layer = conv3d(layer, var_list[4])
# encoder-s1
layer = resnet_block(layer, var_list[5])
layer = resnet_block(layer, var_list[6])
layer = downsample_maxpool(layer, var_list[7])
layer = resnet_block(layer, var_list[6])
skip_layers.append(layer)
layer = downsample_maxpool(layer, var_list[7])
layer = conv3d(layer, var_list[8])
# encoder-s2
layer = resnet_block(layer, var_list[9])
layer = resnet_block(layer, var_list[10])
layer = downsample_maxpool(layer, var_list[11])
layer = resnet_block(layer, var_list[10])
skip_layers.append(layer)
layer = downsample_maxpool(layer, var_list[11])
layer = conv3d(layer, var_list[12])
# deep-layers-s3
layer = resnet_block(layer, var_list[13])
layer = resnet_block(layer, var_list[14])
layer = resnet_block(layer, var_list[15])
# decoder-s2
layer = deconv3d(layer, var_list[16]) + skip_layers[2]
layer = deconv3d(layer, var_list[16], skip_layers[2].shape) + skip_layers[2]
layer = resnet_block(layer, var_list[17])
layer = resnet_block(layer, var_list[18])
# decoder-s1
layer = deconv3d(layer, var_list[19]) + skip_layers[1]
layer = deconv3d(layer, var_list[19], skip_layers[1].shape) + skip_layers[1]
layer = resnet_block(layer, var_list[20])
layer = resnet_block(layer, var_list[21])
# decoder-s0
layer = deconv3d(layer, var_list[22]) + skip_layers[0]
layer = deconv3d(layer, var_list[22], skip_layers[0].shape) + skip_layers[0]
layer = resnet_block(layer, var_list[23])
layer = resnet_block(layer, var_list[24])
# output-layer
......@@ -149,7 +149,6 @@ scale2 = tf.Variable(tf.ones([100]))
beta2 = tf.Variable(tf.zeros([100]))
BN2 = tf.nn.batch_normalization(z2_BN,batch_mean2,batch_var2,beta2,scale2,epsilon)
def batch_norm(inputs, is_training, decay = 0.999):
scale = tf.Variable(tf.ones([inputs.get_shape()[-1]]), trainable=True)
beta = tf.Variable(tf.zeros([inputs.get_shape()[-1]]), trainable=True)
......@@ -180,10 +179,10 @@ def loss_dice(pred, target, eps=1e-6):
return 1 - tf.reduce_mean(dice_numerator/dice_denominator)
def train_step(model, input, labels):
with tf.GradientTape() as g_tape:
with tf.GradientTape() as tape:
# g_tape.watched(var_list): trainable variables are automatically "watched".
current_loss = loss_dice(model(input), labels)
gradients = g_tape.gradient(current_loss, var_list)
gradients = tape.gradient(current_loss, var_list)
optimizer.apply_gradients(zip(gradients, var_list))
print(tf.reduce_mean(current_loss))
......@@ -205,7 +204,7 @@ class DataReader:
total_iter = int(1e6)
n = 50 # 50 training image-label pairs
size_minibatch = 4
path_to_data = './promise12'
path_to_data = '../../../promise12'
num_minibatch = int(n/size_minibatch) # how many minibatches in each epoch
indices_train = [i for i in range(n)]
......@@ -223,7 +222,18 @@ for step in range(total_iter):
minibatch_idx = step % num_minibatch # minibatch index
indices_mb = indices_train[minibatch_idx*size_minibatch:(minibatch_idx+1)*size_minibatch]
# update the variables
train_step(residual_unet, DataFeeder.load_images_train(indices_mb) , DataFeeder.load_labels_train(indices_mb) )
input_mb = DataFeeder.load_images_train(indices_mb)
label_mb = DataFeeder.load_labels_train(indices_mb)
with tf.GradientTape() as tape:
# tape.watched(var_list): trainable variables are automatically "watched".
current_loss = loss_dice(residual_unet(input_mb), label_mb)
gradients = tape.gradient(current_loss, var_list)
optimizer.apply_gradients(zip(gradients, var_list))
print(tf.reduce_mean(current_loss))
'''
# train_step(residual_unet, DataFeeder.load_images_train(indices_mb) , DataFeeder.load_labels_train(indices_mb) )
# print training information
if (step % 10) == 0:
......@@ -243,4 +253,4 @@ for step in range(total_iter):
for idx in range(size_minibatch):
np.save("./label_test%02d_step%06d.npy" % (indices_test[idx], step), layer1d_test[idx, ...])
print('Test results saved.')
\ No newline at end of file
'''
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment