Commit 9c797908 authored by Paul McInnis's avatar Paul McInnis

transfer learning script to produce COVIDNet-Risk from COVIDNet

parent 29d6267c
......@@ -16,7 +16,7 @@ The COVID-19 pandemic continues to have a devastating effect on the health and w
For a detailed description of the methodology behind COVID-Net and a full description of the COVIDx dataset, please click [here](
Currently, the COVID-Net team is working on COVID-RiskNet, a deep neural network tailored for COVID-19 risk stratification. Stay tuned as we make it available soon.
Currently, the COVID-Net team is working on **COVID-RiskNet**, a deep neural network tailored for COVID-19 risk stratification. Currently this is available as a work-in-progress via included `` script, help to contribute data and we can improve this tool.
If you would like to **contribute COVID-19 x-ray images**, please submit to Lets all work together to stop the spread of COVID-19!
......@@ -136,7 +136,17 @@ TF training script from a pretrained model:
3. To inference, `python --weightspath models/COVID-Netv2 --metaname model.meta_eval --ckptname model-2069 --imagepath assets/ex-covid.jpeg`
4. For more options and information, `python --help`
### Steps for Training COVIDNet-Risk
COVIDNet-Risk uses the same architecture as the existing COVIDNet - but instead it predicts the *"number of days since symptom onset"\** for a diagnosed COVID-19 patient based on their chest radiography (same data as COVIDNet). By performing offset stratification, we aim to provide an estimate of prognosis for the patient. Note that the initial dataset is fairly small at the time of writing and we hope to see more results as data increases.
1. Complete data creation and training for COVIDNet (see Training above)
2. run `` (see `-h` for argument help)
*\* note that definition varies between data sources*
## Results
These are the final results for COVID-Net Small and COVID-Net Large.
### COVIDNet Small
"""Perform transfer-learning for offset stratification with a provided COVID-Net
From the trained weights of a COVID-Net for COVID-19 identification in radiographs,
this tool performs transfer learning to re-use these weights for stratification of
patient offset (# of days since symptoms began *)
Steps to use this:
1. follow instructions for building data dir with train & test subdirs
2. train your network with the script
3. run this script and pass the path to your trained network (defaults should suffice)
(*) FIXME: It seems that the definition of offset varies between data sources! (account for this)
TODO: Make this script more general so that it can be used to transfer learn for other applications
import argparse
from collections import namedtuple
import cv2
import os
from typing import List, Tuple, Dict, Any
import numpy as np
import pandas as pd
from sklearn.metrics import confusion_matrix
import tensorflow as tf
from data import BalanceDataGenerator
# We will create a checkpoint which has initial values for these variables
IMAGE_SHAPE = (224, 224, 3)
INPUT_TENSOR_NAME = "input_1:0"
OUTPUT_TENSOR_NAME = "dense_3/Softmax:0"
SAMPLE_WEIGHTS = "dense_3_sample_weights:0"
def get_parse_fn(num_classes: int, augment: bool = False):
def parse_function(imagepath: str, label: int):
"""Parse a single element of the stratification dataset"""
# TODO add augmentation here ideally
image_decoded = tf.image.resize_images(
tf.image.decode_jpeg(, IMAGE_SHAPE[-1]), IMAGE_SHAPE[:2])
return (
tf.image.convert_image_dtype(image_decoded, dtype=tf.float32) / 255.0, # x
tf.one_hot(label, num_classes), # y
tf.convert_to_tensor(1.0, dtype=tf.float32), # sample_weights TODO: verify this is right
return parse_function
def parse_split(split_txt_path: str) -> Tuple[List[str], List[int]]:
"""Read the offsets for COVID patients based on the files in our split"""
# FIXME: ideally we should just store the offset in the split as well or read it from CSV by id.
# FIXME: we need to add pretrained weights + .txts for split with well-distributed offset.
files, labels = [], [],
for split_entry in open(split_txt_path).readlines():
_, image_file, diagnosis = split_entry.strip().split() # TODO: txts should just contain ids
if diagnosis == 'COVID-19':
patient = csv[csv["filename"] == image_file]
recorded_offset = patient['offset'].item()
if not np.isnan(recorded_offset):
offset = stratify(int(recorded_offset))
image_path = os.path.abspath(
os.path.join(args.chestxraydir, 'images', image_file))
assert os.path.exists(image_path), "Missing file {}".format(image_path)
return files, labels
def eval_net(sess: tf.Session, dataset_dict: Dict[str, Any], test_files: List[str],
test_labels: List[int]) -> None:
"""Evaluate the network"""
# Reset eval iterator['iterator'].initializer)
# Eval
preds, all_labels = [], []
num_evaled = 0
while True:
images, labels, sample_weights =['gn_op'])
pred =
feed_dict={INPUT_TENSOR_NAME: images, SAMPLE_WEIGHTS: sample_weights}
num_evaled += len(pred)
except tf.errors.OutOfRangeError:
print("\tevaluated {} images.".format(num_evaled))
matrix = confusion_matrix(all_labels, np.concatenate(preds)).astype('float')
per_class_acc = [
matrix[i,i]/np.sum(matrix[i,:]) if np.sum(matrix[i,:]) else 0 for i in range(len(matrix))
print("confusion matrix:\n{}\nper-class accuracies:\n{}".format(matrix, per_class_acc))
if __name__ == "__main__":
# Input args NOTE: the params here differ from thise in - we are fine-tuning
parser = argparse.ArgumentParser(description='COVIDNet-Risk Transfer Learning Script (offset).')
parser.add_argument('--classes', default=4, type=int,
help='Number of classes to stratify offset into.')
parser.add_argument('--stratification', type=int, nargs='+', default=[3, 5, 10],
help='Stratification points (days), i.e. "5 10" produces stratification of'
': 0o <-0c-> 5o <-1c-> 10o -2c-> via >= comparison (o=offset, c=class).')
parser.add_argument('--epochs', default=10, type=int,
help='Number of epochs (less since we\'re effectively fine-tuning).')
parser.add_argument('--lr', default=0.000002, type=float, help='Learning rate.')
parser.add_argument('--batch-size', default=8, type=int, help='Train batch-size')
parser.add_argument('--eval-batch-size', default=8, type=int, help='Eval batch-size')
parser.add_argument('--evaliterval', default=3, type=int,
help='# of epochs to train before running evaluation. NOTE: we only save'
'after evaluation. This can be disabled when more test data is available')
parser.add_argument('--input-weights-dir', default='models/COVIDNetv2', type=str,
help='Path to input folder containing a trained COVID-Netv2 checkpoint')
parser.add_argument('--input-meta-name', default='model.meta', type=str,
help='Name of meta file within <input-weights-dir>')
parser.add_argument('--outputdir', default='models/COVIDNet-Risk', type=str,
help='Path to output folder.')
parser.add_argument('--trainfile', default='train_COVIDx.txt', type=str,
help='Name of train file. NOTE: stock split is insufficient at this time.')
parser.add_argument('--testfile', default='test_COVIDx.txt', type=str,
help='Name of test file. NOTE: stock split is insufficient at this time.')
parser.add_argument('--name', default='COVIDNet-Risk', type=str,
help='Name of folder to store training checkpoints.')
parser.add_argument('--chestxraydir', default='../covid-chestxray-dataset', type=str,
help='Path to the chestxray images directory for COVID-19 patients.')
args = parser.parse_args()
# Check inputs
assert os.path.exists(args.input_weights_dir), "Missing file {}".format(args.input_weights_dir)
assert os.path.exists(os.path.join(args.input_weights_dir, args.input_meta_name)), \
"Missing file {}".format(args.input_meta_name)
# Format and define a stratification method based on our points
# TODO we could do a different amount of stratification but we have to add our own dense layers
assert len(args.stratification) == 3, "Must pass exactly 3 offset stratification points"
if args.stratification[0] != 0:
stratification = np.array([0, *args.stratification])
stratification = np.array(args.stratification)
num_classes = len(stratification)
stratify = lambda offset: np.where(offset >= stratification)[0][-1]
# Read CSV of dataset
assert os.path.exists(args.chestxraydir), "please clone "\
" and pass path to dir as --chestxraydir"
csv = pd.read_csv(os.path.join(args.chestxraydir, "metadata.csv"), nrows=None)
# Get the image filepaths and labels for training and testing split
train_files, train_labels = parse_split(args.trainfile)
assert len(train_files) >= 0 and len(train_files) == len(train_labels)
test_files, test_labels = parse_split(args.testfile)
assert len(test_files) >= 0 and len(test_labels) == len(test_files)
print("collected {} training and {} test cases for transfer-learning".format(
len(train_files), len(test_files)))
# Init augmentation fn - FIXME: we need a way to put this in a parse_fn for
# augmentation_fn = tf.keras.preprocessing.image.ImageDataGenerator(
# featurewise_center=False,
# featurewise_std_normalization=False,
# rotation_range=10,
# width_shift_range=0.1,
# height_shift_range=0.1,
# horizontal_flip=True,
# brightness_range=(0.9, 1.1),
# fill_mode='constant',
# cval=0.,
# )
# < define generator from augmentation_fn + cv loads? >
# dataset = generator,
# output_types=(tf.float32, tf.float32, tf.float32),
# output_shapes=([batch_size, 224, 224, 3],
# [batch_size, 3],
# [batch_size]))
# Output path creation for this run with lr param in name
train_dir = os.path.join(args.outputdir, + '-lr' + str(
os.makedirs(args.outputdir, exist_ok=True)
print('Output: ' + train_dir)
# Train
graph = tf.Graph()
with tf.Session(graph=graph) as sess:
# Import meta graph
tf.train.import_meta_graph(os.path.join(args.input_weights_dir, args.input_meta_name))
# Restore pre-trained vars which are not in our VARS_TO_FORGET list
restore_vars_list, init_vars_list = [], []
for var in graph.get_collection(tf.GraphKeys.GLOBAL_VARIABLES):
restore_saver = tf.train.Saver(var_list=restore_vars_list)
restore_saver.restore(sess, tf.train.latest_checkpoint(args.input_weights_dir))
existing_vars = sess.graph.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
# Get some I/O tensors
image_tensor = graph.get_tensor_by_name(INPUT_TENSOR_NAME)
labels_tensor = graph.get_tensor_by_name("dense_3_target:0")
sample_weights = graph.get_tensor_by_name(SAMPLE_WEIGHTS)
pred_tensor = graph.get_tensor_by_name("dense_3/MatMul:0")
# Define tf.datasets
datasets = {}
for is_training, files, labels in zip(
[True, False], [train_files, test_files], [train_labels, test_labels]):
dataset =, labels))
dataset =
if is_training:
dataset = dataset.shuffle(15)
dataset = dataset.batch(args.batch_size if is_training else args.eval_batch_size)
if is_training:
dataset = dataset.repeat()
iterator = dataset.make_initializable_iterator()
datasets['train' if is_training else 'test'] = {
'dataset': dataset,
'iterator': iterator,
'gn_op': iterator.get_next(),
# Define loss and optimizer
loss_op = tf.reduce_mean(
logits=pred_tensor, labels=labels_tensor) * sample_weights
optimizer = tf.train.AdamOptimizer(
train_op = optimizer.minimize(loss_op)
optim_vars = list(
set(sess.graph.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)) - set(existing_vars))
# Initialize the optimizer + dsi + vars in our VARS_TO_FORGET list + init_vars_list))
# save base model
saver = tf.train.Saver(), os.path.join(train_dir, 'model'))
print('Saved pre-trained model with re-initialized output layers.')
print('Baseline eval:')
eval_net(sess, datasets['test'], test_files, test_labels)
# Training cycle
# TODO: we need a training method that we can re-use. below very similar to
# FIXME: we need to consider freezing vars for all but dense layers.
print('Transfer Learning Started.')
print('\ttrain samples: {}\n\ttest samples: {}\n\tstratification: {}\n'.format(
len(train_files), len(test_files), args.stratification))['train']['iterator'].initializer)
num_batches = len(train_files) // args.batch_size
progbar = tf.keras.utils.Progbar(num_batches)
for epoch in range(args.epochs):
# Train
print("Fine-Tuning on 1 epoch = {} images.".format(len(train_files)))
for i in range(num_batches):
batch_x, batch_y, weights =['train']['gn_op'])
image_tensor: batch_x,
labels_tensor: batch_y,
sample_weights: weights,
progbar.update(i + 1)
# Evaluate + save
if epoch % args.evaliterval == 0:
pred =, feed_dict={image_tensor:batch_x})
loss =
pred_tensor: pred,
labels_tensor: batch_y,
sample_weights: weights,
print("Epoch:", '%04d' % (epoch + 1), "Minibatch loss=", "{:.9f}".format(loss))
eval_net(sess, datasets['test'], test_files, test_labels)
os.path.join(train_dir, 'model'),
global_step=epoch + 1,
print('Saving checkpoint at epoch {}'.format(epoch + 1))
print("Transfer Learning Finished!\n\tcheckpoint: '{}'".format(train_dir))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment