import re from pathlib import Path import numpy as np from . import metrics from ...nifti import load from ...utils import get_os def get_mip_collage(images_dict, metrics_dict=None): """ Example: images_dict = { 0: None, 1: image1, 2: image2, } Suppose everything is on MNI space """ from PIL import Image, ImageFont, ImageDraw for array in images_dict.values(): if array is not None: si, sj, sk = array.shape break shape_rgb = (si, sj, sk, 3) array_rgb = np.zeros(shape_rgb, np.uint8) for dim, array in images_dict.items(): if array is not None: array_rgb[..., dim] = array array_rgb *= 255 sagittal, coronal, axial = [array_rgb.max(axis=i) for i in range(3)] sagittal = np.rot90(sagittal) sagittal = np.fliplr(sagittal) coronal = np.rot90(coronal) axial = np.rot90(axial) column_left = np.vstack((axial, coronal)) if metrics_dict is None: upper_right = np.zeros((sj, sj, 3), np.uint8) else: label_array = metrics_dict['label_array'] prediction_array = metrics_dict['prediction_array'] confusion = metrics.get_confusion_matrix_elements_from_arrays( label_array, prediction_array) dice = metrics.dice(confusion) precision = metrics.precision(confusion) recall = metrics.recall(confusion) lines = [ f'Dice: {dice:.3f}', f'Precision: {precision:.3f}', f'Recall: {recall:.3f}', ] text = '\n'.join(lines) os = get_os() if os == 'linux': font_path = '/usr/share/fonts/truetype/ubuntu/Ubuntu-R.ttf' fontsize = 48 elif os == 'mac': fonts_dir = Path('/ Users/fernando/Library/Fonts') font_path = fonts_dir / 'Meslo LG M DZ Regular for Powerline.ttf' fontsize = 36 font = ImageFont.truetype(str(font_path), fontsize) image = Image.new('RGB', (sj, sj), (0, 0, 0)) draw = ImageDraw.Draw(image) image_size_x, image_size_y = image.size font = ImageFont.truetype(font_path, fontsize) text_size_x, text_size_y = draw.multiline_textsize(text, font=font) while text_size_x >= image_size_x: fontsize -= 2 font = ImageFont.truetype(font_path, fontsize) text_size_x, text_size_y = draw.multiline_textsize(text, font=font) start_x = image_size_x // 2 - text_size_x // 2 start_y = image_size_y // 2 - text_size_y // 2 xy = start_x, start_y draw.multiline_text( xy, text, fill=(200, 200, 200), font=font, align='right') upper_right = np.array(image) column_right = np.vstack((upper_right, sagittal)) collage = np.hstack((column_left, column_right)) collage = Image.fromarray(collage) return collage def create_confusion_collage(label_path, prediction_path): """ Suppose everything is on MNI space """ label_array = load(label_path).get_data() > 0 prediction_array = load(prediction_path).get_data().squeeze() > 0 collage_both = get_mip_collage( { 0: prediction_array, 1: label_array, 2: prediction_array, }, metrics_dict={ 'label_array': label_array, 'prediction_array': prediction_array, } ) # xor = label_array ^ prediction_array # FN = xor & label_array # FP = xor & prediction_array # collage_fp = get_mip_collage( # { # 0: FP, # 1: None, # 2: FP, # } # ) # collage_fn = get_mip_collage( # { # 0: None, # 1: FN, # 2: None, # } # ) # collages = ( # collage_both, # # collage_fp, # # collage_fn, # ) return collage_both def make_html(models_dirs, output_html_path, title=None, force=False, figure_path=None): import matplotlib.pyplot as plt import seaborn as sns sns.set() custom = [] a = custom.append fig, ax = plt.subplots() if figure_path is None: figure_path = output_html_path.parent / f'{output_html_path.stem}.png' for model_dir in models_dirs: if title is None: title = model_dir.relative_to(model_dir.parents[3]) model_dir = Path(model_dir) model_name = model_dir.name inferred_dir = model_dir / 'inferred' # TODO: remove hard-coded stuff subject = 'rade' types = ( 'dsa_ica_left', 'dsa_ica_right', 'dsa_vert', ) # Plot loss log_path = model_dir / 'model_k_9' / 'training_niftynet_log' loss_dict = read_training_loss(log_path) plot_loss(ax, loss_dict['training'], f'Training - {model_name}') plot_loss(ax, loss_dict['validation'], f'Validation - {model_name}') # Add model name a('
') a(f'

{model_name}

') a('
') # Add link to config config_path = model_dir / 'config' / 'config.ini' a('
') a(f'Open config') a('
') # Add titles a('
') for image_type in types: string = image_type.replace('_', ' ').upper() a(f'
{string}
') a('
') # row # Add images a('
') subject_dir = Path('~/mres_project/subjects').expanduser() / subject dsa_dir = subject_dir / 'scans' / 'dsa' for image_type in types: subject_image = f'{subject}_{image_type}' prediction_stem = f'{subject_image}_niftynet_out_seg' collage_name = f'{prediction_stem}.png' collage_path = inferred_dir / collage_name if not collage_path.is_file() or force: print(f'Computing images for {model_name} - {subject_image}...') label_name = f'{subject_image}_boneless_thresholded_on_mni.nii.gz' label_path = dsa_dir / label_name prediction_name = f'{prediction_stem}.nii.gz' prediction_path = inferred_dir / prediction_name if not prediction_path.is_file(): print(f'{prediction_path} does not exist') continue collage = create_confusion_collage(label_path, prediction_path) collage.save(collage_path) a('
') a(f'') a('
') a('
') # row plt.legend() fig.savefig(figure_path)#, dpi=400) lines = '\n'.join(custom) text = f""" {title}

{title}

{lines}
""" Path(output_html_path).write_text(text) return text def read_loss_from_text(text, word): import pandas as pd pattern = word + r' iter (\d+), loss=(\d\.\d+)' matches = re.findall(pattern, text) if not matches: return None iterations, losses = np.array(matches).T iterations = iterations.astype(np.uint16) losses = losses.astype(np.float32) series = pd.Series(index=iterations, data=losses) return series def read_training_loss(training_log_path): text = Path(training_log_path).read_text() training_series = read_loss_from_text( text, 'training') validation_series = read_loss_from_text( text, 'validation') loss_dict = { 'training': training_series, 'validation': validation_series, # might be None depending of NiftyNet config } return loss_dict def plot_loss(axis, series, label, window=100): if series is None: return # https://stackoverflow.com/a/35710894/3956024 series.plot( ax=axis, linewidth=0.25, alpha=0.25, grid=True, zorder=-1, label='_nolegend_', ) moving_avg = series.rolling(window).mean() color = axis.get_lines()[-1].get_color() moving_avg.plot(ax=axis, color=color, label=label)