Commit 0d80d064 authored by Yipeng Hu's avatar Yipeng Hu

ref #6 initial files from ius added

parent 3ac9bb0b
import os
import random
import tensorflow as tf
from matplotlib import pyplot as plt
flag_wsl = True
nSbj = 6
nFrm = 8
if flag_wsl:
home_dir = os.path.join('/mnt/c/Users/yhu') # WSL
else:
if os.name == 'nt':
home_dir = os.path.expanduser('~')
elif os.name == 'posix':
home_dir = os.environ['HOME']
filename = os.path.join(home_dir, 'Scratch/data/protocol/normalised/protocol_sweep_class_subjects.h5')
# generate 5 random subjects
num_subjects = tf.keras.utils.HDF5Matrix(filename, '/num_subjects').data.value[0][0]
idx_subject = random.sample(range(num_subjects),nSbj)
plt.figure()
for iSbj in range(nSbj):
dataset = '/subject%06d_num_frames' % (idx_subject[iSbj])
num_frames = tf.keras.utils.HDF5Matrix(filename, dataset)[0][0]
idx_frame = random.sample(range(num_frames),nFrm)
for iFrm in range(nFrm):
dataset = '/subject%06d_frame%08d' % (idx_subject[iSbj], idx_frame[iFrm])
frame = tf.transpose(tf.keras.utils.HDF5Matrix(filename, dataset))
dataset = '/subject%06d_label%08d' % (idx_subject[iSbj], idx_frame[iFrm])
label = tf.keras.utils.HDF5Matrix(filename, dataset)[0][0]
axs = plt.subplot(nSbj, nFrm, iSbj*nFrm+iFrm+1)
axs.set_title('S{}, F{}, C{}'.format(idx_subject[iSbj], idx_frame[iFrm], label))
axs.imshow(frame, cmap='gray')
axs.axis('off')
plt.show()
% script_prepData
if ispc
homeFolder = getenv('USERPROFILE');
elseif isunix
homeFolder = getenv('HOME');
end
normFolder = fullfile(homeFolder, 'Scratch/data/protocol/normalised');
mkdir(normFolder);
dataFolder = fullfile(homeFolder, 'Scratch/data/protocol/SPE_data_classes');
ClassNames = {'1_skull'; '2_abdomen'; '3_heart'; '4_other'};
ClassFolders = cellfun(@(x)fullfile(dataFolder,x),ClassNames,'UniformOutput',false);
%% go through all raw data and obtain the frame_info
case_ids = {};
frame_info = {};
frame_counters = zeros(length(ClassFolders),1);
idx_frame_1 = 0;
for idx_class_1 = 1:length(ClassFolders)
frame_names = dir(ClassFolders{idx_class_1});
for j = 3:length(frame_names)
fname = frame_names(j).name;
% get rid of the problematic files
try
% debug: fprintf('reading No.%d - [%s]\n',i,filename)
img = imread(fullfile(frame_names(j).folder,fname)); % figure, imshow(img,[])
catch
disp(fullfile(frame_names(j).folder,fname))
continue
end
% dealing with different date format here
date_del = strfind(fname,'-');
if length(date_del)>=6 % case that date_del is repeated
start0=date_del(6)+1;
elseif length(date_del)>=3
start0=date_del(3)+1;
else
start0=1;
end
ext0 = regexpi(fname,'.png');
newstr = strrep(fname(start0:ext0-1),'fr_','');
% additional check here
udls = strfind(newstr,'_');
if length(udls) ~= 2
warning('Incorrect filename format!, %s',fname);
end
id = newstr(1:udls(2)-1);
fr = str2double(newstr(udls(2)+1:end));
[~, idx_case_1] = ismember(id, case_ids);
if idx_case_1==0 % add to exisiting volume
idx_case_1 = length(case_ids)+1;
case_ids{length(case_ids)+1} = id;
end
idx_frame_1 = idx_frame_1+1;
idx_frame = idx_frame_1 - 1;
frame_info(idx_frame_1).filename = fname;
frame_info(idx_frame_1).case_name = id;
frame_info(idx_frame_1).case_idx = idx_case_1 - 1;
frame_info(idx_frame_1).class_name = ClassNames{idx_class_1};
frame_info(idx_frame_1).class_idx = idx_class_1 - 1;
end
end
save(fullfile(normFolder,'frame_info'),'frame_info','dataFolder');
%% now write into files
% specify the folders
load(fullfile(normFolder,'frame_info')); % specify the folders
roi_crop = [47,230,33,288]; % [xmin,xmax,ymin,ymax]
frame_size = [roi_crop(2)-roi_crop(1)+1, roi_crop(4)-roi_crop(3)+1];
indices_class = [frame_info(:).class_idx];
num_classes = length(unique(indices_class));
indices_subject = [frame_info(:).case_idx];
num_subjects = length(unique(indices_subject));
% by subject
h5fn_subjects = fullfile(normFolder,'protocol_sweep_class_subjects.h5'); delete(h5fn_subjects);
% write global infomation
GroupName = '/frame_size';
h5create(h5fn_subjects,GroupName,size(frame_size),'DataType','uint32');
h5write(h5fn_subjects,GroupName,uint32(frame_size));
GroupName = '/num_classes';
h5create(h5fn_subjects,GroupName,[1,1],'DataType','uint32');
h5write(h5fn_subjects,GroupName,uint32(num_classes));
GroupName = '/num_subjects';
h5create(h5fn_subjects,GroupName,[1,1],'DataType','uint32');
h5write(h5fn_subjects,GroupName,uint32(num_subjects));
for idx_subject = 0:num_subjects-1 % 0-based indexing
indices_frame_1_subject = find(indices_subject==idx_subject);
num_frames_subject = length(indices_frame_1_subject);
for idx_frame_subject = 0:num_frames_subject-1
idx_frame = indices_frame_1_subject(idx_frame_subject+1);
filename = fullfile(dataFolder, frame_info(idx_frame).class_name, frame_info(idx_frame).filename);
img = imread(filename);
img = img(roi_crop(1):roi_crop(2),roi_crop(3):roi_crop(4));
GroupName = sprintf('/subject%06d_frame%08d',idx_subject,idx_frame_subject);
h5create(h5fn_subjects,GroupName,size(img),'DataType','uint8');
h5write(h5fn_subjects,GroupName,img);
GroupName = sprintf('/subject%06d_label%08d',idx_subject,idx_frame_subject);
h5create(h5fn_subjects,GroupName,[1,1],'DataType','uint32');
h5write(h5fn_subjects,GroupName,uint32(indices_class(idx_frame)));
end
GroupName = sprintf('/subject%06d_num_frames',idx_subject);
h5create(h5fn_subjects,GroupName,[1,1],'DataType','uint32');
h5write(h5fn_subjects,GroupName,uint32(num_frames_subject));
end
%% obsolete
% % by subject
% h5fn_subjects = fullfile(normFolder,'protocol_sweep_class_subjects.h5'); delete(h5fn_subjects);
% num_frames_per_subject = zeros(1,num_subjects,'uint32');
% for idx_subject = (1:num_subjects)-1 % 0-based indexing
% frame_subject = 0;
% indices_frame_1_subject = find(indices_subject==idx_subject);
% num_frames_per_subject(idx_subject+1) = length(indices_frame_1_subject);
% for idx_frame_1 = indices_frame_1_subject
% filename = fullfile(dataFolder,frame_info(idx_frame_1).class_name,frame_info(idx_frame_1).filename);
% img = imread(filename);
% img = img(roi_crop(1):roi_crop(2),roi_crop(3):roi_crop(4));
% GroupName = sprintf('/subject%06d_frame%08d',idx_subject,frame_subject);
% frame_subject = frame_subject+1;
% h5create(h5fn_subjects,GroupName,size(img),'DataType','uint8');
% h5write(h5fn_subjects,GroupName,img);
% end
% GroupName = sprintf('/subject%06d_class',idx_subject);
% h5create(h5fn_subjects,GroupName,size(indices_frame_1_subject),'DataType','uint32');
% h5write(h5fn_subjects,GroupName,uint32(indices_class(indices_frame_1_subject)));
% end
% % extra info
% GroupName = '/num_frames_per_subject';
% h5create(h5fn_subjects,GroupName,size(num_frames_per_subject),'DataType','uint32');
% h5write(h5fn_subjects,GroupName,uint32(num_frames_per_subject));
% GroupName = '/frame_size';
% h5create(h5fn_subjects,GroupName,size(frame_size),'DataType','uint32');
% h5write(h5fn_subjects,GroupName,uint32(frame_size));
% GroupName = '/num_classes';
% h5create(h5fn_subjects,GroupName,[1,1],'DataType','uint32');
% h5write(h5fn_subjects,GroupName,uint32(num_classes));
% GroupName = '/num_subjects';
% h5create(h5fn_subjects,GroupName,[1,1],'DataType','uint32');
% h5write(h5fn_subjects,GroupName,uint32(num_subjects));
% % by frames
% h5fn_frames = fullfile(normFolder,'protocol_sweep_class_frames.h5'); delete(h5fn_frames);
% for idx_frame_1 = 1:length(frame_info)
% %% now read in image
% filename = fullfile(dataFolder,frame_info(idx_frame_1).class_name,frame_info(idx_frame_1).filename);
% img = imread(filename);
% img = img(roi_crop(1):roi_crop(2),roi_crop(3):roi_crop(4));
% % figure, imshow(img,[])
% GroupName = sprintf('/frame%08d',idx_frame_1-1);
% h5create(h5fn_frames,GroupName,size(img),'DataType','uint8');
% h5write(h5fn_frames,GroupName,img);
% end
% GroupName = '/class';
% h5create(h5fn_frames,GroupName,size(indices_class),'DataType','uint32');
% h5write(h5fn_frames,GroupName,uint32(indices_class));
% GroupName = '/subject';
% h5create(h5fn_frames,GroupName,size(indices_subject),'DataType','uint32');
% h5write(h5fn_frames,GroupName,uint32(indices_subject));
% % extra info
% GroupName = '/frame_size';
% h5create(h5fn_frames,GroupName,size(frame_size),'DataType','uint32');
% h5write(h5fn_frames,GroupName,uint32(frame_size));
% GroupName = '/num_classes';
% h5create(h5fn_frames,GroupName,[1,1],'DataType','uint32');
% h5write(h5fn_frames,GroupName,uint32(num_classes));
% GroupName = '/num_subjects';
% h5create(h5fn_frames,GroupName,[1,1],'DataType','uint32');
% h5write(h5fn_frames,GroupName,uint32(num_subjects));
import tensorflow as tf
import random
import os
# import numpy as np
os.environ["CUDA_VISIBLE_DEVICES"]="0"
flag_wsl = True
if flag_wsl:
home_dir = os.path.join('/mnt/c/Users/yhu') # WSL
else:
if os.name == 'nt':
home_dir = os.path.expanduser('~')
elif os.name == 'posix':
home_dir = os.environ['HOME']
filename = os.path.join(home_dir, 'Scratch/data/protocol/normalised/protocol_sweep_class_subjects.h5')
frame_size = tf.keras.utils.HDF5Matrix(filename, '/frame_size').data.value
frame_size = [frame_size[0][0],frame_size[1][0]]
num_classes = tf.keras.utils.HDF5Matrix(filename, '/num_classes').data.value[0][0]
# place holder for input image frames
features_input = tf.keras.Input(shape=frame_size+[1])
features = tf.keras.layers.Conv2D(32, 7, activation='relu')(features_input)
features = tf.keras.layers.MaxPool2D(3)(features)
features_block_1 = tf.keras.layers.Conv2D(64, 3, activation='relu')(features)
features = tf.keras.layers.Conv2D(64, 3, activation='relu', padding='same')(features_block_1)
features = tf.keras.layers.Conv2D(64, 3, activation='relu', padding='same')(features)
features_block_2 = features + features_block_1
features = tf.keras.layers.Conv2D(64, 3, activation='relu', padding='same')(features_block_2)
features = tf.keras.layers.Conv2D(64, 3, activation='relu', padding='same')(features)
features = features + features_block_2
features = tf.keras.layers.MaxPool2D(3)(features)
features_block_3 = tf.keras.layers.Conv2D(128, 3, activation='relu')(features)
features = tf.keras.layers.Conv2D(128, 3, activation='relu', padding='same')(features_block_3)
features = tf.keras.layers.Conv2D(128, 3, activation='relu', padding='same')(features)
features_block_4 = features + features_block_3
features = tf.keras.layers.Conv2D(128, 3, activation='relu', padding='same')(features_block_4)
features = tf.keras.layers.Conv2D(128, 3, activation='relu', padding='same')(features)
features = features + features_block_4
features = tf.keras.layers.MaxPool2D(3)(features)
features_block_5 = tf.keras.layers.Conv2D(128, 3, activation='relu')(features)
features = tf.keras.layers.Conv2D(128, 3, activation='relu', padding='same')(features_block_5)
features = tf.keras.layers.Conv2D(128, 3, activation='relu', padding='same')(features)
features_block_6 = features + features_block_5
features = tf.keras.layers.Conv2D(128, 3, activation='relu', padding='same')(features_block_6)
features = tf.keras.layers.Conv2D(128, 3, activation='relu', padding='same')(features)
features = features + features_block_6
features = tf.keras.layers.Conv2D(128, 3, activation='relu')(features)
features = tf.keras.layers.GlobalAveragePooling2D()(features)
features = tf.keras.layers.Dense(units=256, activation='relu')(features)
features = tf.keras.layers.Dropout(0.5)(features)
logits_output = tf.keras.layers.Dense(units=num_classes, activation='softmax')(features)
# now the model
model = tf.keras.Model(inputs=features_input, outputs=logits_output)
model.summary()
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
loss='sparse_categorical_crossentropy',
metrics=['SparseCategoricalAccuracy'])
# now get the data using a generator
num_subjects = tf.keras.utils.HDF5Matrix(filename, '/num_subjects').data.value[0][0]
subject_indices = range(num_subjects)
num_frames_per_subject = 1
def data_generator():
for iSbj in subject_indices:
dataset = '/subject%06d_num_frames' % iSbj
num_frames = tf.keras.utils.HDF5Matrix(filename, dataset)[0][0]
idx_frame = random.sample(range(num_frames),num_frames_per_subject)[0]
dataset = '/subject%06d_frame%08d' % (iSbj, idx_frame)
frame = tf.transpose(tf.keras.utils.HDF5Matrix(filename, dataset)) / 255
dataset = '/subject%06d_label%08d' % (iSbj, idx_frame)
label = tf.keras.utils.HDF5Matrix(filename, dataset)[0][0]
yield (tf.expand_dims(frame, axis=2), label)
dataset = tf.data.Dataset.from_generator(generator = data_generator,
output_types = (tf.float32, tf.int32),
output_shapes = (frame_size+[1], ()))
# training
dataset_batch = dataset.shuffle(buffer_size=1024).batch(num_subjects)
frame_train, label_train = next(iter(dataset_batch))
model.fit(frame_train, label_train, epochs=int(1e3), validation_split=0.2)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment