Commit 39d7413b authored by Yipeng Hu's avatar Yipeng Hu

ref #5 autograph examples added

parent f8e2933e
......@@ -49,7 +49,6 @@ def add_variable(var_shape, var_list, var_name=None, initialiser=None):
### Define a model (the 3D U-Net) with residual layers
### ref: https://arxiv.org/abs/1512.03385 & https://arxiv.org/abs/1505.04597
## define all the trinable weights
num_channels = 32
nc = [num_channels*(2**i) for i in range(4)]
......@@ -91,8 +90,8 @@ var_list = add_variable([3,3,3,nc[0],nc[0],2], var_list)
var_list = add_variable([3,3,3,nc[0],1], var_list)
## model with corresponding layers
@tf.function
def residual_unet(input):
# initial-layer
skip_layers = []
layer = conv3d(input, var_list[0])
......@@ -133,7 +132,6 @@ def residual_unet(input):
# output-layer
layer = tf.sigmoid(conv3d(layer, var_list[25], activation=False))
return layer
'''
......@@ -165,10 +163,6 @@ def batch_norm(inputs, is_training, decay = 0.999):
'''
### training
learning_rate = 1e-5
optimizer = tf.optimizers.Adam(learning_rate)
def loss_crossentropy(pred, target):
return tf.losses.BinaryCrossentropy(pred=pred, target=target)
......@@ -177,14 +171,6 @@ def loss_dice(pred, target, eps=1e-6):
dice_denominator = eps + tf.reduce_sum(pred, axis=[1,2,3,4]) + tf.reduce_sum(target, axis=[1,2,3,4])
return 1 - tf.reduce_mean(dice_numerator/dice_denominator)
def train_step(model, input, labels):
with tf.GradientTape() as tape:
# g_tape.watched(var_list): trainable variables are automatically "watched".
current_loss = loss_dice(model(input), labels)
gradients = tape.gradient(current_loss, var_list)
optimizer.apply_gradients(zip(gradients, var_list))
print(tf.reduce_mean(current_loss))
### a simple npy image reading class
class DataReader:
......@@ -200,6 +186,17 @@ class DataReader:
images = [np.float32(np.load(os.path.join(self.folder_name, fn))) for fn in file_names]
return np.expand_dims(np.stack(images, axis=0), axis=4)
### training
@tf.function
def train_step(model, weights, optimizer, x, y):
with tf.GradientTape() as tape:
# g_tape.watched(var_list): trainable variables are automatically "watched".
loss = loss_dice(model(x), y)
gradients = tape.gradient(loss, weights)
optimizer.apply_gradients(zip(gradients, weights))
return loss
learning_rate = 1e-5
total_iter = int(1e6)
n = 50 # 50 training image-label pairs
size_minibatch = 4
......@@ -207,10 +204,9 @@ path_to_data = '../../../promise12'
num_minibatch = int(n/size_minibatch) # how many minibatches in each epoch
indices_train = [i for i in range(n)]
# data reader
DataFeeder = DataReader(path_to_data)
# start the iterations
DataFeeder = DataReader(path_to_data)
optimizer = tf.optimizers.Adam(learning_rate)
for step in range(total_iter):
# shuffle data every time start a new set of minibatches
......@@ -224,22 +220,18 @@ for step in range(total_iter):
input_mb = DataFeeder.load_images_train(indices_mb)[:, ::2, ::2, ::2, :]
label_mb = DataFeeder.load_labels_train(indices_mb)[:, ::2, ::2, ::2, :]
# update the variables
with tf.GradientTape() as tape:
# tape.watched(var_list): trainable variables are automatically "watched".
loss_train = loss_dice(residual_unet(input_mb), label_mb)
gradients = tape.gradient(loss_train, var_list)
optimizer.apply_gradients(zip(gradients, var_list))
loss_train = train_step(residual_unet, var_list, optimizer, input_mb, label_mb)
# print training information
if (step % 100) == 0:
print('Step %d: training-loss=%f' % (step, loss_train))
if (step % 1) == 0:
tf.print('Step', step, ': training-loss=', loss_train)
# --- simple tests during training ---
if (step % 1000) == 0:
if (step % 100) == 0:
indices_test = [random.randrange(30) for i in range(size_minibatch)] # select size_minibatch test data
input_test = DataFeeder.load_images_test(indices_test)[:, ::2, ::2, ::2, :]
pred_test = residual_unet(input_test)
# save the segmentation
for idx in range(size_minibatch):
np.save("./label_test%02d_step%06d.npy" % (indices_test[idx], step), pred_test[idx, ...])
print('Test results saved.')
tf.print('Test results saved.')
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment