From 7c3c0b2aad6957a3c3bd295549072256bef9ed32 Mon Sep 17 00:00:00 2001 From: Fernando Perez-Garcia Date: Wed, 13 Feb 2019 18:32:27 +0000 Subject: [PATCH] Add custom spatial scaling layer --- vesseg/network/layer/rand_spatial_scaling.py | 121 +++++++++++++++++++ 1 file changed, 121 insertions(+) create mode 100644 vesseg/network/layer/rand_spatial_scaling.py diff --git a/vesseg/network/layer/rand_spatial_scaling.py b/vesseg/network/layer/rand_spatial_scaling.py new file mode 100644 index 0000000..2e08130 --- /dev/null +++ b/vesseg/network/layer/rand_spatial_scaling.py @@ -0,0 +1,121 @@ +# -*- coding: utf-8 -*- +from __future__ import absolute_import, print_function + +import warnings + +import tensorflow as tf +import numpy as np +import scipy.ndimage as ndi + +from niftynet.layer.base_layer import RandomisedLayer + +warnings.simplefilter("ignore", UserWarning) +warnings.simplefilter("ignore", RuntimeWarning) + + +class RandomSpatialScalingLayer(RandomisedLayer): + """ + generate randomised scaling along each dim for data augmentation + """ + + def __init__(self, + min_percentage=-10.0, + max_percentage=10.0, + antialiasing=True, + name='random_spatial_scaling', + isotropic=True): + super(RandomSpatialScalingLayer, self).__init__(name=name) + assert min_percentage <= max_percentage + self._min_percentage = max(min_percentage, -99.9) + self._max_percentage = max_percentage + self.antialiasing = antialiasing + self._rand_zoom = None + self.isotropic = isotropic + + def randomise(self, spatial_rank=3): + spatial_rank = int(np.floor(spatial_rank)) + if self.isotropic: + one_rand_zoom = np.random.uniform(low=self._min_percentage, + high=self._max_percentage) + rand_zoom = np.array(spatial_rank * [one_rand_zoom]) + else: + rand_zoom = np.random.uniform(low=self._min_percentage, + high=self._max_percentage, + size=(spatial_rank,)) + tf.logging.info('Random zoom: {}'.format(rand_zoom)) + self._rand_zoom = (rand_zoom + 100.0) / 100.0 + + def _get_sigma(self, zoom): + """ + Compute optimal standard deviation for Gaussian kernel. + + Cardoso et al., "Scale factor point spread function matching: + beyond aliasing in image resampling", MICCAI 2015 + """ + k = 1 / zoom + variance = (k ** 2 - 1 ** 2) * (2 * np.sqrt(2 * np.log(2))) ** (-2) + sigma = np.sqrt(variance) + return sigma + + def _apply_transformation(self, image, interp_order=3): + if interp_order < 0: + return image + assert self._rand_zoom is not None + full_zoom = np.array(self._rand_zoom) + while len(full_zoom) < image.ndim: + full_zoom = np.hstack((full_zoom, [1.0])) + is_undersampling = all(full_zoom[:3] < 1) + run_antialiasing_filter = self.antialiasing and is_undersampling + if run_antialiasing_filter: + sigma = self._get_sigma(full_zoom[:3]) + # tf.logging.info('Sigma: {}'.format(sigma)) + if image.ndim == 4: + output = [] + for mod in range(image.shape[-1]): + # if image.min() == 0 and image.max() == 1: + # binary = True + # else: + # binary = False + to_scale = ndi.gaussian_filter(image[..., mod], sigma) if \ + run_antialiasing_filter else image[..., mod] + scaled = ndi.zoom(to_scale, full_zoom[:3], order=interp_order) + # if binary: + # scaled = np.round(scaled) + output.append(scaled[..., np.newaxis]) + return np.concatenate(output, axis=-1) + elif image.ndim == 3: + # if image.min() == 0 and image.max() == 1: + # binary = True + # else: + # binary = False + to_scale = ndi.gaussian_filter(image, sigma) \ + if run_antialiasing_filter else image + scaled = ndi.zoom( + to_scale, full_zoom[:3], order=interp_order) + # if binary: + # scaled = np.round(scaled) + return scaled[..., np.newaxis] + else: + raise NotImplementedError('not implemented random scaling') + + def layer_op(self, inputs, interp_orders, *args, **kwargs): + if inputs is None: + return inputs + + if isinstance(inputs, dict) and isinstance(interp_orders, dict): + + for (field, image) in inputs.items(): + transformed_data = [] + interp_order = interp_orders[field][0] + for mod_i in range(image.shape[-1]): + scaled_data = self._apply_transformation( + image[..., mod_i], interp_order) + transformed_data.append(scaled_data[..., np.newaxis]) + inputs[field] = np.concatenate(transformed_data, axis=-1) + # shapes = [] + # for (field, image) in inputs.items(): + # shapes.append(image.shape) + # assert(len(shapes) == 2 and shapes[0][0:4] == shapes[1][0:4]), shapes + else: + raise NotImplementedError("unknown input format") + return inputs -- GitLab