Commit c722f7dc authored by Yipeng Hu's avatar Yipeng Hu

Merge branch 'master' of weisslab.cs.ucl.ac.uk:WEISS/machine-learning-journal-club

parents 004dd249 cdb89279
This diff is collapsed.
% script_prepData
if ispc
homeFolder = getenv('USERPROFILE');
elseif isunix
homeFolder = getenv('HOME');
end
normFolder = fullfile(homeFolder, 'Scratch/data/protocol/normalised');
mkdir(normFolder);
dataFolder = fullfile(homeFolder, 'Scratch/data/protocol/SPE_data_classes');
ClassNames = {'1_skull'; '2_abdomen'; '3_heart'; '4_other'};
ClassFolders = cellfun(@(x)fullfile(dataFolder,x),ClassNames,'UniformOutput',false);
%% go through all raw data and obtain the frame_info
case_ids = {};
frame_info = {};
frame_counters = zeros(length(ClassFolders),1);
idx_frame_1 = 0;
for idx_class_1 = 1:length(ClassFolders)
frame_names = dir(ClassFolders{idx_class_1});
for j = 3:length(frame_names)
fname = frame_names(j).name;
% get rid of the problematic files
try
% debug: fprintf('reading No.%d - [%s]\n',i,filename)
img = imread(fullfile(frame_names(j).folder,fname)); % figure, imshow(img,[])
catch
disp(fullfile(frame_names(j).folder,fname))
continue
end
% dealing with different date format here
date_del = strfind(fname,'-');
if length(date_del)>=6 % case that date_del is repeated
start0=date_del(6)+1;
elseif length(date_del)>=3
start0=date_del(3)+1;
else
start0=1;
end
ext0 = regexpi(fname,'.png');
newstr = strrep(fname(start0:ext0-1),'fr_','');
% additional check here
udls = strfind(newstr,'_');
if length(udls) ~= 2
warning('Incorrect filename format!, %s',fname);
end
id = newstr(1:udls(2)-1);
fr = str2double(newstr(udls(2)+1:end));
[~, idx_case_1] = ismember(id, case_ids);
if idx_case_1==0 % add to exisiting volume
idx_case_1 = length(case_ids)+1;
case_ids{length(case_ids)+1} = id;
end
idx_frame_1 = idx_frame_1+1;
idx_frame = idx_frame_1 - 1;
frame_info(idx_frame_1).filename = fname;
frame_info(idx_frame_1).case_name = id;
frame_info(idx_frame_1).case_idx = idx_case_1 - 1;
frame_info(idx_frame_1).class_name = ClassNames{idx_class_1};
frame_info(idx_frame_1).class_idx = idx_class_1 - 1;
end
end
save(fullfile(normFolder,'frame_info'),'frame_info','dataFolder');
%% now write into files
% specify the folders
load(fullfile(normFolder,'frame_info')); % specify the folders
roi_crop = [47,230,33,288]; % [xmin,xmax,ymin,ymax]
frame_size = [roi_crop(2)-roi_crop(1)+1, roi_crop(4)-roi_crop(3)+1];
indices_class = [frame_info(:).class_idx];
num_classes = length(unique(indices_class));
indices_subject = [frame_info(:).case_idx];
num_subjects = length(unique(indices_subject));
% by subject
MAX_num_frames = 50;
RESIZE_scale = 2;
frame_size = round(frame_size/2);
h5fn_subjects = fullfile(normFolder,'ultrasound_50frames.h5'); delete(h5fn_subjects);
% write global infomation
GroupName = '/frame_size';
h5create(h5fn_subjects,GroupName,size(frame_size),'DataType','uint32');
h5write(h5fn_subjects,GroupName,uint32(frame_size));
GroupName = '/num_classes';
h5create(h5fn_subjects,GroupName,[1,1],'DataType','uint32');
h5write(h5fn_subjects,GroupName,uint32(num_classes));
GroupName = '/num_subjects';
h5create(h5fn_subjects,GroupName,[1,1],'DataType','uint32');
h5write(h5fn_subjects,GroupName,uint32(num_subjects));
for idx_subject = 0:num_subjects-1 % 0-based indexing
indices_frame_1_subject = find(indices_subject==idx_subject);
num_frames_subject = length(indices_frame_1_subject);
if num_frames_subject>MAX_num_frames
indices_frame_1_subject = randsample(indices_frame_1_subject,MAX_num_frames);
num_frames_subject = MAX_num_frames;
end
for idx_frame_subject = 0:num_frames_subject-1
idx_frame = indices_frame_1_subject(idx_frame_subject+1);
filename = fullfile(dataFolder, frame_info(idx_frame).class_name, frame_info(idx_frame).filename);
img = imread(filename);
img = imresize(img(roi_crop(1):roi_crop(2),roi_crop(3):roi_crop(4)),frame_size);
GroupName = sprintf('/subject%06d_frame%08d',idx_subject,idx_frame_subject);
h5create(h5fn_subjects,GroupName,size(img),'DataType','uint8');
h5write(h5fn_subjects,GroupName,img);
GroupName = sprintf('/subject%06d_label%08d',idx_subject,idx_frame_subject);
h5create(h5fn_subjects,GroupName,[1,1],'DataType','uint32');
h5write(h5fn_subjects,GroupName,uint32(indices_class(idx_frame)));
end
GroupName = sprintf('/subject%06d_num_frames',idx_subject);
h5create(h5fn_subjects,GroupName,[1,1],'DataType','uint32');
h5write(h5fn_subjects,GroupName,uint32(num_frames_subject));
end
%% obsolete
% % by subject
% h5fn_subjects = fullfile(normFolder,'protocol_sweep_class_subjects.h5'); delete(h5fn_subjects);
% num_frames_per_subject = zeros(1,num_subjects,'uint32');
% for idx_subject = (1:num_subjects)-1 % 0-based indexing
% frame_subject = 0;
% indices_frame_1_subject = find(indices_subject==idx_subject);
% num_frames_per_subject(idx_subject+1) = length(indices_frame_1_subject);
% for idx_frame_1 = indices_frame_1_subject
% filename = fullfile(dataFolder,frame_info(idx_frame_1).class_name,frame_info(idx_frame_1).filename);
% img = imread(filename);
% img = img(roi_crop(1):roi_crop(2),roi_crop(3):roi_crop(4));
% GroupName = sprintf('/subject%06d_frame%08d',idx_subject,frame_subject);
% frame_subject = frame_subject+1;
% h5create(h5fn_subjects,GroupName,size(img),'DataType','uint8');
% h5write(h5fn_subjects,GroupName,img);
% end
% GroupName = sprintf('/subject%06d_class',idx_subject);
% h5create(h5fn_subjects,GroupName,size(indices_frame_1_subject),'DataType','uint32');
% h5write(h5fn_subjects,GroupName,uint32(indices_class(indices_frame_1_subject)));
% end
% % extra info
% GroupName = '/num_frames_per_subject';
% h5create(h5fn_subjects,GroupName,size(num_frames_per_subject),'DataType','uint32');
% h5write(h5fn_subjects,GroupName,uint32(num_frames_per_subject));
% GroupName = '/frame_size';
% h5create(h5fn_subjects,GroupName,size(frame_size),'DataType','uint32');
% h5write(h5fn_subjects,GroupName,uint32(frame_size));
% GroupName = '/num_classes';
% h5create(h5fn_subjects,GroupName,[1,1],'DataType','uint32');
% h5write(h5fn_subjects,GroupName,uint32(num_classes));
% GroupName = '/num_subjects';
% h5create(h5fn_subjects,GroupName,[1,1],'DataType','uint32');
% h5write(h5fn_subjects,GroupName,uint32(num_subjects));
% % by frames
% h5fn_frames = fullfile(normFolder,'protocol_sweep_class_frames.h5'); delete(h5fn_frames);
% for idx_frame_1 = 1:length(frame_info)
% %% now read in image
% filename = fullfile(dataFolder,frame_info(idx_frame_1).class_name,frame_info(idx_frame_1).filename);
% img = imread(filename);
% img = img(roi_crop(1):roi_crop(2),roi_crop(3):roi_crop(4));
% % figure, imshow(img,[])
% GroupName = sprintf('/frame%08d',idx_frame_1-1);
% h5create(h5fn_frames,GroupName,size(img),'DataType','uint8');
% h5write(h5fn_frames,GroupName,img);
% end
% GroupName = '/class';
% h5create(h5fn_frames,GroupName,size(indices_class),'DataType','uint32');
% h5write(h5fn_frames,GroupName,uint32(indices_class));
% GroupName = '/subject';
% h5create(h5fn_frames,GroupName,size(indices_subject),'DataType','uint32');
% h5write(h5fn_frames,GroupName,uint32(indices_subject));
% % extra info
% GroupName = '/frame_size';
% h5create(h5fn_frames,GroupName,size(frame_size),'DataType','uint32');
% h5write(h5fn_frames,GroupName,uint32(frame_size));
% GroupName = '/num_classes';
% h5create(h5fn_frames,GroupName,[1,1],'DataType','uint32');
% h5write(h5fn_frames,GroupName,uint32(num_classes));
% GroupName = '/num_subjects';
% h5create(h5fn_frames,GroupName,[1,1],'DataType','uint32');
% h5write(h5fn_frames,GroupName,uint32(num_subjects));
import random
import tensorflow as tf
from matplotlib import pyplot as plt
nSbj = 6
nFrm = 8
filename = '../../../datasets/ultrasound_50frames.h5'
# generate 5 random subjects
num_subjects = tf.keras.utils.HDF5Matrix(filename, '/num_subjects').data.value[0][0]
idx_subject = random.sample(range(num_subjects),nSbj)
plt.figure()
for iSbj in range(nSbj):
dataset = '/subject%06d_num_frames' % (idx_subject[iSbj])
num_frames = tf.keras.utils.HDF5Matrix(filename, dataset)[0][0]
idx_frame = random.sample(range(num_frames),nFrm)
for iFrm in range(nFrm):
dataset = '/subject%06d_frame%08d' % (idx_subject[iSbj], idx_frame[iFrm])
frame = tf.transpose(tf.keras.utils.HDF5Matrix(filename, dataset))
dataset = '/subject%06d_label%08d' % (idx_subject[iSbj], idx_frame[iFrm])
label = tf.keras.utils.HDF5Matrix(filename, dataset)[0][0]
axs = plt.subplot(nSbj, nFrm, iSbj*nFrm+iFrm+1)
axs.set_title('S{}, F{}, C{}'.format(idx_subject[iSbj], idx_frame[iFrm], label))
axs.imshow(frame, cmap='gray')
axs.axis('off')
plt.show()
import tensorflow as tf
import random
# import numpy as np
filename = '../../../datasets/ultrasound_50frames.h5'
frame_size = tf.keras.utils.HDF5Matrix(filename, '/frame_size').data.value
frame_size = [frame_size[0][0],frame_size[1][0]]
num_classes = tf.keras.utils.HDF5Matrix(filename, '/num_classes').data.value[0][0]
# place holder for input image frames
features_input = tf.keras.Input(shape=frame_size+[1])
features = tf.keras.layers.Conv2D(32, 7, activation='relu')(features_input)
features = tf.keras.layers.MaxPool2D(3)(features)
features_block_1 = tf.keras.layers.Conv2D(64, 3, activation='relu')(features)
features = tf.keras.layers.Conv2D(64, 3, activation='relu', padding='same')(features_block_1)
features = tf.keras.layers.Conv2D(64, 3, activation='relu', padding='same')(features)
features_block_2 = features + features_block_1
features = tf.keras.layers.Conv2D(64, 3, activation='relu', padding='same')(features_block_2)
features = tf.keras.layers.Conv2D(64, 3, activation='relu', padding='same')(features)
features = features + features_block_2
features = tf.keras.layers.MaxPool2D(3)(features)
features_block_3 = tf.keras.layers.Conv2D(128, 3, activation='relu')(features)
features = tf.keras.layers.Conv2D(128, 3, activation='relu', padding='same')(features_block_3)
features = tf.keras.layers.Conv2D(128, 3, activation='relu', padding='same')(features)
features_block_4 = features + features_block_3
features = tf.keras.layers.Conv2D(128, 3, activation='relu', padding='same')(features_block_4)
features = tf.keras.layers.Conv2D(128, 3, activation='relu', padding='same')(features)
features_block_5 = features + features_block_4
features = tf.keras.layers.Conv2D(128, 3, activation='relu', padding='same')(features_block_5)
features = tf.keras.layers.Conv2D(128, 3, activation='relu', padding='same')(features)
features_block_6 = features + features_block_5
features = tf.keras.layers.Conv2D(128, 3, activation='relu', padding='same')(features_block_6)
features = tf.keras.layers.Conv2D(128, 3, activation='relu', padding='same')(features)
features = features + features_block_6
features = tf.keras.layers.Conv2D(128, 3, activation='relu')(features)
features = tf.keras.layers.GlobalAveragePooling2D()(features)
features = tf.keras.layers.Dense(units=256, activation='relu')(features)
features = tf.keras.layers.Dropout(0.5)(features)
logits_output = tf.keras.layers.Dense(units=num_classes, activation='softmax')(features)
# now the model
model = tf.keras.Model(inputs=features_input, outputs=logits_output)
model.summary()
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
loss='sparse_categorical_crossentropy',
metrics=['SparseCategoricalAccuracy'])
# now get the data using a generator
num_subjects = tf.keras.utils.HDF5Matrix(filename, '/num_subjects').data.value[0][0]
subject_indices = range(num_subjects)
num_frames_per_subject = 1
def data_generator():
for iSbj in subject_indices:
dataset = '/subject%06d_num_frames' % iSbj
num_frames = tf.keras.utils.HDF5Matrix(filename, dataset)[0][0]
idx_frame = random.sample(range(num_frames),num_frames_per_subject)[0]
dataset = '/subject%06d_frame%08d' % (iSbj, idx_frame)
frame = tf.transpose(tf.keras.utils.HDF5Matrix(filename, dataset)) / 255
dataset = '/subject%06d_label%08d' % (iSbj, idx_frame)
label = tf.keras.utils.HDF5Matrix(filename, dataset)[0][0]
yield (tf.expand_dims(frame, axis=2), label)
dataset = tf.data.Dataset.from_generator(generator = data_generator,
output_types = (tf.float32, tf.int32),
output_shapes = (frame_size+[1], ()))
# training
dataset_batch = dataset.shuffle(buffer_size=1024).batch(num_subjects)
frame_train, label_train = next(iter(dataset_batch))
model.fit(frame_train, label_train, epochs=int(1e3), validation_split=0.2)
import tensorflow as tf
### build a data pipeline
# https://www.tensorflow.org/guide/data
# https://www.tensorflow.org/guide/data_performance
\ No newline at end of file
# (Provided)
# *** Available as part of UCL MPHY0025 (Information Processing in Medical Imaging) Assessed Coursework 2018-19 ***
# *** This code is with an Apache 2.0 license, University College London ***
import numpy as np
import matplotlib.pyplot as plt
plt.style.use('default')
def dispImage(img, int_lims = [], ax = None):
"""
function to display a grey-scale image that is stored in 'standard
orientation' with y-axis on the 2nd dimension and 0 at the bottom
INPUTS: img: image to be displayed
int_lims: the intensity limits to use when displaying the
image, int_lims[0] = min intensity to display, int_lims[1]
= max intensity to display [default min and max intensity
of image]
ax: if displaying an image on a subplot grid or on top of a
second image, optionally supply the axis on which to display
the image.
OUTPUTS: ax: the axis object after plotting if an axis object is
supplied
"""
#check if intensity limits have been provided, and if not set to min and
#max of image
if not int_lims:
int_lims = [np.nanmin(img), np.nanmax(img)]
#check if min and max are same (i.e. all values in img are equal)
if int_lims[0] == int_lims[1]:
int_lims[0] -= 1
int_lims[1] += 1
# take transpose of image to switch x and y dimensions and display with
# first pixel having coordinates 0,0
img = img.T
if not ax:
plt.imshow(img, cmap = 'gray', vmin = int_lims[0], vmax = int_lims[1], \
origin='lower')
else:
ax.imshow(img, cmap = 'gray', vmin = int_lims[0], vmax = int_lims[1], \
origin='lower')
#set axis to be scaled equally (assumes isotropic pixel dimensions), tight
#around the image
plt.axis('image')
plt.tight_layout()
return ax
def dispBinaryImage(binImg, cmap='Greens_r', ax=None):
"""
function to display a binary image that is stored in 'standard
orientation' with y-axis on the 2nd dimension and 0 at the bottom
INPUTS: binImg: binary image to be displayed
ax: if displaying an image on a subplot grid or on top of a
second image, optionally supply the axis on which to display
the image.
E.g.
fig = plt.figure()
ax = fig.gca()
ax = dispImage(ct_image, ax)
ax = dispBinaryImage(label_image, ax)
cmap: color map of the binary image to be displayed
(see: https://matplotlib.org/examples/color/colormaps_reference.html)
OUTPUTS: ax: the axis object after plotting if an axis object is
supplied
"""
# take transpose of image to switch x and y dimensions and display with
# first pixel having coordinates 0,0
binImg = binImg.T
# set the background pixels to NaNs so that imshow will display
# transparent
binImg = np.where(binImg == 0, np.nan, binImg)
if not ax:
plt.imshow(binImg, cmap = cmap, origin='lower')
else:
ax.imshow(binImg, cmap = cmap, origin='lower')
#set axis to be scaled equally (assumes isotropic pixel dimensions), tight
#around the image
plt.axis('image')
plt.tight_layout()
return ax
def dispImageAndBinaryOverlays(img, bin_imgs = [], bin_cols = [], int_lims = [], ax = None):
"""
function to display a grey-scale image with one or more binary images
overlaid
INPUTS: img: image to be displayed
bin_imgs: a list or np.array containing one or more binary images.
must have same dimensions as img
bin_cols: a list or np.array containing the matplotlib colormaps
to use for each binary image E.g. 'Greens_r', 'Reds_r'
Must have one colormap for each binary image
int_lims: the intensity limits to use when displaying the
image, int_lims[0] = min intensity to display, int_lims[1]
= max intensity to display [default min and max intensity
of image]
ax: if displaying an image on a subplot grid or on top of a
second image, optionally supply the axis on which to display
the image.
OUTPUTS: ax: the axis object after plotting if an axis object is
supplied
"""
#check if intensity limits have been provided, and if not set to min and
#max of image
if not int_lims:
int_lims = [np.nanmin(img), np.nanmax(img)]
#check if min and max are same (i.e. all values in img are equal)
if int_lims[0] == int_lims[1]:
int_lims[0] -= 1
int_lims[1] += 1
# take transpose of image to switch x and y dimensions and display with
# first pixel having coordinates 0,0
img = img.T
if not ax:
fig = plt.figure()
ax = fig.gca()
ax.imshow(img, cmap = 'gray', vmin = int_lims[0], vmax = int_lims[1], \
origin='lower')
for idx, binImg in enumerate(bin_imgs):
binImg = binImg.T
# check the binary images and img are the same shape
if binImg.shape != img.shape:
print('Error: binary image {} does not have same dimensions as image'.format(idx))
break
# set the colormap from bin_cols unless not enough colors have been provided
try:
cmap = bin_cols[idx]
except IndexError:
cmap = 'Greens_r'
print('WARNING: not enough colormaps provided - defaulting to Green')
ax.imshow(np.where(binImg == 0, np.nan, binImg), cmap=cmap,\
origin = 'lower')
#set axis to be scaled equally (assumes isotropic pixel dimensions), tight
#around the image
plt.axis('image')
plt.tight_layout()
return ax
# (Provided) This uses TensorFlow
# *** Available as part of UCL MPHY0025 (Information Processing in Medical Imaging) Assessed Coursework 2018-19 ***
# *** This code is with an Apache 2.0 license, University College London ***
import tensorflow as tf
import networks
import numpy as np
from matplotlib.pyplot import imread
# 1 - Read images and convert to "standard orientation"
files_image_test = ['../data/test/433.png', '../data/test/441.png']
images = np.stack([imread(fn)[::-1, ...].T for fn in files_image_test], axis=0)
# Normalise the test images so they have zero-mean and unit-variance
images = (images-images.mean(axis=(1, 2), keepdims=True)) / images.std(axis=(1, 2), keepdims=True)
image_size = [images.shape[1], images.shape[2]]
# 2 - Load one of the provided trained networks
model_dir = '../trained/ssd_1e1/'
# model_dir = '../trained/ssd_1e0/'
# model_dir = '../trained/ssd_1e-1/'
# model_dir = '../trained/ssd_1e-2/'
file_model_save = model_dir+'model_saved'
# Restore the computation graph
ph_moving_image = tf.placeholder(tf.float32, [1]+image_size)
ph_fixed_image = tf.placeholder(tf.float32, [1]+image_size)
input_moving_image = tf.expand_dims(ph_moving_image, axis=3)
input_fixed_image = tf.expand_dims(ph_fixed_image, axis=3)
reg_net = networks.RegNet2D(minibatch_size=1, image_moving=input_moving_image, image_fixed=input_fixed_image)
# Reinstate the trained weights (stored in the network model file)
saver = tf.train.Saver()
sess = tf.Session()
saver.restore(sess, file_model_save)
# 5 - Predict the displacement field
testFeed = {ph_moving_image: images[[0], ...], ph_fixed_image: images[[1], ...]}
ddfs, resampling_grids = sess.run([reg_net.ddf, reg_net.grid_warped], feed_dict=testFeed)
# 6a - an example if one uses NumPy for analysis
"""
np.save(model_dir+'ddfs.npy', ddfs)
np.save(model_dir+'resampling_grids.npy', resampling_grids)
"""
# 6b - an example if one uses MATLAB for analysis
"""
from scipy.io import savemat
savemat(model_dir+'reg',
{'ddfs': ddfs,
'resampling_grids': resampling_grids})
"""
# (Model) This can be implemented in MATLAB
import os
import matplotlib.pyplot as plt
import numpy as np
# data
data_dir = '../data'
# training
train_dir = '../data/train'
files_image_train = [fn for fn in os.listdir(train_dir) if os.path.splitext(fn)[0].isdigit()]
# save images to files and convert to standard orientation
images_train = np.stack([plt.imread(os.path.join(train_dir, fn))[::-1, ...].T for fn in files_image_train], axis=0)
# normalise individual images
images_train = (images_train-images_train.mean(axis=(1, 2), keepdims=True)) / images_train.std(axis=(1, 2), keepdims=True)
np.save(os.path.join(data_dir, 'images_train.npy'), images_train)
# test
test_dir = '../data/test'
files_image_test = [fn for fn in os.listdir(test_dir) if os.path.splitext(fn)[0].isdigit()]
files_label0_test = [os.path.splitext(fn)[0]+'_SPINAL_CORD'+os.path.splitext(fn)[1] for fn in files_image_test]
files_label1_test = [os.path.splitext(fn)[0]+'_BRAIN_STEM'+os.path.splitext(fn)[1] for fn in files_image_test]
images_test = np.stack([plt.imread(os.path.join(test_dir, fn))[::-1, ...].T for fn in files_image_test], axis=0)
images_test = (images_test-images_test.mean(axis=(1, 2), keepdims=True)) / images_test.std(axis=(1, 2), keepdims=True)
np.save(os.path.join(data_dir, 'images_test.npy'), images_test)
labels0_test = np.stack([plt.imread(os.path.join(test_dir, fn))[::-1, ...].T for fn in files_label0_test], axis=0)
labels1_test = np.stack([plt.imread(os.path.join(test_dir, fn))[::-1, ...].T for fn in files_label1_test], axis=0)
np.save(os.path.join(data_dir, 'labels_test.npy'), np.stack([labels0_test, labels1_test], axis=3))
# plot test data
for idx in range(len(files_image_test)):
fig = plt.figure()
ax = fig.gca()
ax.imshow(images_test[idx,...], cmap='gray')
ax.imshow(np.where(labels0_test[idx, ...] == 0, np.nan, labels0_test[idx, ...]), cmap='Reds_r')
ax.imshow(np.where(labels1_test[idx, ...] == 0, np.nan, labels1_test[idx, ...]), cmap='Greens_r')
ax.set_title(files_image_test[idx])
plt.show()
# (Model) This uses TensorFlow
import tensorflow as tf
def normalised_cross_correlation(ts, ps, eps=0.0):
dp = ps - tf.reduce_mean(ps, axis=[1, 2, 3])
dt = ts - tf.reduce_mean(ts, axis=[1, 2, 3])
vp = tf.reduce_sum(tf.square(dp), axis=[1, 2, 3])
vt = tf.reduce_sum(tf.square(dt), axis=[1, 2, 3])
return tf.constant(1.0) - tf.reduce_sum(dp*dt / (tf.sqrt(vp*vt) + eps), axis=[1, 2, 3])
def normalised_cross_correlation2(ts, ps, eps=1e-6):
mean_t = tf.reduce_mean(ts, axis=[1, 2, 3])
mean_p = tf.reduce_mean(ps, axis=[1, 2, 3])
std_t = tf.reduce_sum(tf.sqrt(tf.square(mean_t)-tf.reduce_mean(tf.square(ts), axis=[1, 2, 3])), axis=[1, 2, 3])
std_p = tf.reduce_sum(tf.sqrt(tf.square(mean_p)-tf.reduce_mean(tf.square(ps), axis=[1, 2, 3])), axis=[1, 2, 3])
return -tf.reduce_mean((ts-mean_t)*(ps-mean_p) / (std_t*std_p+eps), axis=[1, 2, 3])
def sum_square_difference(i1, i2):
return tf.reduce_mean(tf.square(i1 - i2), axis=[1, 2, 3]) # use mean for normalised regulariser weighting
def gradient_dx(fv):
return (fv[:, 2:, 1:-1] - fv[:, :-2, 1:-1]) / 2
def gradient_dy(fv):
return (fv[:, 1:-1, 2:] - fv[:, 1:-1, :-2]) / 2
def gradient_txy(txy, fn):
return tf.stack([fn(txy[..., i]) for i in [0, 1]], axis=3)
def gradient_norm(displacement, flag_l1=False):
dtdx = gradient_txy(displacement, gradient_dx)
dtdy = gradient_txy(displacement, gradient_dy)
if flag_l1:
norms = tf.abs(dtdx) + tf.abs(dtdy)
else:
norms = dtdx**2 + dtdy**2
return tf.reduce_mean(norms, [1, 2, 3])
def bending_energy(displacement):
dtdx = gradient_txy(displacement, gradient_dx)
dtdy = gradient_txy(displacement, gradient_dy)
dtdxx = gradient_txy(dtdx, gradient_dx)
dtdyy = gradient_txy(dtdy, gradient_dy)
dtdxy = gradient_txy(dtdx, gradient_dy)
return tf.reduce_mean(dtdxx**2 + dtdyy**2 + 2*dtdxy**2, [1, 2, 3])
# (Provided) This uses TensorFlow
# *** Available as part of UCL MPHY0025 (Information Processing in Medical Imaging) Assessed Coursework 2018-19 ***
# *** This code is with an Apache 2.0 license, University College London ***
import tensorflow as tf
def var_conv_kernel(ch_in, ch_out, name='W', initialiser=None):
with tf.variable_scope(name):
k_conv = [3, 3]
if initialiser is None:
initialiser = tf.contrib.layers.xavier_initializer()
return tf.get_variable(name, shape=k_conv+[ch_in]+[ch_out], initializer=initialiser)
def conv_block(input_, ch_in, ch_out, strides=None, name='conv_block'):
if strides is None:
strides = [1, 1, 1, 1]
with tf.variable_scope(name):
w = var_conv_kernel(ch_in, ch_out)
return tf.nn.relu(tf.contrib.layers.batch_norm(tf.nn.conv2d(input_, w, strides, "SAME")))
def deconv_block(input_, ch_in, ch_out, shape_out, strides, name='deconv_block'):
with tf.variable_scope(name):
w = var_conv_kernel(ch_out, ch_in)
return tf.nn.relu(tf.contrib.layers.batch_norm(tf.nn.conv2d_transpose(input_, w, shape_out, strides, "SAME")))
def downsample_resnet_block(input_, ch_in, ch_out, name='down_resnet_block'):
strides1 = [1, 1, 1, 1]
strides2 = [1, 2, 2, 1]
k_pool = [1, 2, 2, 1]
with tf.variable_scope(name):
h0 = conv_block(input_, ch_in, ch_out, name='W0')
r1 = conv_block(h0, ch_out, ch_out, name='WR1')
wr2 = var_conv_kernel(ch_out, ch_out)
r2 = tf.nn.relu(tf.contrib.layers.batch_norm(tf.nn.conv2d(r1, wr2, strides1, "SAME")) + h0)
h1 = tf.nn.max_pool(r2, k_pool, strides2, padding="SAME")
return h1, h0
def upsample_resnet_block(input_, input_skip, ch_in, ch_out, name='up_resnet_block'):
strides1 = [1, 1, 1, 1]
strides2 = [1, 2, 2, 1]
size_out = input_skip.shape.as_list()
with tf.variable_scope(name):