visualization.py 9.08 KB
Newer Older
1
import re
2 3
from pathlib import Path
import numpy as np
Fernando Perez-Garcia's avatar
Fernando Perez-Garcia committed
4
from . import metrics
5
from ...nifti import load
6
from ...utils import get_os
7 8


9
def get_mip_collage(images_dict, metrics_dict=None):
10 11 12 13 14 15 16 17 18 19
    """
    Example:
    images_dict = {
        0: None,
        1: image1,
        2: image2,
    }

    Suppose everything is on MNI space
    """
Fernando Perez-Garcia's avatar
Fernando Perez-Garcia committed
20
    from PIL import Image, ImageFont, ImageDraw
21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38
    for array in images_dict.values():
        if array is not None:
            si, sj, sk = array.shape
            break
    shape_rgb = (si, sj, sk, 3)
    array_rgb = np.zeros(shape_rgb, np.uint8)
    for dim, array in images_dict.items():
        if array is not None:
            array_rgb[..., dim] = array
    array_rgb *= 255

    sagittal, coronal, axial = [array_rgb.max(axis=i) for i in range(3)]
    sagittal = np.rot90(sagittal)
    sagittal = np.fliplr(sagittal)
    coronal = np.rot90(coronal)
    axial = np.rot90(axial)

    column_left = np.vstack((axial, coronal))
Fernando Perez-Garcia's avatar
Fernando Perez-Garcia committed
39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54
    if metrics_dict is None:
        upper_right = np.zeros((sj, sj, 3), np.uint8)
    else:
        label_array = metrics_dict['label_array']
        prediction_array = metrics_dict['prediction_array']
        confusion = metrics.get_confusion_matrix_elements_from_arrays(
            label_array, prediction_array)
        dice = metrics.dice(confusion)
        precision = metrics.precision(confusion)
        recall = metrics.recall(confusion)
        lines = [
            f'Dice: {dice:.3f}',
            f'Precision: {precision:.3f}',
            f'Recall: {recall:.3f}',
        ]
        text = '\n'.join(lines)
55 56 57 58 59 60 61 62 63
        os = get_os()
        if os == 'linux':
            font_path = '/usr/share/fonts/truetype/ubuntu/Ubuntu-R.ttf'
            fontsize = 48
        elif os == 'mac':
            fonts_dir = Path('/ Users/fernando/Library/Fonts')
            font_path = fonts_dir / 'Meslo LG M DZ Regular for Powerline.ttf'
            fontsize = 36
        font = ImageFont.truetype(str(font_path), fontsize)
Fernando Perez-Garcia's avatar
Fernando Perez-Garcia committed
64 65 66
        image = Image.new('RGB', (sj, sj), (0, 0, 0))
        draw = ImageDraw.Draw(image)
        image_size_x, image_size_y = image.size
67 68

        font = ImageFont.truetype(font_path, fontsize)
Fernando Perez-Garcia's avatar
Fernando Perez-Garcia committed
69
        text_size_x, text_size_y = draw.multiline_textsize(text, font=font)
70 71 72 73
        while text_size_x >= image_size_x:
            fontsize -= 2
            font = ImageFont.truetype(font_path, fontsize)
            text_size_x, text_size_y = draw.multiline_textsize(text, font=font)
Fernando Perez-Garcia's avatar
Fernando Perez-Garcia committed
74 75 76 77 78 79 80
        start_x = image_size_x // 2 - text_size_x // 2
        start_y = image_size_y // 2 - text_size_y // 2
        xy = start_x, start_y
        draw.multiline_text(
            xy, text, fill=(200, 200, 200), font=font, align='right')
        upper_right = np.array(image)
    column_right = np.vstack((upper_right, sagittal))
81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98
    collage = np.hstack((column_left, column_right))
    collage = Image.fromarray(collage)

    return collage


def create_confusion_collage(label_path, prediction_path):
    """
    Suppose everything is on MNI space
    """
    label_array = load(label_path).get_data() > 0
    prediction_array = load(prediction_path).get_data().squeeze() > 0

    collage_both = get_mip_collage(
        {
            0: prediction_array,
            1: label_array,
            2: prediction_array,
Fernando Perez-Garcia's avatar
Fernando Perez-Garcia committed
99 100 101 102
        },
        metrics_dict={
            'label_array': label_array,
            'prediction_array': prediction_array,
103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133
        }
    )

    # xor = label_array ^ prediction_array
    # FN = xor & label_array
    # FP = xor & prediction_array
    # collage_fp = get_mip_collage(
    #     {
    #         0: FP,
    #         1: None,
    #         2: FP,
    #     }
    # )

    # collage_fn = get_mip_collage(
    #     {
    #         0: None,
    #         1: FN,
    #         2: None,
    #     }
    # )

    # collages = (
    #     collage_both,
    #     # collage_fp,
    #     # collage_fn,
    # )

    return collage_both


134 135 136 137 138 139 140 141
def make_html(models_dirs,
              output_html_path,
              title=None,
              force=False,
              figure_path=None):
    import matplotlib.pyplot as plt
    import seaborn as sns
    sns.set()
142 143 144
    custom = []
    a = custom.append

