Commit 461c7a88 authored by lindawangg's avatar lindawangg

updated training script

parent e4a62c6f
......@@ -20,3 +20,4 @@ test_dups.py
test_COVIDx.txt
train_COVIDx.txt
create_COVIDx_v2.ipynb
data_keras.py
No preview for this file type
......@@ -6,7 +6,141 @@ import os
import cv2
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from sklearn.utils import shuffle
def _process_csv_file(file):
with open(file, 'r') as fr:
files = fr.readlines()
return files
class BalanceCovidDataset(keras.utils.Sequence):
'Generates data for Keras'
def __init__(
self,
data_dir,
csv_file,
is_training=True,
batch_size=8,
input_shape=(224, 224),
n_classes=3,
num_channels=3,
mapping={
'normal': 0,
'pneumonia': 1,
'COVID-19': 2
},
shuffle=True,
augmentation=True,
covid_percent=0.3,
class_weights=[1., 1., 6.]
):
'Initialization'
self.datadir = data_dir
self.dataset = _process_csv_file(csv_file)
self.is_training = is_training
self.batch_size = batch_size
self.N = len(self.dataset)
self.input_shape = input_shape
self.n_classes = n_classes
self.num_channels = num_channels
self.mapping = mapping
self.shuffle = True
self.covid_percent = covid_percent
self.class_weights = class_weights
self.n = 0
if augmentation:
self.augmentation = ImageDataGenerator(
featurewise_center=False,
featurewise_std_normalization=False,
rotation_range=10,
width_shift_range=0.1,
height_shift_range=0.1,
horizontal_flip=True,
brightness_range=(0.9, 1.1),
zoom_range=(0.85, 1.15),
fill_mode='constant',
cval=0.,
)
datasets = {'normal': [], 'pneumonia': [], 'COVID-19': []}
for l in self.dataset:
datasets[l.split()[-1]].append(l)
self.datasets = [
datasets['normal'] + datasets['pneumonia'],
datasets['COVID-19'],
]
print(len(self.datasets[0]), len(self.datasets[1]))
self.on_epoch_end()
def __next__(self):
# Get one batch of data
batch_x, batch_y, weights = self.__getitem__(self.n)
# Batch index
self.n += 1
# If we have processed the entire dataset then
if self.n >= self.__len__():
self.on_epoch_end
self.n = 0
return batch_x, batch_y, weights
def __len__(self):
return int(np.ceil(len(self.datasets[0]) / float(self.batch_size)))
def on_epoch_end(self):
'Updates indexes after each epoch'
if self.shuffle == True:
for v in self.datasets:
np.random.shuffle(v)
def __getitem__(self, idx):
batch_x, batch_y = np.zeros(
(self.batch_size, *self.input_shape,
self.num_channels)), np.zeros(self.batch_size)
batch_files = self.datasets[0][idx * self.batch_size:(idx + 1) *
self.batch_size]
# upsample covid cases
covid_size = max(int(len(batch_files) * self.covid_percent), 1)
covid_inds = np.random.choice(np.arange(len(batch_files)),
size=covid_size,
replace=False)
covid_files = np.random.choice(self.datasets[1],
size=covid_size,
replace=False)
for i in range(covid_size):
batch_files[covid_inds[i]] = covid_files[i]
for i in range(len(batch_files)):
sample = batch_files[i].split()
if self.is_training:
folder = 'train'
else:
folder = 'test'
x = cv2.imread(os.path.join(self.datadir, folder, sample[1]))
h, w, c = x.shape
x = x[int(h/6):, :]
x = cv2.resize(x, self.input_shape)
if self.is_training and hasattr(self, 'augmentation'):
x = self.augmentation.random_transform(x)
x = x.astype('float32') / 255.0
y = self.mapping[sample[2]]
batch_x[i] = x
batch_y[i] = y
class_weights = self.class_weights
weights = np.take(class_weights, batch_y.astype('int64'))
return batch_x, keras.utils.to_categorical(batch_y, num_classes=self.n_classes), weights
class BalanceDataGenerator(keras.utils.Sequence):
'Generates data for Keras'
......
......@@ -4,17 +4,17 @@ import tensorflow as tf
import os, argparse, pathlib
from eval import eval
from data import BalanceDataGenerator
from data import BalanceCovidDataset
parser = argparse.ArgumentParser(description='COVID-Net Training Script')
parser.add_argument('--epochs', default=10, type=int, help='Number of epochs')
parser.add_argument('--lr', default=0.00002, type=float, help='Learning rate')
parser.add_argument('--bs', default=8, type=int, help='Batch size')
parser.add_argument('--weightspath', default='models/COVIDNetv2', type=str, help='Path to output folder')
parser.add_argument('--metaname', default='model.meta_train', type=str, help='Name of ckpt meta file')
parser.add_argument('--ckptname', default='model-2069', type=str, help='Name of model ckpts')
parser.add_argument('--trainfile', default='train_COVIDx.txt', type=str, help='Name of train file')
parser.add_argument('--testfile', default='test_COVIDx.txt', type=str, help='Name of test file')
parser.add_argument('--weightspath', default='models/COVIDNet-CXR-Large', type=str, help='Path to output folder')
parser.add_argument('--metaname', default='model.meta', type=str, help='Name of ckpt meta file')
parser.add_argument('--ckptname', default='model-8485', type=str, help='Name of model ckpts')
parser.add_argument('--trainfile', default='train_COVIDx2.txt', type=str, help='Name of train file')
parser.add_argument('--testfile', default='test_COVIDx2.txt', type=str, help='Name of test file')
parser.add_argument('--name', default='COVIDNet', type=str, help='Name of folder to store training checkpoints')
parser.add_argument('--datadir', default='data', type=str, help='Path to data folder')
......@@ -37,19 +37,7 @@ with open(args.trainfile) as f:
with open(args.testfile) as f:
testfiles = f.readlines()
generator = BalanceDataGenerator(trainfiles, datadir=args.datadir, class_weights=[1., 1., 25.])
# Create a dataset tensor from the images and the labels
'''dataset = tf.data.Dataset.from_generator(lambda: generator,
output_types=(tf.float32, tf.float32, tf.float32),
output_shapes=([batch_size, 224, 224, 3],
[batch_size, 3],
[batch_size]))'''
# Create an iterator over the dataset
#iterator = dataset.make_initializable_iterator()
# Neural Net Input (images, labels, weights)
#batch_x, batch_y, weights = iterator.get_next()
generator = BalanceCovidDataset(data_dir=args.datadir, csv_file=args.trainfile, covid_percent=0.3, class_weights=[1., 1., 12.])
with tf.Session() as sess:
tf.get_default_graph()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment