Commit d499658d authored by Fernando Perez-Garcia's avatar Fernando Perez-Garcia

Add script to train many networks

parent 5b09c379
#!/usr/bin/env python3
import sys
print('Importing libraries...')
from pathlib import Path
from vesseg import Model, Job
from vesseg.network.model import DSA, T1, T1_GAD
learning_dir = Path('~/mres_project/learning').expanduser()
combinations = (
(DSA,),
(DSA, T1),
(DSA, T1_GAD),
(DSA, T1, T1_GAD),
)
learning_rates = 0.01, 0.02, 0.05, 0.1, 0.2, 0.5, 1, 2, 5
networks = 222, 223, 233
applications = 'Vessels', 'Vessels_Reuben'
for application in applications:
for net in networks:
for inputs in combinations:
for lr in learning_rates:
models_dir = learning_dir / application / f'highres3dnet_{net}'
network = f'vesseg_networks.highres3dnet_{net}.HighRes3DNet{net}'
application = (
'vesseg_applications'
f'.segmentation_application_{application.lower()}'
f'.SegmentationApplication{application.replace("_", "")}'
)
string = '-'.join(inputs)
model_name = f'highres3dnet_{net}_{string}_lr_{lr}'
print(f'Creating {model_name}...')
model_dir = models_dir / model_name
model = Model(model_dir=model_dir, inputs=inputs)
model.set_images_and_labels_paths()
model.make_csv_files(split_type='subject')
model.config_all()
model.config_training(learning_rate=lr)
model.config_network(network_name=network)
model.write_config_file()
job = Job(model_dir, model.config_path)
job.application = application
job.write()
if len(sys.argv) > 1 and sys.argv[1] == '--submit':
print(f'Submitting {model_name}...')
job.submit()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment