...
 
Commits (5)
#!/usr/bin/env python3
import sys
print('Importing libraries...')
from pathlib import Path
from vesseg import Model, Job
from vesseg.network.model import DSA, T1, T1_GAD
learning_dir = Path('~/mres_project/learning').expanduser()
combinations = (
(DSA,),
(DSA, T1),
(DSA, T1_GAD),
(DSA, T1, T1_GAD),
)
learning_rates = 0.01, 0.02, 0.05, 0.1, 0.2, 0.5, 1, 2, 5
networks = 222, 223, 233
applications = 'Vessels', 'Vessels_Reuben'
for application in applications:
for net in networks:
for inputs in combinations:
for lr in learning_rates:
models_dir = learning_dir / application / f'highres3dnet_{net}'
network = f'vesseg_networks.highres3dnet_{net}.HighRes3DNet{net}'
application = (
'vesseg_applications'
f'.segmentation_application_{application.lower()}'
f'.SegmentationApplication{application.replace("_", "")}'
)
string = '-'.join(inputs)
model_name = f'highres3dnet_{net}_{string}_lr_{lr}'
print(f'Creating {model_name}...')
model_dir = models_dir / model_name
model = Model(model_dir=model_dir, inputs=inputs)
model.set_images_and_labels_paths()
model.make_csv_files(split_type='subject')
model.config_all()
model.config_training(learning_rate=lr)
model.config_network(network_name=network)
model.write_config_file()
job = Job(model_dir, model.config_path)
job.application = application
job.write()
if len(sys.argv) > 1 and sys.argv[1] == '--submit':
print(f'Submitting {model_name}...')
job.submit()
......@@ -15,6 +15,7 @@ class Job:
self.config_path = config_path
self.train = True
self.infer = True
self.application = 'vesseg_applications.segmentation_application_vessels.SegmentationApplicationVessels'
def write(self, old_cluster=False):
......@@ -80,7 +81,7 @@ class Job:
a(' /share/apps/python-3.6.3-shared/bin/python3) \\')
a(f' -u net_segment.py {action} \\')
a(f' -c "{self.config_path}" \\')
a(f' -a vesseg_applications.segmentation_application_vessels.SegmentationApplicationVessels')
a(f' -a {self.application}')
a('')
if old_cluster:
a('fi')
......
......@@ -453,7 +453,7 @@ class SegmentationApplicationVessels(BaseApplication):
# classification probabilities or argmax classification labels
data_dict = switch_sampler(for_training=False)
image = tf.cast(data_dict['image'], tf.float32)
net_args = {'is_training': True, # self.is_training, (Reuben modif)
net_args = {'is_training': self.is_training, # True # (Reuben modif)
'keep_prob': self.net_param.keep_prob}
net_out = self.net(image, **net_args)
......
# -*- coding: utf-8 -*-
import tensorflow as tf
from niftynet.application.base_application import BaseApplication
from niftynet.engine.application_factory import \
ApplicationNetFactory, InitializerFactory, OptimiserFactory
from niftynet.engine.application_variables import \
CONSOLE, NETWORK_OUTPUT, TF_SUMMARIES
from niftynet.engine.sampler_grid_v2 import GridSampler
from niftynet.engine.sampler_resize_v2 import ResizeSampler
from niftynet.engine.sampler_uniform_v2 import UniformSampler
from niftynet.engine.sampler_weighted_v2 import WeightedSampler
from niftynet.engine.sampler_balanced_v2 import BalancedSampler
from niftynet.engine.windows_aggregator_grid import GridSamplesAggregator
from niftynet.engine.windows_aggregator_resize import ResizeSamplesAggregator
from niftynet.io.image_reader import ImageReader
from niftynet.layer.binary_masking import BinaryMaskingLayer
from niftynet.layer.discrete_label_normalisation import \
DiscreteLabelNormalisationLayer
from niftynet.layer.histogram_normalisation import \
HistogramNormalisationLayer
from niftynet.layer.loss_segmentation import LossFunction
from niftynet.layer.mean_variance_normalisation import \
MeanVarNormalisationLayer
from niftynet.layer.pad import PadLayer
from niftynet.layer.post_processing import PostProcessingLayer
from niftynet.layer.rand_flip import RandomFlipLayer
from niftynet.layer.rand_rotation import RandomRotationLayer
from niftynet.layer.rand_spatial_scaling import RandomSpatialScalingLayer
from niftynet.evaluation.segmentation_evaluator import SegmentationEvaluator
from niftynet.layer.rand_elastic_deform import RandomElasticDeformationLayer
SUPPORTED_INPUT = set(['image', 'label', 'weight', 'sampler', 'inferred'])
class SegmentationApplicationVesselsReuben(BaseApplication):
REQUIRED_CONFIG_SECTION = "SEGMENTATION"
def __init__(self, net_param, action_param, action):
super(SegmentationApplicationVesselsReuben, self).__init__()
tf.logging.info('Starting vessels (no BN) segmentation application')
self.action = action
self.net_param = net_param
self.action_param = action_param
self.data_param = None
self.segmentation_param = None
self.SUPPORTED_SAMPLING = {
'uniform': (self.initialise_uniform_sampler,
self.initialise_grid_sampler,
self.initialise_grid_aggregator),
'weighted': (self.initialise_weighted_sampler,
self.initialise_grid_sampler,
self.initialise_grid_aggregator),
'resize': (self.initialise_resize_sampler,
self.initialise_resize_sampler,
self.initialise_resize_aggregator),
'balanced': (self.initialise_balanced_sampler,
self.initialise_grid_sampler,
self.initialise_grid_aggregator),
}
def initialise_dataset_loader(
self, data_param=None, task_param=None, data_partitioner=None):
self.data_param = data_param
self.segmentation_param = task_param
# initialise input image readers
if self.is_training:
reader_names = ('image', 'label', 'weight', 'sampler')
elif self.is_inference:
# in the inference process use `image` input only
reader_names = ('image',)
elif self.is_evaluation:
reader_names = ('image', 'label', 'inferred')
else:
tf.logging.fatal(
'Action `%s` not supported. Expected one of %s',
self.action, self.SUPPORTED_PHASES)
raise ValueError
try:
reader_phase = self.action_param.dataset_to_infer
except AttributeError:
reader_phase = None
file_lists = data_partitioner.get_file_lists_by(
phase=reader_phase, action=self.action)
self.readers = [
ImageReader(reader_names).initialise(
data_param, task_param, file_list) for file_list in file_lists]
# initialise input preprocessing layers
foreground_masking_layer = BinaryMaskingLayer(
type_str=self.net_param.foreground_type,
multimod_fusion=self.net_param.multimod_foreground_type,
threshold=0.0) \
if self.net_param.normalise_foreground_only else None
mean_var_normaliser = MeanVarNormalisationLayer(
image_name='image', binary_masking_func=foreground_masking_layer) \
if self.net_param.whitening else None
histogram_normaliser = HistogramNormalisationLayer(
image_name='image',
modalities=vars(task_param).get('image'),
model_filename=self.net_param.histogram_ref_file,
binary_masking_func=foreground_masking_layer,
norm_type=self.net_param.norm_type,
cutoff=self.net_param.cutoff,
name='hist_norm_layer') \
if (self.net_param.histogram_ref_file and
self.net_param.normalisation) else None
label_normalisers = None
if self.net_param.histogram_ref_file and \
task_param.label_normalisation:
label_normalisers = [DiscreteLabelNormalisationLayer(
image_name='label',
modalities=vars(task_param).get('label'),
model_filename=self.net_param.histogram_ref_file)]
if self.is_evaluation:
label_normalisers.append(
DiscreteLabelNormalisationLayer(
image_name='inferred',
modalities=vars(task_param).get('inferred'),
model_filename=self.net_param.histogram_ref_file))
label_normalisers[-1].key = label_normalisers[0].key
normalisation_layers = []
if histogram_normaliser is not None:
normalisation_layers.append(histogram_normaliser)
if mean_var_normaliser is not None:
normalisation_layers.append(mean_var_normaliser)
if task_param.label_normalisation and \
(self.is_training or not task_param.output_prob):
normalisation_layers.extend(label_normalisers)
volume_padding_layer = []
if self.net_param.volume_padding_size:
volume_padding_layer.append(PadLayer(
image_name=SUPPORTED_INPUT,
border=self.net_param.volume_padding_size,
mode=self.net_param.volume_padding_mode))
# initialise training data augmentation layers
augmentation_layers = []
if self.is_training:
train_param = self.action_param
if train_param.random_flipping_axes != -1:
augmentation_layers.append(RandomFlipLayer(
flip_axes=train_param.random_flipping_axes))
if train_param.scaling_percentage:
augmentation_layers.append(RandomSpatialScalingLayer(
min_percentage=train_param.scaling_percentage[0],
max_percentage=train_param.scaling_percentage[1],
antialiasing=train_param.antialiasing))
if train_param.rotation_angle or \
train_param.rotation_angle_x or \
train_param.rotation_angle_y or \
train_param.rotation_angle_z:
rotation_layer = RandomRotationLayer()
if train_param.rotation_angle:
rotation_layer.init_uniform_angle(
train_param.rotation_angle)
else:
rotation_layer.init_non_uniform_angle(
train_param.rotation_angle_x,
train_param.rotation_angle_y,
train_param.rotation_angle_z)
augmentation_layers.append(rotation_layer)
if train_param.do_elastic_deformation:
spatial_rank = list(self.readers[0].spatial_ranks.values())[0]
augmentation_layers.append(RandomElasticDeformationLayer(
spatial_rank=spatial_rank,
num_controlpoints=train_param.num_ctrl_points,
std_deformation_sigma=train_param.deformation_sigma,
proportion_to_augment=train_param.proportion_to_deform))
# only add augmentation to first reader (not validation reader)
self.readers[0].add_preprocessing_layers(
volume_padding_layer + normalisation_layers + augmentation_layers)
for reader in self.readers[1:]:
reader.add_preprocessing_layers(
volume_padding_layer + normalisation_layers)
def initialise_uniform_sampler(self):
self.sampler = [[UniformSampler(
reader=reader,
window_sizes=self.data_param,
batch_size=self.net_param.batch_size,
windows_per_image=self.action_param.sample_per_volume,
queue_length=self.net_param.queue_length) for reader in
self.readers]]
def initialise_weighted_sampler(self):
self.sampler = [[WeightedSampler(
reader=reader,
window_sizes=self.data_param,
batch_size=self.net_param.batch_size,
windows_per_image=self.action_param.sample_per_volume,
queue_length=self.net_param.queue_length) for reader in
self.readers]]
def initialise_resize_sampler(self):
self.sampler = [[ResizeSampler(
reader=reader,
window_sizes=self.data_param,
batch_size=self.net_param.batch_size,
shuffle=self.is_training,
smaller_final_batch_mode=self.net_param.smaller_final_batch_mode,
queue_length=self.net_param.queue_length) for reader in
self.readers]]
def initialise_grid_sampler(self):
self.sampler = [[GridSampler(
reader=reader,
window_sizes=self.data_param,
batch_size=self.net_param.batch_size,
spatial_window_size=self.action_param.spatial_window_size,
window_border=self.action_param.border,
smaller_final_batch_mode=self.net_param.smaller_final_batch_mode,
queue_length=self.net_param.queue_length) for reader in
self.readers]]
def initialise_balanced_sampler(self):
self.sampler = [[BalancedSampler(
reader=reader,
window_sizes=self.data_param,
batch_size=self.net_param.batch_size,
windows_per_image=self.action_param.sample_per_volume,
queue_length=self.net_param.queue_length) for reader in
self.readers]]
def initialise_grid_aggregator(self):
self.output_decoder = GridSamplesAggregator(
image_reader=self.readers[0],
output_path=self.action_param.save_seg_dir,
window_border=self.action_param.border,
interp_order=self.action_param.output_interp_order,
postfix=self.action_param.output_postfix)
def initialise_resize_aggregator(self):
self.output_decoder = ResizeSamplesAggregator(
image_reader=self.readers[0],
output_path=self.action_param.save_seg_dir,
window_border=self.action_param.border,
interp_order=self.action_param.output_interp_order,
postfix=self.action_param.output_postfix)
def initialise_sampler(self):
if self.is_training:
self.SUPPORTED_SAMPLING[self.net_param.window_sampling][0]()
elif self.is_inference:
self.SUPPORTED_SAMPLING[self.net_param.window_sampling][1]()
def initialise_aggregator(self):
self.SUPPORTED_SAMPLING[self.net_param.window_sampling][2]()
def initialise_network(self):
w_regularizer = None
b_regularizer = None
reg_type = self.net_param.reg_type.lower()
decay = self.net_param.decay
if reg_type == 'l2' and decay > 0:
from tensorflow.contrib.layers.python.layers import regularizers
w_regularizer = regularizers.l2_regularizer(decay)
b_regularizer = regularizers.l2_regularizer(decay)
elif reg_type == 'l1' and decay > 0:
from tensorflow.contrib.layers.python.layers import regularizers
w_regularizer = regularizers.l1_regularizer(decay)
b_regularizer = regularizers.l1_regularizer(decay)
self.net = ApplicationNetFactory.create(self.net_param.name)(
num_classes=self.segmentation_param.num_classes,
w_initializer=InitializerFactory.get_initializer(
name=self.net_param.weight_initializer),
b_initializer=InitializerFactory.get_initializer(
name=self.net_param.bias_initializer),
w_regularizer=w_regularizer,
b_regularizer=b_regularizer,
acti_func=self.net_param.activation_function)
def connect_data_and_network(self,
outputs_collector=None,
gradients_collector=None):
def switch_sampler(for_training):
with tf.name_scope('train' if for_training else 'validation'):
sampler = self.get_sampler()[0][0 if for_training else -1]
return sampler.pop_batch_op()
if self.is_training:
if self.action_param.validation_every_n > 0:
data_dict = tf.cond(tf.logical_not(self.is_validation),
lambda: switch_sampler(for_training=True),
lambda: switch_sampler(for_training=False))
else:
data_dict = switch_sampler(for_training=True)
image = tf.cast(data_dict['image'], tf.float32)
net_args = {'is_training': self.is_training,
'keep_prob': self.net_param.keep_prob}
net_out = self.net(image, **net_args)
with tf.name_scope('Optimiser'):
optimiser_class = OptimiserFactory.create(
name=self.action_param.optimiser)
self.optimiser = optimiser_class.get_instance(
learning_rate=self.action_param.lr)
loss_func = LossFunction(
n_class=self.segmentation_param.num_classes,
loss_type=self.action_param.loss_type,
softmax=self.segmentation_param.softmax)
data_loss = loss_func(
prediction=net_out,
ground_truth=data_dict.get('label', None),
weight_map=data_dict.get('weight', None))
reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
if self.net_param.decay > 0.0 and reg_losses:
reg_loss = tf.reduce_mean(
[tf.reduce_mean(reg_loss) for reg_loss in reg_losses])
loss = data_loss + reg_loss
else:
loss = data_loss
# Get all vars
to_optimise = tf.trainable_variables()
vars_to_freeze = \
self.action_param.vars_to_freeze or \
self.action_param.vars_to_restore
if vars_to_freeze:
import re
var_regex = re.compile(vars_to_freeze)
# Only optimise vars that are not frozen
to_optimise = \
[v for v in to_optimise if not var_regex.search(v.name)]
tf.logging.info(
"Optimizing %d out of %d trainable variables, "
"the other variables fixed (--vars_to_freeze %s)",
len(to_optimise),
len(tf.trainable_variables()),
vars_to_freeze)
grads = self.optimiser.compute_gradients(
loss, var_list=to_optimise, colocate_gradients_with_ops=True)
# collecting gradients variables
gradients_collector.add_to_collection([grads])
# collecting output variables
outputs_collector.add_to_collection(
var=data_loss, name='loss',
average_over_devices=False, collection=CONSOLE)
outputs_collector.add_to_collection(
var=data_loss, name='loss',
average_over_devices=True, summary_type='scalar',
collection=TF_SUMMARIES)
### TENSORBOARD ###
axes = 'sagittal', # 'coronal', 'axial'
# Input images
for channel in range(image.shape[-1]):
channel_array = image[..., channel]
channel_array = tf.expand_dims(channel_array, -1)
for axis, name in enumerate(axes, start=1):
image_mip = tf.reduce_max(channel_array, axis=axis)
outputs_collector.add_to_collection(
var=image_mip, name='input_image_channel_{}_{}'.format(channel, name),
average_over_devices=False, summary_type='image',
collection=TF_SUMMARIES)
limit = 2
image_float = tf.to_float(channel_array)
image_clipped = tf.clip_by_value(image_float, -limit, limit) # image has been whitened
image_clipped += limit
image_clipped /= 2 * limit
image_scaled = 255 * image_clipped
image_uint8 = tf.cast(image_scaled, tf.uint8)
outputs_collector.add_to_collection(
var=image_uint8, name='input_image_channel_{}_{}'.format(channel, name),
average_over_devices=False, summary_type='image3_{}'.format(name),
collection=TF_SUMMARIES)
softmaxed_output = tf.nn.softmax(net_out)
prediction = softmaxed_output[..., 1] # foreground only
for axis, name in enumerate(axes, start=1):
prediction_mip = tf.reduce_max(prediction, axis=axis)
prediction_mip = tf.expand_dims(prediction_mip, 3)
colorized = self.colorize(
value=prediction_mip,
vmin=0,
vmax=1,
cmap='RdBu_r',
)
outputs_collector.add_to_collection(
var=colorized, name='prediction_{}'.format(name),
average_over_devices=False, summary_type='image',
collection=TF_SUMMARIES)
for axis, name in enumerate(axes, start=1):
binary_prediction = tf.round(prediction)
binary_prediction_mip = tf.reduce_max(binary_prediction, axis=axis)
binary_prediction_mip = tf.expand_dims(binary_prediction_mip, 3)
ground_truth = data_dict.get('label', None) # (1, 96, 96, 96, 1)
ground_truth_mip = tf.reduce_max(ground_truth, axis=axis) # (1, 96, 96, 1)
green_magenta_mip = tf.concat(
values=(
binary_prediction_mip,
ground_truth_mip,
binary_prediction_mip,
),
axis=-1,
)
outputs_collector.add_to_collection(
var=green_magenta_mip, name='green_magenta_{}'.format(name),
average_over_devices=False, summary_type='image',
collection=TF_SUMMARIES)
## 3D RGB doesn't seem to be supported
# green_magenta = tf.concat(
# values=(
# binary_prediction,
# ground_truth[..., 0],
# binary_prediction,
# ),
# axis=-1,
# )
# outputs_collector.add_to_collection(
# var=green_magenta, name='green_magenta',
# average_over_devices=False, summary_type='image3_sagittal',
# collection=TF_SUMMARIES)
# outputs_collector.add_to_collection(
# var=image, name='image_output_test',
# average_over_devices=False,
# collection=NETWORK_OUTPUT)
# outputs_collector.add_to_collection(
# var=tf.reduce_mean(image), name='mean_image',
# average_over_devices=False, summary_type='scalar',
# collection=CONSOLE)
elif self.is_inference:
# converting logits into final output for
# classification probabilities or argmax classification labels
data_dict = switch_sampler(for_training=False)
image = tf.cast(data_dict['image'], tf.float32)
net_args = {'is_training': True, # (Reuben modif)
'keep_prob': self.net_param.keep_prob}
net_out = self.net(image, **net_args)
output_prob = self.segmentation_param.output_prob
num_classes = self.segmentation_param.num_classes
if output_prob and num_classes > 1:
post_process_layer = PostProcessingLayer(
'SOFTMAX', num_classes=num_classes)
elif not output_prob and num_classes > 1:
post_process_layer = PostProcessingLayer(
'ARGMAX', num_classes=num_classes)
else:
post_process_layer = PostProcessingLayer(
'IDENTITY', num_classes=num_classes)
net_out = post_process_layer(net_out)
outputs_collector.add_to_collection(
var=net_out, name='window',
average_over_devices=False, collection=NETWORK_OUTPUT)
outputs_collector.add_to_collection(
var=data_dict['image_location'], name='location',
average_over_devices=False, collection=NETWORK_OUTPUT)
self.initialise_aggregator()
def interpret_output(self, batch_output):
if self.is_inference:
return self.output_decoder.decode_batch(
batch_output['window'], batch_output['location'])
return True
def initialise_evaluator(self, eval_param):
self.eval_param = eval_param
self.evaluator = SegmentationEvaluator(self.readers[0],
self.segmentation_param,
eval_param)
def add_inferred_output(self, data_param, task_param):
return self.add_inferred_output_like(data_param, task_param, 'label')
def colorize(self, value, vmin=None, vmax=None, cmap=None):
"""
A utility function for TensorFlow that maps a grayscale image to a matplotlib
colormap for use with TensorBoard image summaries.
By default it will normalize the input value to the range 0..1 before mapping
to a grayscale colormap.
Arguments:
- value: 2D Tensor of shape [height, width] or 3D Tensor of shape
[height, width, 1].
- vmin: the minimum value of the range used for normalization.
(Default: value minimum)
- vmax: the maximum value of the range used for normalization.
(Default: value maximum)
- cmap: a valid cmap named for use with matplotlib's `get_cmap`.
(Default: 'gray')
Example usage:
```
output = tf.random_uniform(shape=[256, 256, 1])
output_color = colorize(output, vmin=0.0, vmax=1.0, cmap='viridis')
tf.summary.image('output', output_color)
```
Returns a 3D tensor of shape [height, width, 3].
"""
import matplotlib
import matplotlib.cm
import numpy as np
# normalize
vmin = tf.reduce_min(value) if vmin is None else vmin
vmax = tf.reduce_max(value) if vmax is None else vmax
value = (value - vmin) / (vmax - vmin) # vmin..vmax
# squeeze last dim if it exists
value = tf.squeeze(value)
# quantize
indices = tf.to_int32(tf.round(value * 255))
# gather
cm = matplotlib.cm.get_cmap(cmap if cmap is not None else 'gray')
# colors = tf.constant(cm.colors, dtype=tf.float32)
colors = cm(np.arange(256))[:, :3]
colors = tf.constant(colors, dtype=tf.float32)
value = tf.gather(colors, indices)
value = tf.expand_dims(value, 0)
return value
......@@ -55,7 +55,8 @@ SEGMENTATION_NIFTYNET = '.'.join([
'segmentation_application',
'SegmentationApplication'
])
SEGMENTATION_APP = 'segmentation_application_mine.SegmentationApplicationMine'
SEGMENTATION_APP = 'vesseg_applications.segmentation_application_vessels.SegmentationApplicationVessels'
# Random seed
......@@ -80,9 +81,9 @@ VOXEL_SPACING = 0.466 # most DSAs
NEAREST_NEIGHBOR = 0
LINEAR = 1 # true?
X_INTERPOLATION = 3
TRAINING_IMAGE_WINDOW_SIZE = 64
TRAINING_IMAGE_WINDOW_SIZE = 96
TRAINING_LABEL_WINDOW_SIZE = TRAINING_IMAGE_WINDOW_SIZE
INFERENCE_WINDOW_SIZE = 64
INFERENCE_WINDOW_SIZE = 96
# System
NUM_THREADS = 2
......@@ -90,14 +91,14 @@ NUM_THREADS = 2
# Network
# NETWORK_NAME = UNET_SMALL
NETWORK_NAME = HIGHRES3DNET_SMALLER
BATCH_SIZE = 1
BATCH_SIZE = 4
WINDOW_SAMPLING = 'weighted'
VOLUME_PADDING_SIZE = TRAINING_IMAGE_WINDOW_SIZE // 2
# Training
MAX_NUM_ITERATIONS = 10000
SAVE_EVERY_N = MAX_NUM_ITERATIONS // 10
LEARNING_RATE = 1e-3
LEARNING_RATE = 1e-1
TENSORBOARD_EVERY_N = 1
VALIDATION_EVERY_N = 5
LATEST_CHECKPOINT = -1
......@@ -183,6 +184,8 @@ class Model:
self.config = ConfigParser()
self.application = SEGMENTATION_APP
def __repr__(self):
return self.name
......