145 146 147 148
    fig, ax = plt.subplots()
    if figure_path is None:
        figure_path = output_html_path.parent / f'{output_html_path.stem}.png'

149
    for model_dir in models_dirs:
150
        if title is None:
151
            title = model_dir.relative_to(model_dir.parents[3])
152

153 154
        model_dir = Path(model_dir)
        model_name = model_dir.name
155

156 157 158 159 160 161 162 163 164 165
        inferred_dir = model_dir / 'inferred'

        # TODO: remove hard-coded stuff
        subject = 'rade'
        types = (
            'dsa_ica_left',
            'dsa_ica_right',
            'dsa_vert',
        )

166 167 168 169 170
        # Plot loss
        log_path = model_dir / 'model_k_9' / 'training_niftynet_log'
        loss_dict = read_training_loss(log_path)
        plot_loss(ax, loss_dict['training'], f'Training - {model_name}')
        plot_loss(ax, loss_dict['validation'], f'Validation - {model_name}')
171 172 173 174 175 176

        # Add model name
        a('<div class="row">')
        a(f'<h2>{model_name}</h2>')
        a('</div>')

177 178 179 180 181 182
        # Add link to config
        config_path = model_dir / 'config' / 'config.ini'
        a('<div class="row">')
        a(f'<a href="{config_path}">Open config</a>')
        a('</div>')

183 184 185 186 187 188 189 190 191
        # Add titles
        a('<div class="row">')
        for image_type in types:
            string = image_type.replace('_', ' ').upper()
            a(f'<div class="col-sm">{string}</div>')
        a('</div>')  # row

        # Add images
        a('<div class="row">')
192 193
        subject_dir = Path('~/mres_project/subjects').expanduser() / subject
        dsa_dir = subject_dir / 'scans' / 'dsa'
194 195 196 197 198 199
        for image_type in types:
            subject_image = f'{subject}_{image_type}'
            prediction_stem = f'{subject_image}_niftynet_out_seg'
            collage_name = f'{prediction_stem}.png'
            collage_path = inferred_dir / collage_name
            if not collage_path.is_file() or force:
200
                print(f'Computing images for {model_name} - {subject_image}...')
201 202 203 204 205 206 207 208 209 210 211 212 213 214
                label_name = f'{subject_image}_boneless_thresholded_on_mni.nii.gz'
                label_path = dsa_dir / label_name
                prediction_name = f'{prediction_stem}.nii.gz'
                prediction_path = inferred_dir / prediction_name
                if not prediction_path.is_file():
                    print(f'{prediction_path} does not exist')
                    continue
                collage = create_confusion_collage(label_path, prediction_path)
                collage.save(collage_path)
            a('<div class="col-sm">')
            a(f'<img src="{collage_path}" class="img-fluid">')
            a('</div>')
        a('</div>')  # row

215 216 217
    plt.legend()
    fig.savefig(figure_path)#, dpi=400)

218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239
    lines = '\n'.join(custom)

    text = f"""
    <!doctype html>
    <html lang="en">
        <head>
            <!-- Required meta tags -->
            <meta charset="utf-8">
            <meta name="viewport" content="width=device-width, initial-scale=1, shrink-to-fit=no">
            <!-- Bootstrap CSS -->
            <link rel="stylesheet"
                href="https://maxcdn.bootstrapcdn.com/bootstrap/4.0.0/css/bootstrap.min.css"
                integrity="sha384-Gn5384xqQ1aoWXA+058RXPxPg6fy4IWvTNh0E263XmFcJlSAwiGgFAW/dAiS6JXm"
                crossorigin="anonymous">
            <title>{title}</title>
        </head>
        <body style="background-color: white">
            <div class="jumbotron">
                <div class="container">
                    <h1 id="title">{title}</h1>
                </div>
            </div>
240 241 242
            <div class="container-fluid">
                <img src="{figure_path}" class="img-fluid">
            </div>
243 244 245 246 247 248 249 250 251
            <div class="container-fluid">
                <!-- Images here -->
                {lines}
            </div>
        </body>
    </html>
    """
    Path(output_html_path).write_text(text)
    return text
252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294


def read_loss_from_text(text, word):
    import pandas as pd
    pattern = word + r' iter (\d+), loss=(\d\.\d+)'
    matches = re.findall(pattern, text)
    if not matches:
        return None
    iterations, losses = np.array(matches).T
    iterations = iterations.astype(np.uint16)
    losses = losses.astype(np.float32)
    series = pd.Series(index=iterations, data=losses)
    return series


def read_training_loss(training_log_path):
    text = Path(training_log_path).read_text()
    training_series = read_loss_from_text(
        text, 'training')
    validation_series = read_loss_from_text(
        text, 'validation')
    loss_dict = {
        'training': training_series,
        'validation': validation_series,  # might be None depending of NiftyNet config
    }
    return loss_dict


def plot_loss(axis, series, label, window=100):
    if series is None:
        return
    # https://stackoverflow.com/a/35710894/3956024
    series.plot(
        ax=axis,
        linewidth=0.25,
        alpha=0.25,
        grid=True,
        zorder=-1,
        label='_nolegend_',
    )
    moving_avg = series.rolling(window).mean()
    color = axis.get_lines()[-1].get_color()
    moving_avg.plot(ax=axis, color=color, label=label)