Commit 5718316b authored by Fernando Perez-Garcia's avatar Fernando Perez-Garcia

Add jobs progress bar

parent 2ceaf73e
from os.path import join
class Model(object):
def __init__(self, model_dir):
self.dir = model_dir
self.config_dir = join(self.dir, 'config')
self.inferred_dir = join(self.dir, 'inferred')
self.inferred_csv_path = join(self.inferred_dir, 'inferred.csv')
model_dir = '/mnt/comic/cluster/project0/vesseg/mres_project/learning/models/highres3dnet_223/scaling/highres3dnet_223_lr_1e-4_scaling_75'
......@@ -3,6 +3,7 @@
import sys
print('Importing libraries...')
from pathlib import Path
from tqdm import tqdm
from vesseg import Model, Job
from vesseg.network.model import DSA, T1, T1_GAD
......@@ -20,6 +21,8 @@ networks = 233, 223, 222
batch_sizes = 1, 2, 3, 4
learning_rates = 0.01, 0.02, 0.05, 0.1, 0.2, 0.5, 1, 2, 5
jobs = []
for application in applications:
for net in networks:
for batch_size in batch_sizes:
......@@ -51,9 +54,16 @@ for application in applications:
job = Job(model_dir, model.config_path)
job.application = application_name
print(f'Creating {job.path}...')
job.write()
if len(sys.argv) > 1 and sys.argv[1] == '--submit':
print(f'Submitting {model_name}...')
job.submit()
jobs.append(job)
progress_bar = tqdm(jobs)
for job in progress_bar:
progress_bar.set_description(f'Creating {job.path}...')
job.write()
progress_bar = tqdm(jobs)
if len(sys.argv) > 1 and sys.argv[1] == '--submit':
for job in progress_bar:
# progress_bar.set_description(f'Submitting {model_name}...')
job.submit()
from pathlib import Path
import matplotlib.pyplot as plt
from vesseg.network.evaluation import metrics
models_dir = Path('/mnt/comic/cluster/project0/vesseg/mres_project/learning/models/highres3dnet_223/scaling')
inferred_ids = (
# 'rade_dsa_ica_left',
# 'rade_dsa_ica_right',
'rade_dsa_vert',
)
models = sorted(list(models_dir.iterdir()))
dice = {}
for i, model_dir in enumerate(models):
model_name = model_dir.name
print(model_name)
for image_id in inferred_ids:
label_path = f'/home/fernando/mres_project/subjects/rade/scans/dsa/{image_id}_boneless_thresholded_on_mni.nii.gz'
prediction_path = model_dir / 'inferred' / f'{image_id}_niftynet_out_seg.nii.gz'
if not prediction_path.is_file():
print(prediction_path.name, 'does not exist')
continue
confusion = metrics.get_confusion_matrix_elements(label_path, prediction_path)
dice[model_name] = metrics.dice(confusion)
fig, ax = plt.subplots()
for i, (model_name, dice_score) in enumerate(dice.items()):
ax.scatter(i, dice_score, label=model_name)
ax.grid()
plt.legend()
plt.show()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment