"""
.. codeauthor:: Mona Koehler
.. codeauthor:: Daniel Seichter
"""
import argparse
from datetime import datetime
import json
import pickle
import os
import sys
import time
import warnings
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torch.optim
from torch.optim.lr_scheduler import OneCycleLR
from src.args import ArgumentParserRGBDSegmentation
from src.build_model import build_model
from src import utils
from src.prepare_data import prepare_data
from src.utils import save_ckpt, save_ckpt_every_epoch
from src.utils import load_ckpt
from src.utils import print_log
import tensorflow as tf
from src.logger import CSVLogger
from src.confusion_matrix import ConfusionMatrixTensorflow
os.environ['CUDA_VISIBLE_DEVICES'] = '2'
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
tf.compat.v1.disable_eager_execution()
def parse_args():
parser = ArgumentParserRGBDSegmentation(
description='Efficient RGBD Indoor Sematic Segmentation (Training)',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.set_common_args()
args = parser.parse_args()
return args
def train_main():
args = parse_args()
training_starttime = datetime.now().strftime("%d_%m_%Y-%H_%M_%S-%f")
ckpt_dir = os.path.join(args.results_dir, args.dataset,
f'checkpoints_{training_starttime}')
os.makedirs(ckpt_dir, exist_ok=True)
os.makedirs(os.path.join(ckpt_dir, 'confusion_matrices'), exist_ok=True)
with open(os.path.join(ckpt_dir, 'args.json'), 'w') as f:
json.dump(vars(args), f, sort_keys=True, indent=4)
with open(os.path.join(ckpt_dir, 'argsv.txt'), 'w') as f:
f.write(' '.join(sys.argv))
f.write('\n')
label_downsampling_rates = [16, 8, 4]
data_loaders = prepare_data(args, ckpt_dir)
"""
.. codeauthor:: Mona Koehler
.. codeauthor:: Daniel Seichter
"""
import copy
import os
import pickle
from torch.utils.data import DataLoader
from src import preprocessing
from src.datasets import Cityscapes
from src.datasets import NYUv2
from src.datasets import SceneNetRGBD
from src.datasets import SUNRGBD
def prepare_data(args, ckpt_dir=None, with_input_orig=False, split=None):
train_preprocessor_kwargs = {}
if args.dataset == 'sunrgbd':
Dataset = SUNRGBD
dataset_kwargs = {}
valid_set = 'test'
elif args.dataset == 'nyuv2':
Dataset = NYUv2
dataset_kwargs = {'n_classes': 40}
valid_set = 'test'
elif args.dataset == 'cityscapes':
Dataset = Cityscapes
dataset_kwargs = {
'n_classes': 19,
'disparity_instead_of_depth': True
}
valid_set = 'valid'
elif args.dataset == 'cityscapes-with-depth':
Dataset = Cityscapes
dataset_kwargs = {
'n_classes': 19,
'disparity_instead_of_depth': False
}
valid_set = 'valid'
elif args.dataset == 'scenenetrgbd':
Dataset = SceneNetRGBD
dataset_kwargs = {'n_classes': 13}
valid_set = 'valid'
if args.width == 640 and args.height == 480:
train_preprocessor_kwargs['train_random_rescale'] = (1.0*2, 1.4*2)
else:
raise ValueError(f"Unknown dataset: `{args.dataset}`")
if args.aug_scale_min != 1 or args.aug_scale_max != 1.4:
train_preprocessor_kwargs['train_random_rescale'] = (
args.aug_scale_min, args.aug_scale_max)
if split in ['valid', 'test']:
valid_set = split
if args.raw_depth:
depth_mode = 'raw'
else:
depth_mode = 'refined'
"""
.. codeauthor:: Daniel Seichter
"""
import os
import cv2
import numpy as np
from ..dataset_base import DatasetBase
from .nyuv2 import NYUv2Base
class NYUv2(NYUv2Base, DatasetBase):
def __init__(self,
data_dir=None,
n_classes=40,
split='train',
depth_mode='refined',
with_input_orig=False):
super(NYUv2, self).__init__()
assert split in self.SPLITS
assert n_classes in self.N_CLASSES
assert depth_mode in ['refined', 'raw']
self._n_classes = n_classes
self._split = split
self._depth_mode = depth_mode
self._with_input_orig = with_input_orig
self._cameras = ['kv1']
if data_dir is not None:
data_dir = os.path.expanduser(data_dir)
assert os.path.exists(data_dir)
self._data_dir = data_dir
fp = os.path.join(self._data_dir,
self.SPLIT_FILELIST_FILENAMES[self._split])
self._filenames = np.loadtxt(fp, dtype=str)
else:
print(f"Loaded {self.__class__.__name__} dataset without files")
CLASS_NAMES_40 = ['void',
'wall', 'floor', 'cabinet', 'bed', 'chair', 'sofa',
'table', 'door', 'window', 'bookshelf', 'picture',
'counter', 'blinds', 'desk', 'shelves', 'curtain',
'dresser', 'pillow', 'mirror', 'floor mat', 'clothes',
'ceiling', 'books', 'refridgerator', 'television',
'paper', 'towel', 'shower curtain', 'box', 'whiteboard',
'person', 'night stand', 'toilet', 'sink', 'lamp',
'bathtub', 'bag',
'otherstructure', 'otherfurniture', 'otherprop']
self._class_names = getattr(self, f'CLASS_NAMES_{self._n_classes}')
CLASS_COLORS_13 = [[0, 0, 0],
[0, 0, 255],
[232, 88, 47],
[0, 217, 0],
[148, 0, 240],
[222, 241, 23],
[255, 205, 205],
[0, 223, 228],
[106, 135, 204],
[116, 28, 41],
[240, 35, 235],
[0, 166, 156],
[249, 139, 0],
[225, 228, 194]]
self._class_colors = np.array(
getattr(self, f'CLASS_COLORS_{self._n_classes}'),
dtype='uint8'
)
self._depth_mean = 2841.94941272766
self._depth_std = 1417.259428167227
@property
def cameras(self):
return self._cameras
@property
def class_names(self):
return self._class_names
@property
def class_names_without_void(self):
return self._class_names[1:]
@property
def class_colors(self):
return self._class_colors
@property
def class_colors_without_void(self):
return self._class_colors[1:]
@property
def n_classes(self):
return self._n_classes + 1
@property
def n_classes_without_void(self):
return self._n_classes
@property
def split(self):
return self._split
@property
def depth_mode(self):
return self._depth_mode
@property
def depth_mean(self):
return self._depth_mean
@property
def depth_std(self):
return self._depth_std
@property
def source_path(self):
return os.path.abspath(os.path.dirname(__file__))
@property
def with_input_orig(self):
return self._with_input_orig
def _load(self, directory, filename):
fp = os.path.join(self._data_dir,
self.split,
directory,
f'{filename}.png')
im = cv2.imread(fp, cv2.IMREAD_UNCHANGED)
if im.ndim == 3:
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
return im
def load_image(self, idx):
return self._load(self.RGB_DIR, self._filenames[idx])
def load_depth(self, idx):
if self._depth_mode == 'raw':
return self._load(self.DEPTH_RAW_DIR, self._filenames[idx])
else:
return self._load(self.DEPTH_DIR, self._filenames[idx])
def load_label(self, idx):
return self._load(self.LABELS_DIR_FMT.format(self._n_classes),
self._filenames[idx])
def __len__(self):
return len(self._filenames)
train_data = Dataset(
data_dir=args.dataset_dir,
split='train',
depth_mode=depth_mode,
with_input_orig=with_input_orig,
**dataset_kwargs
)
"""
.. codeauthor:: Mona Koehler
.. codeauthor:: Daniel Seichter
This code is partially adapted from RedNet
(https://github.com/JindongJiang/RedNet/blob/master/RedNet_data.py)
"""
import cv2
import matplotlib
import matplotlib.colors
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
def get_preprocessor(depth_mean,
depth_std,
depth_mode='refined',
height=None,
width=None,
phase='train',
train_random_rescale=(1.0, 1.4)):
assert phase in ['train', 'test']
if phase == 'train':
transform_list = [
RandomRescale(train_random_rescale),
RandomCrop(crop_height=height, crop_width=width),
RandomHSV((0.9, 1.1),
(0.9, 1.1),
(25, 25)),
RandomFlip(),
ToTensor(),
Normalize(depth_mean=depth_mean,
depth_std=depth_std,
depth_mode=depth_mode),
MultiScaleLabel(downsampling_rates=[16, 8, 4])
]
else:
if height is None and width is None:
transform_list = []
else:
transform_list = [Rescale(height=height, width=width)]
transform_list.extend([
ToTensor(),
Normalize(depth_mean=depth_mean,
depth_std=depth_std,
depth_mode=depth_mode)
])
transform = transforms.Compose(transform_list)
return transform
class Rescale:
def __init__(self, height, width):
self.height = height
self.width = width
def __call__(self, sample):
image, depth = sample['image'], sample['depth']
image = cv2.resize(image, (self.width, self.height),
interpolation=cv2.INTER_LINEAR)
depth = cv2.resize(depth, (self.width, self.height),
interpolation=cv2.INTER_NEAREST)
sample['image'] = image
sample['depth'] = depth
if 'label' in sample:
label = sample['label']
label = cv2.resize(label, (self.width, self.height),
interpolation=cv2.INTER_NEAREST)
sample['label'] = label
return sample
class RandomRescale:
def __init__(self, scale):
self.scale_low = min(scale)
self.scale_high = max(scale)
def __call__(self, sample):
image, depth, label = sample['image'], sample['depth'], sample['label']
target_scale = np.random.uniform(self.scale_low, self.scale_high)
target_height = int(round(target_scale * image.shape[0]))
target_width = int(round(target_scale * image.shape[1]))
image = cv2.resize(image, (target_width, target_height),
interpolation=cv2.INTER_LINEAR)
depth = cv2.resize(depth, (target_width, target_height),
interpolation=cv2.INTER_NEAREST)
label = cv2.resize(label, (target_width, target_height),
interpolation=cv2.INTER_NEAREST)
sample['image'] = image
sample['depth'] = depth
sample['label'] = label
return sample
class RandomCrop:
def __init__(self, crop_height, crop_width):
self.crop_height = crop_height
self.crop_width = crop_width
self.rescale = Rescale(self.crop_height, self.crop_width)
def __call__(self, sample):
image, depth, label = sample['image'], sample['depth'], sample['label']
h = image.shape[0]
w = image.shape[1]
if h <= self.crop_height or w <= self.crop_width:
sample = self.rescale(sample)
else:
i = np.random.randint(0, h - self.crop_height)
j = np.random.randint(0, w - self.crop_width)
image = image[i:i + self.crop_height, j:j + self.crop_width, :]
depth = depth[i:i + self.crop_height, j:j + self.crop_width]
label = label[i:i + self.crop_height, j:j + self.crop_width]
sample['image'] = image
sample['depth'] = depth
sample['label'] = label
return sample
class RandomHSV:
def __init__(self, h_range, s_range, v_range):
assert isinstance(h_range, (list, tuple)) and \
isinstance(s_range, (list, tuple)) and \
isinstance(v_range, (list, tuple))
self.h_range = h_range
self.s_range = s_range
self.v_range = v_range
def __call__(self, sample):
img = sample['image']
img_hsv = matplotlib.colors.rgb_to_hsv(img)
img_h = img_hsv[:, :, 0]
img_s = img_hsv[:, :, 1]
img_v = img_hsv[:, :, 2]
h_random = np.random.uniform(min(self.h_range), max(self.h_range))
s_random = np.random.uniform(min(self.s_range), max(self.s_range))
v_random = np.random.uniform(-min(self.v_range), max(self.v_range))
img_h = np.clip(img_h * h_random, 0, 1)
img_s = np.clip(img_s * s_random, 0, 1)
img_v = np.clip(img_v + v_random, 0, 255)
img_hsv = np.stack([img_h, img_s, img_v], axis=2)
img_new = matplotlib.colors.hsv_to_rgb(img_hsv)
sample['image'] = img_new
return sample
class RandomFlip:
def __call__(self, sample):
image, depth, label = sample['image'], sample['depth'], sample['label']
if np.random.rand() > 0.5:
image = np.fliplr(image).copy()
depth = np.fliplr(depth).copy()
label = np.fliplr(label).copy()
sample['image'] = image
sample['depth'] = depth
sample['label'] = label
return sample
class Normalize:
def __init__(self, depth_mean, depth_std, depth_mode='refined'):
assert depth_mode in ['refined', 'raw']
self._depth_mode = depth_mode
self._depth_mean = [depth_mean]
self._depth_std = [depth_std]
def __call__(self, sample):
image, depth = sample['image'], sample['depth']
image = image / 255
image = torchvision.transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(image)
if self._depth_mode == 'raw':
depth_0 = depth == 0
depth = torchvision.transforms.Normalize(
mean=self._depth_mean, std=self._depth_std)(depth)
depth[depth_0] = 0
else:
depth = torchvision.transforms.Normalize(
mean=self._depth_mean, std=self._depth_std)(depth)
sample['image'] = image
sample['depth'] = depth
return sample
class ToTensor:
def __call__(self, sample):
image, depth = sample['image'], sample['depth']
image = image.transpose((2, 0, 1))
depth = np.expand_dims(depth, 0).astype('float32')
sample['image'] = torch.from_numpy(image).float()
sample['depth'] = torch.from_numpy(depth).float()
if 'label' in sample:
label = sample['label']
sample['label'] = torch.from_numpy(label).float()
return sample
class MultiScaleLabel:
def __init__(self, downsampling_rates=None):
if downsampling_rates is None:
self.downsampling_rates = [16, 8, 4]
else:
self.downsampling_rates = downsampling_rates
def __call__(self, sample):
label = sample['label']
h, w = label.shape
sample['label_down'] = dict()
for rate in self.downsampling_rates:
label_down = cv2.resize(label.numpy(), (w // rate, h // rate),
interpolation=cv2.INTER_NEAREST)
sample['label_down'][rate] = torch.from_numpy(label_down)
return sample
train_preprocessor = preprocessing.get_preprocessor(
height=args.height,
width=args.width,
depth_mean=train_data.depth_mean,
depth_std=train_data.depth_std,
depth_mode=depth_mode,
phase='train',
**train_preprocessor_kwargs
)
train_data.preprocessor = train_preprocessor
if ckpt_dir is not None:
pickle_file_path = os.path.join(ckpt_dir, 'depth_mean_std.pickle')
if os.path.exists(pickle_file_path):
with open(pickle_file_path, 'rb') as f:
depth_stats = pickle.load(f)
print(f'Loaded depth mean and std from {pickle_file_path}')
print(depth_stats)
else:
depth_stats = {'mean': train_data.depth_mean,
'std': train_data.depth_std}
with open(pickle_file_path, 'wb') as f:
pickle.dump(depth_stats, f)
else:
depth_stats = {'mean': train_data.depth_mean,
'std': train_data.depth_std}
valid_preprocessor = preprocessing.get_preprocessor(
height=args.height,
width=args.width,
depth_mean=depth_stats['mean'],
depth_std=depth_stats['std'],
depth_mode=depth_mode,
phase='test'
)
if args.valid_full_res:
valid_preprocessor_full_res = preprocessing.get_preprocessor(
depth_mean=depth_stats['mean'],
depth_std=depth_stats['std'],
depth_mode=depth_mode,
phase='test'
)
valid_data = Dataset(
data_dir=args.dataset_dir,
split=valid_set,
depth_mode=depth_mode,
with_input_orig=with_input_orig,
**dataset_kwargs
)
valid_data.preprocessor = valid_preprocessor
if args.dataset_dir is None:
if args.valid_full_res:
return valid_data, valid_preprocessor_full_res
else:
return valid_data, valid_preprocessor
train_loader = DataLoader(train_data,
batch_size=args.batch_size,
num_workers=args.workers,
drop_last=True,
shuffle=True)
batch_size_valid = args.batch_size_valid or args.batch_size
valid_loader = DataLoader(valid_data,
batch_size=batch_size_valid,
num_workers=args.workers,
shuffle=False)
if args.valid_full_res:
valid_loader_full_res = copy.deepcopy(valid_loader)
valid_loader_full_res.dataset.preprocessor = valid_preprocessor_full_res
return train_loader, valid_loader, valid_loader_full_res
return train_loader, valid_loader
if args.valid_full_res:
train_loader, valid_loader, valid_loader_full_res = data_loaders
else:
train_loader, valid_loader = data_loaders
valid_loader_full_res = None
cameras = train_loader.dataset.cameras
n_classes_without_void = train_loader.dataset.n_classes_without_void
if args.class_weighting != 'None':
class_weighting = train_loader.dataset.compute_class_weights(
weight_mode=args.class_weighting,
c=args.c_for_logarithmic_weighting)
else:
class_weighting = np.ones(n_classes_without_void)