ESANet前半部分代码解析

# -*- coding: utf-8 -*-
"""
.. codeauthor:: Mona Koehler 
.. codeauthor:: Daniel Seichter 
"""
import argparse
from datetime import datetime
import json
import pickle
import os
import sys
import time
import warnings

import numpy as np

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torch.optim
from torch.optim.lr_scheduler import OneCycleLR
# from torch.optim.lr_scheduler import StepLR
# from src.lr_policy import PolyLR
from src.args import ArgumentParserRGBDSegmentation
from src.build_model import build_model
from src import utils
from src.prepare_data import prepare_data
from src.utils import save_ckpt, save_ckpt_every_epoch
from src.utils import load_ckpt
from src.utils import print_log
import tensorflow as tf
from src.logger import CSVLogger
from src.confusion_matrix import ConfusionMatrixTensorflow

os.environ['CUDA_VISIBLE_DEVICES'] = '2'
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
tf.compat.v1.disable_eager_execution()


def parse_args():
    parser = ArgumentParserRGBDSegmentation(
        description='Efficient RGBD Indoor Sematic Segmentation (Training)',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.set_common_args()
    args = parser.parse_args()

    # The provided learning rate refers to the default batch size of 8.
    # When using different batch sizes we need to adjust the learning rate
    # accordingly:
    # if args.batch_size != 8: #lr = 0.0025
    #     args.lr = args.lr * args.batch_size / 8
    #     warnings.warn(f'Adapting learning rate to {args.lr} because provided '
    #                   f'batch size differs from default batch size of 8.')

    return args


def train_main():
    args = parse_args()

    # directory for storing weights and other training related files
    training_starttime = datetime.now().strftime("%d_%m_%Y-%H_%M_%S-%f")
    # 保存的权重,在results/nyuv2下
    ckpt_dir = os.path.join(args.results_dir, args.dataset,
                            f'checkpoints_{training_starttime}')
    os.makedirs(ckpt_dir, exist_ok=True)
    os.makedirs(os.path.join(ckpt_dir, 'confusion_matrices'), exist_ok=True)
    #vars返回args里面字典的键值对
    #with open打开json文件,jason.dump将args存储到ckpt_dir下的jason文件中,按字母的顺序,且用空格分开
    with open(os.path.join(ckpt_dir, 'args.json'), 'w') as f:
        json.dump(vars(args), f, sort_keys=True, indent=4)
	#打开ckpt_dir,文件夹下的'argsv.txt','argsv.txt'是train.py后面紧跟的参数
	#而sys.argv就是获得这些参数,包括train.py,然后写入到f中。
	#[sys.argv解析](https://www.cnblogs.com/aland-1415/p/6613449.html)
    with open(os.path.join(ckpt_dir, 'argsv.txt'), 'w') as f:
        f.write(' '.join(sys.argv))
        f.write('\n')

    # when using multi scale supervision the label needs to be downsampled.
    label_downsampling_rates = [16, 8, 4]

    # data preparation ---------------------------------------------------------
    data_loaders = prepare_data(args, ckpt_dir)

	#prepare_data代码:
##########################################################################################################################
	# -*- coding: utf-8 -*-
"""
.. codeauthor:: Mona Koehler 
.. codeauthor:: Daniel Seichter 
"""
import copy
import os
import pickle
from torch.utils.data import DataLoader
from src import preprocessing
from src.datasets import Cityscapes
from src.datasets import NYUv2
from src.datasets import SceneNetRGBD
from src.datasets import SUNRGBD

def prepare_data(args, ckpt_dir=None, with_input_orig=False, split=None):
    train_preprocessor_kwargs = {}
    if args.dataset == 'sunrgbd':
        Dataset = SUNRGBD
        dataset_kwargs = {}
        valid_set = 'test'
    #nyu数据集
    elif args.dataset == 'nyuv2':
        Dataset = NYUv2
        dataset_kwargs = {'n_classes': 40}
        valid_set = 'test'
    elif args.dataset == 'cityscapes':
        Dataset = Cityscapes
        dataset_kwargs = {
            'n_classes': 19,
            'disparity_instead_of_depth': True
        }
        valid_set = 'valid'
    elif args.dataset == 'cityscapes-with-depth':
        Dataset = Cityscapes
        dataset_kwargs = {
            'n_classes': 19,
            'disparity_instead_of_depth': False
        }
        valid_set = 'valid'
    elif args.dataset == 'scenenetrgbd':
        Dataset = SceneNetRGBD
        dataset_kwargs = {'n_classes': 13}
        valid_set = 'valid'
        if args.width == 640 and args.height == 480:
            # for SceneNetRGBD, we additionally scale up the images by factor
            # of 2
            train_preprocessor_kwargs['train_random_rescale'] = (1.0*2, 1.4*2)
    else:
        raise ValueError(f"Unknown dataset: `{args.dataset}`")
    #图像预处理阶段,随机缩放的参数,满足条件,不执行。
    #self.add_argument('--aug_scale_min', default=1.0, type=float,help='the minimum scale for random rescaling the ''
    #training data.')
    #self.add_argument('--aug_scale_max', default=1.4, type=float,help='the maximum scale for random rescaling the '
    #'training data.')
    if args.aug_scale_min != 1 or args.aug_scale_max != 1.4:
        train_preprocessor_kwargs['train_random_rescale'] = (
            args.aug_scale_min, args.aug_scale_max)

    if split in ['valid', 'test']:
        valid_set = split
	#不使用原始的raw depth
    if args.raw_depth:
        # We can not expect the model to predict depth values that are just
        # interpolated and not really there. It is better to let the model only
        # predict the measured depth values and ignore the rest.
        depth_mode = 'raw'
    else:
        depth_mode = 'refined'

    # train data,这里使用NYU为例
##################################################################################################################
  	# -*- coding: utf-8 -*-
"""
.. codeauthor:: Daniel Seichter 
"""
import os
import cv2
import numpy as np
from ..dataset_base import DatasetBase
from .nyuv2 import NYUv2Base

class NYUv2(NYUv2Base, DatasetBase):
    def __init__(self,
                 data_dir=None,
                 n_classes=40,
                 split='train',
                 depth_mode='refined',
                 with_input_orig=False):
        super(NYUv2, self).__init__()
        assert split in self.SPLITS
        assert n_classes in self.N_CLASSES
        assert depth_mode in ['refined', 'raw']

        self._n_classes = n_classes
        self._split = split
        self._depth_mode = depth_mode
        self._with_input_orig = with_input_orig
        self._cameras = ['kv1']
		#如果数据的地址不为空,原地创建一个。
        if data_dir is not None:
            data_dir = os.path.expanduser(data_dir)
            assert os.path.exists(data_dir)
            self._data_dir = data_dir

            # load filenames
            #fp = 数据文件下的SPLIT_FILELIST_FILENAMES的train。
            # SPLIT_FILELIST_FILENAMES = {SPLITS[0]: 'train.txt', SPLITS[1]: 'test.txt'}
            fp = os.path.join(self._data_dir,
                              self.SPLIT_FILELIST_FILENAMES[self._split])
            #载入fp,即载入train.txt的内容,读取后的数据类型是string格式。如:‘1,0003’。           
            self._filenames = np.loadtxt(fp, dtype=str)
        else:
            print(f"Loaded {self.__class__.__name__} dataset without files")

        # load class names
        # getattr获取属性名CLASS_NAMES_40对应的值,即所有的类别。
        ######################################################
            CLASS_NAMES_40 = ['void',
                      'wall', 'floor', 'cabinet', 'bed', 'chair', 'sofa',
                      'table', 'door', 'window', 'bookshelf', 'picture',
                      'counter', 'blinds', 'desk', 'shelves', 'curtain',
                      'dresser', 'pillow', 'mirror', 'floor mat', 'clothes',
                      'ceiling', 'books', 'refridgerator', 'television',
                      'paper', 'towel', 'shower curtain', 'box', 'whiteboard',
                      'person', 'night stand', 'toilet', 'sink', 'lamp',
                      'bathtub', 'bag',
                      'otherstructure', 'otherfurniture', 'otherprop']
        ######################################################
        self._class_names = getattr(self, f'CLASS_NAMES_{self._n_classes}')

        # load class colors
        #getattr获取CLASS_COLORS_40属性对应的值,即40种类别对应的颜色。
        ######################################################
            CLASS_COLORS_13 = [[0, 0, 0],
                       [0, 0, 255],
                       [232, 88, 47],
                       [0, 217, 0],
                       [148, 0, 240],
                       [222, 241, 23],
                       [255, 205, 205],
                       [0, 223, 228],
                       [106, 135, 204],
                       [116, 28, 41],
                       [240, 35, 235],
                       [0, 166, 156],
                       [249, 139, 0],
                       [225, 228, 194]]
   		######################################################
   		# np.array将CLASS_COLORS_13 对应的值转换为'uint8'类型。
   		
        self._class_colors = np.array(
            getattr(self, f'CLASS_COLORS_{self._n_classes}'),
            dtype='uint8'
        )

        # note that mean and std differ depending on the selected depth_mode
        # however, the impact is marginal, therefore, we decided to use the
        # stats for refined depth for both cases
        # stats for raw: mean: 2769.0187903686697, std: 1350.4174149841133
        # 深度的均值和方差
        self._depth_mean = 2841.94941272766
        self._depth_std = 1417.259428167227
       
    @property
    def cameras(self):
        return self._cameras
	#装饰器
	
    @property
    def class_names(self):
        return self._class_names

    @property
    def class_names_without_void(self):
        return self._class_names[1:]

    @property
    def class_colors(self):
        return self._class_colors

    @property
    def class_colors_without_void(self):
        return self._class_colors[1:]

    @property
    def n_classes(self):
        return self._n_classes + 1

    @property
    def n_classes_without_void(self):
        return self._n_classes

    @property
    def split(self):
        return self._split

    @property
    def depth_mode(self):
        return self._depth_mode

    @property
    def depth_mean(self):
        return self._depth_mean

    @property
    def depth_std(self):
        return self._depth_std

    @property
    def source_path(self):
        return os.path.abspath(os.path.dirname(__file__))

    @property
    def with_input_orig(self):
        return self._with_input_orig
	#类方法:载入data_dir下的train文件下的rgb或者depth文件夹下的图片。
	#读入图片,如果图片的维度为3,将BGR转换为RGB通道。
    def _load(self, directory, filename):
        fp = os.path.join(self._data_dir,
                          self.split,
                          directory,
                          f'{filename}.png')
        #用图片的原来的格式打开,即BGR。                 
        im = cv2.imread(fp, cv2.IMREAD_UNCHANGED)
        if im.ndim == 3:
            im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
        return im
        
	#调用刚才的load函数,载入rgb文件
	#directory =self.RGB_DIR = 'RGB', filename = self._filenames[idx]=根据train.txt载入train文件夹下的图片
    def load_image(self, idx):
        return self._load(self.RGB_DIR, self._filenames[idx])
        
	#调用刚才的load函数,载入depth文件
    def load_depth(self, idx):
        if self._depth_mode == 'raw':
            return self._load(self.DEPTH_RAW_DIR, self._filenames[idx])
        else:
            return self._load(self.DEPTH_DIR, self._filenames[idx])
            
	#载入标签
	#directory=self.LABELS_DIR_FMT.format(self._n_classes)=LABELS_DIR_FMT = 'labels_{:d}'=label_40
	#即进入train文件夹下的label_40文件夹,根据filename的索引读取label标签,这里不是彩色的。
    def load_label(self, idx):
        return self._load(self.LABELS_DIR_FMT.format(self._n_classes),
                          self._filenames[idx])
                          
	#返回整个train.txt或者test.txt的长度。
    def __len__(self):
        return len(self._filenames)

  	##################################################################################################################
  	#Dataset=NYU,所以NYU的形参就是Dataset的形参,这里将Dataset替换为NYU也是可以的。
  	#所以NYU的类方法可以通过train_data来调用,包括装饰器。
    train_data = Dataset(
        data_dir=args.dataset_dir,
        split='train',
        depth_mode=depth_mode,
        with_input_orig=with_input_orig,
        **dataset_kwargs
    )
	#train数据预处理,可以看到调用preprocessingd的get_preprocessor处理方法
	###########################################################################################
	# -*- coding: utf-8 -*-
"""
.. codeauthor:: Mona Koehler 
.. codeauthor:: Daniel Seichter 

This code is partially adapted from RedNet
(https://github.com/JindongJiang/RedNet/blob/master/RedNet_data.py)
"""
import cv2
import matplotlib
import matplotlib.colors
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
def get_preprocessor(depth_mean,
                     depth_std,
                     depth_mode='refined',
                     height=None,
                     width=None,
                     phase='train',#训练阶段
                     train_random_rescale=(1.0, 1.4)):
    assert phase in ['train', 'test']
	
	#判断处于训练阶段,使用下面的预处理方法
    if phase == 'train':
        transform_list = [
            RandomRescale(train_random_rescale),
            RandomCrop(crop_height=height, crop_width=width),
            RandomHSV((0.9, 1.1),
                      (0.9, 1.1),
                      (25, 25)),
            RandomFlip(),
            ToTensor(),
            Normalize(depth_mean=depth_mean,
                      depth_std=depth_std,
                      depth_mode=depth_mode),
            MultiScaleLabel(downsampling_rates=[16, 8, 4])
        ]
    #测试阶段   
    else:
        if height is None and width is None:
            transform_list = []
        else:#首先缩放,然后转为tensor,再进行归一化。
            transform_list = [Rescale(height=height, width=width)]
        transform_list.extend([
            ToTensor(),
            Normalize(depth_mean=depth_mean,
                      depth_std=depth_std,
                      depth_mode=depth_mode)
        ])
    #最后通过transforms.composed 将处理方法结合到一起
    transform = transforms.Compose(transform_list)
    return transform
   
   
class Rescale:
    def __init__(self, height, width):
        self.height = height
        self.width = width
	#__call__可以实现调用功能,确保所有的处理操作是可调用的
    def __call__(self, sample): 
        image, depth = sample['image'], sample['depth']
		#通过线性插值将rgb和深度recale到self.width, self.height大小。
        image = cv2.resize(image, (self.width, self.height),
                           interpolation=cv2.INTER_LINEAR)
        depth = cv2.resize(depth, (self.width, self.height),
                           interpolation=cv2.INTER_NEAREST)

        sample['image'] = image
        sample['depth'] = depth

        if 'label' in sample:
            label = sample['label']
            #将标签rescale到和rgb,depth一样大小。
            label = cv2.resize(label, (self.width, self.height),
                               interpolation=cv2.INTER_NEAREST)
            sample['label'] = label

        return sample
        
#随机缩放图片大小        
class RandomRescale:
    def __init__(self, scale):
        self.scale_low = min(scale)
        self.scale_high = max(scale)

    def __call__(self, sample):
        image, depth, label = sample['image'], sample['depth'], sample['label']
		#产生一个self.scale_low到self.scale_high大小的随机数。
        target_scale = np.random.uniform(self.scale_low, self.scale_high)
        # 将图片的长和宽与设定的scale相乘,最后四舍五入并转换为整数。
        target_height = int(round(target_scale * image.shape[0]))
        target_width = int(round(target_scale * image.shape[1]))
		#通过双线性将图片插值到目标大小。image采用线性插值,depth和label采用最近邻插值。
        image = cv2.resize(image, (target_width, target_height),
                           interpolation=cv2.INTER_LINEAR)
        depth = cv2.resize(depth, (target_width, target_height),
                           interpolation=cv2.INTER_NEAREST)
        label = cv2.resize(label, (target_width, target_height),
                           interpolation=cv2.INTER_NEAREST)

        sample['image'] = image
        sample['depth'] = depth
        sample['label'] = label

        return sample
 #随机裁剪       
class RandomCrop:
    def __init__(self, crop_height, crop_width):
        self.crop_height = crop_height
        self.crop_width = crop_width
        #首先将图片rescale到需要裁减的高和宽
        self.rescale = Rescale(self.crop_height, self.crop_width)

    def __call__(self, sample):
        image, depth, label = sample['image'], sample['depth'], sample['label']
        h = image.shape[0]
        w = image.shape[1]
        #如果图像的高和宽小于裁剪大小,那么直接通过随机缩放。
        if h <= self.crop_height or w <= self.crop_width:
            # simply rescale instead of random crop as image is not large enough
            sample = self.rescale(sample)
        #否则图像的高和宽大于裁剪大小
        else:
        	#随机生成0到 h - self.crop_height大小的整数
            i = np.random.randint(0, h - self.crop_height)
            j = np.random.randint(0, w - self.crop_width)
            image = image[i:i + self.crop_height, j:j + self.crop_width, :]
            depth = depth[i:i + self.crop_height, j:j + self.crop_width]
            label = label[i:i + self.crop_height, j:j + self.crop_width]
            sample['image'] = image
            sample['depth'] = depth
            sample['label'] = label
        return sample
        
class RandomHSV:
    def __init__(self, h_range, s_range, v_range):
        assert isinstance(h_range, (list, tuple)) and \
               isinstance(s_range, (list, tuple)) and \
               isinstance(v_range, (list, tuple))
        self.h_range = h_range
        self.s_range = s_range
        self.v_range = v_range

    def __call__(self, sample):
        img = sample['image']
        img_hsv = matplotlib.colors.rgb_to_hsv(img)
        img_h = img_hsv[:, :, 0]
        img_s = img_hsv[:, :, 1]
        img_v = img_hsv[:, :, 2]

        h_random = np.random.uniform(min(self.h_range), max(self.h_range))
        s_random = np.random.uniform(min(self.s_range), max(self.s_range))
        v_random = np.random.uniform(-min(self.v_range), max(self.v_range))
        img_h = np.clip(img_h * h_random, 0, 1)
        img_s = np.clip(img_s * s_random, 0, 1)
        img_v = np.clip(img_v + v_random, 0, 255)
        img_hsv = np.stack([img_h, img_s, img_v], axis=2)
        img_new = matplotlib.colors.hsv_to_rgb(img_hsv)

        sample['image'] = img_new

        return sample
class RandomFlip:
    def __call__(self, sample):
        image, depth, label = sample['image'], sample['depth'], sample['label']
        if np.random.rand() > 0.5:
        #在左右方向上翻转
            image = np.fliplr(image).copy()
            depth = np.fliplr(depth).copy()
            label = np.fliplr(label).copy()

        sample['image'] = image
        sample['depth'] = depth
        sample['label'] = label

        return sample
#归一化        
class Normalize:
    def __init__(self, depth_mean, depth_std, depth_mode='refined'):
        assert depth_mode in ['refined', 'raw']
        self._depth_mode = depth_mode #'refined'
        self._depth_mean = [depth_mean]
        self._depth_std = [depth_std]

    def __call__(self, sample):
        image, depth = sample['image'], sample['depth']
        image = image / 255
        #将每个颜色通道的平均值和标准差传递给 Normalize() 变换
        image = torchvision.transforms.Normalize(
            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(image)
        if self._depth_mode == 'raw':
            depth_0 = depth == 0

            depth = torchvision.transforms.Normalize(
                mean=self._depth_mean, std=self._depth_std)(depth)

            # set invalid values back to zero again
            depth[depth_0] = 0

        else:
        	#深度只有一个通道
            depth = torchvision.transforms.Normalize(
                mean=self._depth_mean, std=self._depth_std)(depth)
        # depth = depth / torch.max(depth)

        sample['image'] = image
        sample['depth'] = depth
        return sample
        
#转换为tensor        
class ToTensor:
    def __call__(self, sample):
        image, depth = sample['image'], sample['depth']
        image = image.transpose((2, 0, 1))
        depth = np.expand_dims(depth, 0).astype('float32')
		#将numpy转换为tensor
        sample['image'] = torch.from_numpy(image).float()
        sample['depth'] = torch.from_numpy(depth).float()

        if 'label' in sample:
            label = sample['label']
            sample['label'] = torch.from_numpy(label).float()

        return sample

#多尺度标签,用于深监督使用
class MultiScaleLabel:
    def __init__(self, downsampling_rates=None):
        if downsampling_rates is None:
            self.downsampling_rates = [16, 8, 4]
        else:
            self.downsampling_rates = downsampling_rates

    def __call__(self, sample):
        label = sample['label']

        h, w = label.shape

        sample['label_down'] = dict()

        # Nearest neighbor interpolation
        for rate in self.downsampling_rates:
        	#将标签resize为(w // rate, h // rate)大小。
            label_down = cv2.resize(label.numpy(), (w // rate, h // rate),
                                    interpolation=cv2.INTER_NEAREST)
            #将label_down转换为tensor,每一个下采样率对应一个大小。                       
            sample['label_down'][rate] = torch.from_numpy(label_down)

        return sample

	###########################################################################################
   
   #调用get_preprocessor就相当于调用transforme的一系列造作。
    train_preprocessor = preprocessing.get_preprocessor(
        height=args.height,
        width=args.width,
        depth_mean=train_data.depth_mean,
        depth_std=train_data.depth_std,
        depth_mode=depth_mode,
        phase='train',
        **train_preprocessor_kwargs
    )
	#NYU
    train_data.preprocessor = train_preprocessor

    if ckpt_dir is not None:
        pickle_file_path = os.path.join(ckpt_dir, 'depth_mean_std.pickle')
        if os.path.exists(pickle_file_path):
            with open(pickle_file_path, 'rb') as f:
                depth_stats = pickle.load(f)
            print(f'Loaded depth mean and std from {pickle_file_path}')
            print(depth_stats)
        else:
            # dump depth stats
            depth_stats = {'mean': train_data.depth_mean,
                           'std': train_data.depth_std}
            with open(pickle_file_path, 'wb') as f:
                pickle.dump(depth_stats, f)
    else:
        depth_stats = {'mean': train_data.depth_mean,
                       'std': train_data.depth_std}

    # valid data
    valid_preprocessor = preprocessing.get_preprocessor(
        height=args.height,
        width=args.width,
        depth_mean=depth_stats['mean'],
        depth_std=depth_stats['std'],
        depth_mode=depth_mode,
        phase='test'
    )

    if args.valid_full_res:
        valid_preprocessor_full_res = preprocessing.get_preprocessor(
            depth_mean=depth_stats['mean'],
            depth_std=depth_stats['std'],
            depth_mode=depth_mode,
            phase='test'
        )

    valid_data = Dataset(
        data_dir=args.dataset_dir,
        split=valid_set,
        depth_mode=depth_mode,
        with_input_orig=with_input_orig,
        **dataset_kwargs
    )

    valid_data.preprocessor = valid_preprocessor

    if args.dataset_dir is None:
        # no path to the actual data was passed -> we cannot create dataloader,
        # return the valid dataset and preprocessor object for inference only
        if args.valid_full_res:
            return valid_data, valid_preprocessor_full_res
        else:
            return valid_data, valid_preprocessor

    # create the data loaders
    train_loader = DataLoader(train_data,
                              batch_size=args.batch_size,
                              num_workers=args.workers,
                              drop_last=True,
                              shuffle=True)

    # for validation we can use higher batch size as activations do not
    # need to be saved for the backwards pass
    batch_size_valid = args.batch_size_valid or args.batch_size
    valid_loader = DataLoader(valid_data,
                              batch_size=batch_size_valid,
                              num_workers=args.workers,
                              shuffle=False)

    if args.valid_full_res:
        valid_loader_full_res = copy.deepcopy(valid_loader)
        valid_loader_full_res.dataset.preprocessor = valid_preprocessor_full_res
        return train_loader, valid_loader, valid_loader_full_res

    return train_loader, valid_loader
#####################################################################################################################
    if args.valid_full_res:
        train_loader, valid_loader, valid_loader_full_res = data_loaders
    else:
        train_loader, valid_loader = data_loaders
        valid_loader_full_res = None

    cameras = train_loader.dataset.cameras
    n_classes_without_void = train_loader.dataset.n_classes_without_void
    if args.class_weighting != 'None':
        class_weighting = train_loader.dataset.compute_class_weights(
            weight_mode=args.class_weighting,
            c=args.c_for_logarithmic_weighting)
    else:
        class_weighting = np.ones(n_classes_without_void)

    # model building -----------------------------------------------------------
    

你可能感兴趣的:(paper代码,python,深度学习,pytorch,神经网络)