Commit 65b05d1a authored by Yipeng Hu's avatar Yipeng Hu
Browse files

ref #5 initial commit of the script

parent 344b5191
# This is a tutorial using TensorFlow 2.x, in particular, low-level TF APIs without high-level Keras
import tensorflow as tf
# import tensorflow_addons as tfa
import matplotlib.pyplot as plt
### Define a few functions for network layers
def conv3d(input, filters):
out = tf.nn.conv3d(input, filters, strides=[1,3,3,3,1], padding='SAME')
return tf.nn.relu(out)
def max_pool(input):
return tf.nn.max_pool3d(input, ksize=[1,3,3,3,1], padding='VALID', strides=[1,2,2,2,1])
def deconv3d(input, filters):
out = tf.nn.conv3d_transpose(input, filters, strides=[1,2,2,2,1], padding='SAME')
return tf.nn.relu(out)
def dense(input, weights):
x = tf.nn.relu(tf.matmul(input, weights))
return tf.nn.dropout(x, rate=0.5)
def resnet_block(input, filters1, filters2):
y = tf.nn.conv3d(input, filters1, strides=[1,2,2,2,1], padding='SAME')
y = tf.nn.relu(y) # where bn can be added
y = tf.nn.conv3d(y, filters2, strides=[1,2,2,2,1], padding='SAME')
return tf.nn.relu(y + input)
def up_sampling_deconv(input, filters):
out = tf.nn.conv3d_transpose(input, filters, strides=[1,2,2,2,1], padding='SAME')
return tf.nn.relu(out)
batch_mean2, batch_var2 = tf.nn.moments(z2_BN,[0])
scale2 = tf.Variable(tf.ones([100]))
beta2 = tf.Variable(tf.zeros([100]))
BN2 = tf.nn.batch_normalization(z2_BN,batch_mean2,batch_var2,beta2,scale2,epsilon)
def batch_norm(inputs, is_training, decay = 0.999):
scale = tf.Variable(tf.ones([inputs.get_shape()[-1]]))
beta = tf.Variable(tf.zeros([inputs.get_shape()[-1]]))
pop_mean = tf.Variable(tf.zeros([inputs.get_shape()[-1]]), trainable=False)
pop_var = tf.Variable(tf.ones([inputs.get_shape()[-1]]), trainable=False)
if is_training:
batch_mean, batch_var = tf.nn.moments(inputs,[0])
train_mean = tf.assign(pop_mean,
pop_mean * decay + batch_mean * (1 - decay))
train_var = tf.assign(pop_var,
pop_var * decay + batch_var * (1 - decay))
with tf.control_dependencies([train_mean, train_var]):
return tf.nn.batch_normalization(inputs,
batch_mean, batch_var, beta, scale, epsilon)
return tf.nn.batch_normalization(inputs,
pop_mean, pop_var, beta, scale, epsilon)
### Define the "weights" - trainable variables
initialiser = tf.initializers.glorot_normal()
### Define a model (the 3D U-Net) with these layers
def model(x):
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment