Commit c7c8e0e0 authored by Fernando Perez-Garcia's avatar Fernando Perez-Garcia

Add loss plots to HTMLs

parent cc8bf34f
import re
from pathlib import Path
import numpy as np
from . import metrics
from ...nifti import load
def get_mip_collage(images_dict, metrics_dict=None, fontsize=48):
def get_mip_collage(images_dict, metrics_dict=None, fontsize=72):
"""
Example:
images_dict = {
......@@ -51,11 +52,16 @@ def get_mip_collage(images_dict, metrics_dict=None, fontsize=48):
]
text = '\n'.join(lines)
font_path = '/usr/share/fonts/truetype/ubuntu/Ubuntu-R.ttf'
font = ImageFont.truetype(font_path, fontsize)
image = Image.new('RGB', (sj, sj), (0, 0, 0))
draw = ImageDraw.Draw(image)
image_size_x, image_size_y = image.size
font = ImageFont.truetype(font_path, fontsize)
text_size_x, text_size_y = draw.multiline_textsize(text, font=font)
while text_size_x >= image_size_x:
fontsize -= 2
font = ImageFont.truetype(font_path, fontsize)
text_size_x, text_size_y = draw.multiline_textsize(text, font=font)
start_x = image_size_x // 2 - text_size_x // 2
start_y = image_size_y // 2 - text_size_y // 2
xy = start_x, start_y
......@@ -116,13 +122,24 @@ def create_confusion_collage(label_path, prediction_path):
return collage_both
def make_html(models_dirs, output_html_path, title=None, force=False):
def make_html(models_dirs,
output_html_path,
title=None,
force=False,
figure_path=None):
import matplotlib.pyplot as plt
import seaborn as sns
sns.set()
custom = []
a = custom.append
fig, ax = plt.subplots()
if figure_path is None:
figure_path = output_html_path.parent / f'{output_html_path.stem}.png'
for model_dir in models_dirs:
if title is None:
title = model_dir.parent.name
title = model_dir.relative_to(model_dir.parents[3])
model_dir = Path(model_dir)
model_name = model_dir.name
......@@ -137,14 +154,23 @@ def make_html(models_dirs, output_html_path, title=None, force=False):
'dsa_vert',
)
subject_dir = Path('~/mres_project/subjects').expanduser() / subject
dsa_dir = subject_dir / 'scans' / 'dsa'
# Plot loss
log_path = model_dir / 'model_k_9' / 'training_niftynet_log'
loss_dict = read_training_loss(log_path)
plot_loss(ax, loss_dict['training'], f'Training - {model_name}')
plot_loss(ax, loss_dict['validation'], f'Validation - {model_name}')
# Add model name
a('<div class="row">')
a(f'<h2>{model_name}</h2>')
a('</div>')
# Add link to config
config_path = model_dir / 'config' / 'config.ini'
a('<div class="row">')
a(f'<a href="{config_path}">Open config</a>')
a('</div>')
# Add titles
a('<div class="row">')
for image_type in types:
......@@ -154,6 +180,8 @@ def make_html(models_dirs, output_html_path, title=None, force=False):
# Add images
a('<div class="row">')
subject_dir = Path('~/mres_project/subjects').expanduser() / subject
dsa_dir = subject_dir / 'scans' / 'dsa'
for image_type in types:
subject_image = f'{subject}_{image_type}'
prediction_stem = f'{subject_image}_niftynet_out_seg'
......@@ -175,6 +203,9 @@ def make_html(models_dirs, output_html_path, title=None, force=False):
a('</div>')
a('</div>') # row
plt.legend()
fig.savefig(figure_path)#, dpi=400)
lines = '\n'.join(custom)
text = f"""
......@@ -197,6 +228,9 @@ def make_html(models_dirs, output_html_path, title=None, force=False):
<h1 id="title">{title}</h1>
</div>
</div>
<div class="container-fluid">
<img src="{figure_path}" class="img-fluid">
</div>
<div class="container-fluid">
<!-- Images here -->
{lines}
......@@ -206,3 +240,46 @@ def make_html(models_dirs, output_html_path, title=None, force=False):
"""
Path(output_html_path).write_text(text)
return text
def read_loss_from_text(text, word):
import pandas as pd
pattern = word + r' iter (\d+), loss=(\d\.\d+)'
matches = re.findall(pattern, text)
if not matches:
return None
iterations, losses = np.array(matches).T
iterations = iterations.astype(np.uint16)
losses = losses.astype(np.float32)
series = pd.Series(index=iterations, data=losses)
return series
def read_training_loss(training_log_path):
text = Path(training_log_path).read_text()
training_series = read_loss_from_text(
text, 'training')
validation_series = read_loss_from_text(
text, 'validation')
loss_dict = {
'training': training_series,
'validation': validation_series, # might be None depending of NiftyNet config
}
return loss_dict
def plot_loss(axis, series, label, window=100):
if series is None:
return
# https://stackoverflow.com/a/35710894/3956024
series.plot(
ax=axis,
linewidth=0.25,
alpha=0.25,
grid=True,
zorder=-1,
label='_nolegend_',
)
moving_avg = series.rolling(window).mean()
color = axis.get_lines()[-1].get_color()
moving_avg.plot(ax=axis, color=color, label=label)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment