Commit 4b353d51 authored by mathpluscode's avatar mathpluscode

use black to reformat the code

parent b06b379c
Pipeline #3304 failed with stages
in 53 seconds
......@@ -26,32 +26,31 @@ import sys
# sys.path.insert(0, os.path.abspath('.'))
working_dir = os.path.abspath(os.path.dirname(__file__))
root_dir_rel = os.path.join('..')
root_dir_rel = os.path.join("..")
root_dir_abs = os.path.abspath(root_dir_rel)
module_path = root_dir_abs
sys.path.insert(0, module_path)
logo_file = 'project-icon.png'
logo_path = os.path.join('..', logo_file)
logo_file = "project-icon.png"
logo_path = os.path.join("..", logo_file)
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This patterns also effect to html_static_path and html_extra_path
exclude_patterns = [
'tests',
'run_*',
'setup.py',
'_build',
'Thumbs.db',
'.DS_Store',
'_verion.py',
'versioneer.py'
"tests",
"run_*",
"setup.py",
"_build",
"Thumbs.db",
".DS_Store",
"_verion.py",
"versioneer.py",
]
# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
static_folder = 'static'
static_folder = "static"
html_static_path = [static_folder]
......@@ -60,20 +59,21 @@ def generate_apidocs(*args):
global working_dir, module_path
output_path = working_dir
apidoc_command_path = 'sphinx-apidoc'
if hasattr(sys, 'real_prefix'): # called from a virtualenv
apidoc_command_path = os.path.join(sys.prefix, 'bin', 'sphinx-apidoc')
apidoc_command_path = "sphinx-apidoc"
if hasattr(sys, "real_prefix"): # called from a virtualenv
apidoc_command_path = os.path.join(sys.prefix, "bin", "sphinx-apidoc")
apidoc_command_path = os.path.abspath(apidoc_command_path)
subprocess.check_call(
[apidoc_command_path, '--force', '--separate'] +
['-o', output_path, module_path] +
[os.path.join(root_dir_abs, pattern) for pattern in exclude_patterns])
[apidoc_command_path, "--force", "--separate"]
+ ["-o", output_path, module_path]
+ [os.path.join(root_dir_abs, pattern) for pattern in exclude_patterns]
)
def setup(app):
# Hook to allow for automatic generation of API docs
# before doc deployment begins.
app.connect('builder-inited', generate_apidocs)
app.connect("builder-inited", generate_apidocs)
# -- General configuration ------------------------------------------------
......@@ -85,39 +85,41 @@ def setup(app):
# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = ['sphinx.ext.autodoc', 'sphinx.ext.viewcode', 'sphinx.ext.imgmath']
extensions = [
"sphinx.ext.autodoc",
"sphinx.ext.viewcode",
"sphinx.ext.imgmath",
]
# Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates']
templates_path = ["_templates"]
# The suffix(es) of source filenames.
# You can specify multiple suffix as a list of string:
#
# source_suffix = ['.rst', '.md']
source_suffix = '.rst'
source_suffix = ".rst"
# The master toctree document.
master_doc = 'index'
master_doc = "index"
# This allows modules to be indexed under the submodule name rather than all appearing under yfmil3id2019
modindex_common_prefix = [
'yfmil3id2019.'
]
modindex_common_prefix = ["yfmil3id2019."]
# General information about the project.
project = u'yunguanfu-mil3id2019'
project = u"yunguanfu-mil3id2019"
copyright = u"2019, University College London"
author = u'Yunguan Fu'
author = u"Yunguan Fu"
# The version info for the project you're documenting, acts as replacement for
# |version| and |release|, also used in various other places throughout the
# built documents.
#
# The short X.Y version.
version = u''
version = u""
# The full version, including alpha/beta/rc tags.
release = u''
release = u""
# The short X.Y version.
# version = yfmil3id2019.__version__
......@@ -132,7 +134,7 @@ release = u''
language = None
# The name of the Pygments (syntax highlighting) style to use.
pygments_style = 'sphinx'
pygments_style = "sphinx"
# If true, `todo` and `todoList` produce output, else they produce nothing.
todo_include_todos = False
......@@ -144,37 +146,37 @@ todo_include_todos = False
# a list of builtin themes.
#
# html_theme = 'alabaster'
html_theme = 'sphinx_rtd_theme'
html_theme = "sphinx_rtd_theme"
# Theme options are theme-specific and customize the look and feel of a theme
# further. For a list of options available for each theme, see the
# documentation.
#
doc_black = '#0A0A0A'
doc_red = '#E03C31'
doc_gray = '#979999'
doc_blue = '#00627D'
doc_dark_red = '#C53B2C'
doc_white = '#FEFEFE'
doc_black = "#0A0A0A"
doc_red = "#E03C31"
doc_gray = "#979999"
doc_blue = "#00627D"
doc_dark_red = "#C53B2C"
doc_white = "#FEFEFE"
html_theme_options = {
'footerbgcolor': doc_gray,
'footertextcolor': doc_black,
'sidebarbgcolor': doc_white,
'sidebartextcolor': doc_black,
'sidebarlinkcolor': doc_red,
'relbarbgcolor': doc_white,
'relbartextcolor': doc_black,
'relbarlinkcolor': doc_red,
'bgcolor': doc_white,
'textcolor': doc_black,
'linkcolor': doc_red,
'visitedlinkcolor': doc_dark_red,
'headbgcolor': doc_white,
'headtextcolor': doc_black,
'headlinkcolor': doc_red,
'codebgcolor': doc_blue,
'codetextcolor': doc_black,
'stickysidebar': 'true',
"footerbgcolor": doc_gray,
"footertextcolor": doc_black,
"sidebarbgcolor": doc_white,
"sidebartextcolor": doc_black,
"sidebarlinkcolor": doc_red,
"relbarbgcolor": doc_white,
"relbartextcolor": doc_black,
"relbarlinkcolor": doc_red,
"bgcolor": doc_white,
"textcolor": doc_black,
"linkcolor": doc_red,
"visitedlinkcolor": doc_dark_red,
"headbgcolor": doc_white,
"headtextcolor": doc_black,
"headlinkcolor": doc_red,
"codebgcolor": doc_blue,
"codetextcolor": doc_black,
"stickysidebar": "true",
}
html_logo = logo_path
......@@ -182,7 +184,7 @@ html_logo = logo_path
# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ['static']
html_static_path = ["static"]
# Custom sidebar templates, must be a dictionary that maps document names
# to template names.
......@@ -190,9 +192,9 @@ html_static_path = ['static']
# This is required for the alabaster theme
# refs: http://alabaster.readthedocs.io/en/latest/installation.html#sidebars
html_sidebars = {
'**': [
'relations.html', # needs 'show_related': True theme option to display
'searchbox.html',
"**": [
"relations.html", # needs 'show_related': True theme option to display
"searchbox.html",
]
}
......@@ -200,7 +202,7 @@ html_sidebars = {
# -- Options for HTMLHelp output ------------------------------------------
# Output file base name for HTML help builder.
htmlhelp_basename = 'yunguanfu-mil3id2019doc'
htmlhelp_basename = "yunguanfu-mil3id2019doc"
# -- Options for LaTeX output ---------------------------------------------
......@@ -208,15 +210,12 @@ latex_elements = {
# The paper size ('letterpaper' or 'a4paper').
#
# 'papersize': 'letterpaper',
# The font size ('10pt', '11pt' or '12pt').
#
# 'pointsize': '10pt',
# Additional stuff for the LaTeX preamble.
#
# 'preamble': '',
# Latex figure (float) alignment
#
# 'figure_align': 'htbp',
......@@ -226,9 +225,13 @@ latex_elements = {
# (source start file, target name, title,
# author, documentclass [howto, manual, or own class]).
latex_documents = [
('index', 'yunguanfu-mil3id2019.tex',
u'yunguanfu-mil3id2019 Documentation',
u'Yunguan Fu', 'manual'),
(
"index",
"yunguanfu-mil3id2019.tex",
u"yunguanfu-mil3id2019 Documentation",
u"Yunguan Fu",
"manual",
)
]
# -- Options for manual page output ---------------------------------------
......@@ -236,9 +239,13 @@ latex_documents = [
# One entry per manual page. List of tuples
# (source start file, name, description, authors, manual section).
man_pages = [
('index', 'yunguanfu-mil3id2019',
u'yunguanfu-mil3id2019 Documentation',
[u'Yunguan Fu'], 1)
(
"index",
"yunguanfu-mil3id2019",
u"yunguanfu-mil3id2019 Documentation",
[u"Yunguan Fu"],
1,
)
]
......@@ -248,10 +255,13 @@ man_pages = [
# (source start file, target name, title, author,
# dir menu entry, description, category)
texinfo_documents = [
('index', 'yunguanfu-mil3id2019',
u'yunguanfu-mil3id2019 Documentation',
u'Yunguan Fu',
'yunguanfu-mil3id2019',
'One line description of project.',
'Miscellaneous'),
(
"index",
"yunguanfu-mil3id2019",
u"yunguanfu-mil3id2019 Documentation",
u"Yunguan Fu",
"yunguanfu-mil3id2019",
"One line description of project.",
"Miscellaneous",
)
]
......@@ -7,58 +7,40 @@ from setuptools import setup, find_packages
import versioneer
# Get the long description
with open('README.rst') as f:
with open("README.rst") as f:
long_description = f.read()
setup(
name='yunguanfu-mil3id2019',
name="yunguanfu-mil3id2019",
version=versioneer.get_version(),
cmdclass=versioneer.get_cmdclass(),
description="Implements Yunguan Fu's MIL3ID paper presented at MICCAI 2019",
long_description=long_description,
long_description_content_type='text/x-rst',
url='https://weisslab.cs.ucl.ac.uk/WEISS/SoftwareRepositories/yunguanfu-mil3id2019',
author='Yunguan Fu',
author_email='yunguan.fu.18@ucl.ac.uk',
license='BSD-3 license',
long_description_content_type="text/x-rst",
url="https://weisslab.cs.ucl.ac.uk/WEISS/SoftwareRepositories/yunguanfu-mil3id2019",
author="Yunguan Fu",
author_email="yunguan.fu.18@ucl.ac.uk",
license="BSD-3 license",
classifiers=[
'Development Status :: 3 - Alpha',
'Intended Audience :: Developers',
'Intended Audience :: Healthcare Industry',
'Intended Audience :: Information Technology',
'Intended Audience :: Science/Research',
'License :: OSI Approved :: BSD License',
'Programming Language :: Python',
'Programming Language :: Python :: 2',
'Programming Language :: Python :: 3',
'Topic :: Scientific/Engineering :: Information Analysis',
'Topic :: Scientific/Engineering :: Medical Science Apps.',
"Development Status :: 3 - Alpha",
"Intended Audience :: Developers",
"Intended Audience :: Healthcare Industry",
"Intended Audience :: Information Technology",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: BSD License",
"Programming Language :: Python",
"Programming Language :: Python :: 2",
"Programming Language :: Python :: 3",
"Topic :: Scientific/Engineering :: Information Analysis",
"Topic :: Scientific/Engineering :: Medical Science Apps.",
],
keywords='medical imaging',
packages=find_packages(
exclude=[
'doc',
'tests',
]
),
install_requires=[
'six>=1.10',
'numpy>=1.11',
],
keywords="medical imaging",
packages=find_packages(exclude=["doc", "tests"]),
install_requires=["six>=1.10", "numpy>=1.11"],
entry_points={
'console_scripts': [
'yfmil3id2019_train=yfmil3id2019.ui.yfmil3id2019_train_command_line:main',
'yfmil3id2019_test=yfmil3id2019.ui.yfmil3id2019_test_command_line:main',
],
"console_scripts": [
"yfmil3id2019_train=yfmil3id2019.ui.yfmil3id2019_train_command_line:main",
"yfmil3id2019_test=yfmil3id2019.ui.yfmil3id2019_test_command_line:main",
]
},
)
This diff is collapsed.
......@@ -2,5 +2,6 @@
"""yunguanfu-mil3id2019"""
from ._version import get_versions
__version__ = get_versions()['version']
__version__ = get_versions()["version"]
del get_versions
This diff is collapsed.
......@@ -6,20 +6,32 @@ from yfmil3id2019.src.util import make_dir
class CrossValidationRun:
def __init__(self, folders_lbl_train, folders_unlbl_train, folders_lbl_eval):
def __init__(
self, folders_lbl_train, folders_unlbl_train, folders_lbl_eval
):
self.folders_lbl_train = folders_lbl_train
self.folders_unlbl_train = folders_unlbl_train
self.folders_lbl_eval = folders_lbl_eval
self.mean_std_folders = [x.replace('labeled', 'mean_std') for x in folders_lbl_train]
self.mean_std_folders = [
x.replace("labeled", "mean_std") for x in folders_lbl_train
]
def __repr__(self):
return get_folder_name_from_paths(self.folders_lbl_train) + '/' \
+ (get_folder_name_from_paths(self.folders_unlbl_train) if self.folders_unlbl_train is not None else '') + '/' \
+ get_folder_name_from_paths(self.folders_lbl_eval)
return (
get_folder_name_from_paths(self.folders_lbl_train)
+ "/"
+ (
get_folder_name_from_paths(self.folders_unlbl_train)
if self.folders_unlbl_train is not None
else ""
)
+ "/"
+ get_folder_name_from_paths(self.folders_lbl_eval)
)
def generate_mean_std(self, dir_run):
matplotlib.use('agg')
matplotlib.use("agg")
import matplotlib.pyplot as plt
from matplotlib.image import imsave
......@@ -30,11 +42,11 @@ class CrossValidationRun:
# read files
for folder in self.mean_std_folders:
mean = plt.imread(folder + '/train_mean.png')
std = plt.imread(folder + '/train_std.png')
mean = plt.imread(folder + "/train_mean.png")
std = plt.imread(folder + "/train_std.png")
mean = mean[:, :, :3]
std = std[:, :, :3]
with open(folder + '/num_img.txt', 'r') as f:
with open(folder + "/num_img.txt", "r") as f:
num = int(f.read().splitlines()[0])
means.append(mean)
stds.append(std)
......@@ -51,7 +63,7 @@ class CrossValidationRun:
std = np.sqrt(std)
# save file
mean_std_path = [dir_run + '/mean.png', dir_run + '/std.png']
mean_std_path = [dir_run + "/mean.png", dir_run + "/std.png"]
imsave(mean_std_path[0], mean, vmin=0, vmax=1)
imsave(mean_std_path[1], std, vmin=0, vmax=1)
......@@ -60,9 +72,9 @@ class CrossValidationRun:
def split_train_eval(folders_labeled, folders_unlabeled, param):
# retrieve parameters
train = param['train']
leave_out = param['leave_out']
overfit = param['overfit']
train = param["train"]
leave_out = param["leave_out"]
overfit = param["overfit"]
num_folders = len(folders_labeled)
# extend labeled folders
......@@ -71,14 +83,22 @@ def split_train_eval(folders_labeled, folders_unlabeled, param):
# the train folders will be [0, 1], [2, 3], [4, 5, 6]
runs = []
indices = [x for x in range(num_folders)]
indicies_evals = np.array_split(np.arange(num_folders), num_folders // leave_out)
indicies_evals = np.array_split(
np.arange(num_folders), num_folders // leave_out
)
for run_id in range(num_folders // leave_out):
indicies_eval = list(indicies_evals[run_id])
indicies_train = [x for x in indices if x not in indicies_eval]
fs_lbl_train = [folders_labeled[x] for x in indicies_train]
fs_unlbl_train = [folders_unlabeled[x] for x in indicies_train] if folders_unlabeled is not None else None
fs_unlbl_train = (
[folders_unlabeled[x] for x in indicies_train]
if folders_unlabeled is not None
else None
)
fs_lbl_eval = [folders_labeled[x] for x in indicies_eval]
runs.append(CrossValidationRun(fs_lbl_train, fs_unlbl_train, fs_lbl_eval))
runs.append(
CrossValidationRun(fs_lbl_train, fs_unlbl_train, fs_lbl_eval)
)
if train >= 0:
runs = [runs[train]]
......@@ -91,32 +111,52 @@ def split_train_eval(folders_labeled, folders_unlabeled, param):
def get_folder_name_from_paths(paths):
return '_'.join([x.split('/')[-1] for x in paths])
return "_".join([x.split("/")[-1] for x in paths])
def save_predict_results(results, dir_run, name):
matplotlib.use('agg')
matplotlib.use("agg")
import matplotlib.pyplot as plt
from matplotlib.image import imsave
dir_pred = dir_run + '/preds/%s/' % name
dir_pred = dir_run + "/preds/%s/" % name
make_dir(dir_pred)
with open(dir_pred + 'metric.log', 'w+') as f:
with open(dir_pred + "metric.log", "w+") as f:
for i, result in enumerate(results):
images = result['images']
masks = result['masks'] # [0,1]
preds = result['preds'] # [0,1]
logits = result['logits'] # [0,1]
images = (images - np.min(images)) / (np.max(images) - np.min(images))
imsave(dir_pred + '/%d_image.png' % i, images)
imsave(dir_pred + '/%d_mask.png' % i, masks, vmin=0, vmax=1, cmap='gray')
imsave(dir_pred + '/%d_prob.png' % i, preds, vmin=0, vmax=1, cmap='gray')
imsave(dir_pred + '/%d_pred.png' % i, np.round(preds), vmin=0, vmax=1, cmap='gray')
images = result["images"]
masks = result["masks"] # [0,1]
preds = result["preds"] # [0,1]
logits = result["logits"] # [0,1]
images = (images - np.min(images)) / (
np.max(images) - np.min(images)
)
imsave(dir_pred + "/%d_image.png" % i, images)
imsave(
dir_pred + "/%d_mask.png" % i,
masks,
vmin=0,
vmax=1,
cmap="gray",
)
imsave(
dir_pred + "/%d_prob.png" % i,
preds,
vmin=0,
vmax=1,
cmap="gray",
)
imsave(
dir_pred + "/%d_pred.png" % i,
np.round(preds),
vmin=0,
vmax=1,
cmap="gray",
)
metrics = seg_metric_np(preds, masks)
line = '%d|' % i
line = "%d|" % i
for k, v in metrics.items():
line += k + '=' + '%f,' % v
f.write(line + '\n')
line += k + "=" + "%f," % v
f.write(line + "\n")
......@@ -25,13 +25,13 @@ def apply_augmentation_in_model(x, mean, std, ch, affine, config_aug, mode):
x = adjust_color(x, ch, config_aug)
# standardize
if config_aug['standardize']:
if config_aug["standardize"]:
x = standardize(x, mean, std, ch)
if mode == tf.estimator.ModeKeys.TRAIN:
# affine, should also be the last augmentation
p = config_aug['affine']['prob']
scale = config_aug['affine']['scale']
p = config_aug["affine"]["prob"]
scale = config_aug["affine"]["scale"]
if affine:
if scale > 0:
if np.random.rand() < p:
......@@ -50,10 +50,16 @@ def adjust_color(x, ch, config_aug):
def _adjust_color(x):
"""perform at most one color adjustment for speed"""
color_adjusted = True
if np.random.rand() < config_aug['contrast']['prob']:
x = tf.image.random_contrast(image=x, lower=config_aug['contrast']['lower'], upper=config_aug['contrast']['upper'])
elif np.random.rand() < config_aug['brightness']['prob']:
x = tf.image.random_brightness(image=x, max_delta=config_aug['brightness']['max_delta'])
if np.random.rand() < config_aug["contrast"]["prob"]:
x = tf.image.random_contrast(
image=x,
lower=config_aug["contrast"]["lower"],
upper=config_aug["contrast"]["upper"],
)
elif np.random.rand() < config_aug["brightness"]["prob"]:
x = tf.image.random_brightness(
image=x, max_delta=config_aug["brightness"]["max_delta"]
)
else:
color_adjusted = False
if color_adjusted:
......@@ -76,7 +82,7 @@ def adjust_color(x, ch, config_aug):
image2 = _adjust_color(image2)
x = tf.concat([image1, image2, mask], axis=-1)
else:
raise ValueError('Unknown input channel %d.' % ch)
raise ValueError("Unknown input channel %d." % ch)
return x
......@@ -97,7 +103,7 @@ def standardize(x, mean, std, ch):
image2 = (image2 - mean) / std
x = tf.concat([image1, image2, mask], axis=-1)
else:
raise ValueError('Unknown input channel %d.' % ch)
raise ValueError("Unknown input channel %d." % ch)
return x
......@@ -112,11 +118,17 @@ def apply_affine_transform(images, scale, return_fn):
sh = images.get_shape().as_list()
batch_size = sh[0]
size = sh[1:3]
A = get_affine_transform_batch(size, scale, batch_size) # shape = [batch_size, 2, 3]
A = get_affine_transform_batch(
size, scale, batch_size
) # shape = [batch_size, 2, 3]
A = tf.convert_to_tensor(A, dtype=tf.float32)
A = flatten(A) # shape = [batch_size, 6]
A = tf.concat([A, tf.zeros([batch_size, 2], tf.float32)], axis=1) # shape = [batch_size, 8]
images = tf.contrib.image.transform(images=images, transforms=tf.convert_to_tensor(A))
A = tf.concat(
[A, tf.zeros([batch_size, 2], tf.float32)], axis=1
) # shape = [batch_size, 8]
images = tf.contrib.image.transform(
images=images, transforms=tf.convert_to_tensor(A)
)
return images
......@@ -136,8 +148,12 @@ def get_joint_affine_transform_batch(size, scale, batch_size):
apply A1, then A12 is equivalent to apply A2
they all have shape = [batch_size, 8]
"""
A = get_affine_transform_batch(size=size, scale=scale, batch_size=batch_size * 2) # [batch_size*2, 2, 3], np array
A = np.concatenate([A, np.zeros([batch_size * 2, 1, 3])], axis=1) # [batch_size*2, 3, 3], np array
A = get_affine_transform_batch(
size=size, scale=scale, batch_size=batch_size * 2
) # [batch_size*2, 2, 3], np array
A = np.concatenate(
[A, np.zeros([batch_size * 2, 1, 3])], axis=1
) # [batch_size*2, 3, 3], np array
A[:, 2, 2] = 1
A = tf.convert_to_tensor(A, dtype=tf.float32)
......@@ -174,14 +190,26 @@ def get_affine_transform_batch(size, scale, batch_size):
"""
W = size[1] / size[0]
H = 1
coords_orig = np.array([[[W / 2, H / 2],
[-W / 2, H / 2],
[-W / 2, -H / 2],
[W / 2, -H / 2]]], dtype=np.float32) # [1, 4, 2]
coords_orig = np.tile(coords_orig, (batch_size, 1, 1)) # [batch_size, 4, 2]
offset = np.random.uniform(-scale, scale, [batch_size, 4, 2]) # [batch_size, 4, 2]
coords_orig = np.array(
[[[W / 2, H / 2], [-W / 2, H / 2], [-W / 2, -H / 2], [W / 2, -H / 2]]],
dtype=np.float32,
) # [1, 4, 2]
coords_orig = np.tile(
coords_orig, (batch_size, 1, 1)
) # [batch_size, 4, 2]
offset = np.random.uniform(
-scale, scale, [batch_size, 4, 2]
) # [batch_size, 4, 2]
coords_new = coords_orig + offset # [batch_size, 4, 2]
coords_orig = np.concatenate([coords_orig, np.ones((batch_size, 4, 1))], axis=2) # [batch_size, 4, 3]
A = np.stack([calculate_affine_matrix(coords_orig[k, :, :], coords_new[k, :, :]) for k in range(batch_size)], axis=0) # [batch_size, 2, 3]
coords_orig = np.concatenate(
[coords_orig, np.ones((batch_size, 4, 1))], axis=2
) # [batch_size, 4, 3]
A = np.stack(
[
calculate_affine_matrix(coords_orig[k, :, :], coords_new[k, :, :])
for k in range(batch_size)
],
axis=0,
) # [batch_size, 2, 3]
A = A.astype(np.float32)
return A
......@@ -16,9 +16,11 @@ def get_labeled_folders(cwd, config):
:param config: dict of the config file
:return: paths of folders
"""
data_dir = cwd + config["dir"]["data"] + 'img/labeled'
data_dir = cwd + config["dir"]["data"] + "img/labeled"
folders = [f.path for f in os.scandir(data_dir) if f.is_dir()]
folders = [x for x in folders if 'HLS' not in x] # remove folders containing HLS in the name
folders = [
x for x in folders if "HLS" not in x
] # remove folders containing HLS in the name
folders = sorted(folders)
return folders
......@@ -33,9 +35,12 @@ def get_unlabeled_folders(folders_labeled, config):
:param config:
:return: same order as folders_labeled, but some folders might not exist
"""
if not config['data']['ssl']['activate']:
if not config["data"]["ssl"]["activate"]:
return None
folders = [x.replace('labeled', 'unlabeled/fps%d' % config['data']['ssl']['fps']) for x in folders_labeled]
folders = [
x.replace("labeled", "unlabeled/fps%d" % config["data"]["ssl"]["fps"])
for x in folders_labeled
]