[高光谱] (6w字巨详细) GitHub开源项目Hyperspectral-Classification的解析

文章目录

  • 项目简介
  • 项目各模块和函数的解析
  • utils.py
      • get_device(ordinal)
          • 功能:
          • 输入和输出:
            • 输入:
            • 输出:
          • 代码:
          • 解析:
      • open_file(dataset)
          • 功能:
          • 输入和输出:
            • 输入:
            • 输出
          • 代码:
          • 解析:
      • convert_to_color_()
          • 功能:
          • 输入和输出:
            • 输入:
            • 输出:
          • 代码:
          • 解析:
      • convert_from_color_()
          • 功能:
          • 输入和输出:
            • 输入:
            • 输出:
          • 代码:
          • 解析:
      • display_predictions()
          • 功能:
          • 输入和输出:
            • 输入:
            • 输出:
          • 代码:
          • 解析:
      • display_dataset()
          • 功能:
          • 输入和输出:
            • 输入:
            • 输出:
          • 代码:
          • 解析:
      • explore_spectrums()
      • plot_spectrums()
      • build_dataset()
          • 功能:
          • 输入和输出:
            • 输入:
            • 输出:
          • 代码:
          • 解析:
      • get_random_pos()
          • 功能:
          • 输入和输出:
            • 输入:
            • 输出:
          • 代码:
          • 解析
      • sliding_window()
          • 功能:
          • 输入和输出:
            • 输入:
            • 输出:
          • 代码:
          • 解析:
      • count_sliding_window()
          • 功能:
          • 输入和输出:
            • 输入:
            • 输出:
          • 代码:
          • 解析:
      • grouper()
          • 功能:
          • 输入和输出:
            • 输入:
            • 输出:
          • 代码:
          • 解析:
      • metrics()
          • 功能:
          • 输入和输出:
            • 输入:
            • 输出:
          • 代码:
          • 解析:
            • 计算混淆矩阵
            • 计算分类准确率
            • 计算F1 score
            • 计算kappa系数
            • 返回结果
      • show_results()
          • 功能:
          • 输入和输出:
            • 输入:
            • 输出:
          • 代码:
          • 解析:
      • sample_gt()
          • 功能:
          • 输入和输出:
            • 输入:
            • 输出:
          • 代码:
          • 解析:
      • compute_imf_weights()
      • camel_to_snake()
  • module.py
      • _addindent()
      • class Module(object)
  • model.py
    • class Baseline(nn.Module)
      • 属性:
      • 方法:
        • weight_init()
        • __ init__()
        • forward(self, x)
    • class HuEtAl(nn.Module)
      • 属性:
      • 方法:
        • weight_init()
        • _get_final_flattened_size()
        • __ init__()
        • forward()
    • get_model()
        • 功能:
        • 输入和输出:
          • 输入:
          • 输出:
        • 代码和解析:
    • val()
        • 功能:
        • 输入和输出:
          • 输入:
          • 输出:
        • 代码和解析:
          • 函数定义
          • 初始化和预操作
          • 开始检测
            • 保持梯度
            • 选择device
            • 根据不同方式获取预测值
            • 统计accuracy和total
            • 返回 accuracy / total
    • save_model()
    • train()
        • 功能:
        • 输入和输出:
          • 输入:
          • 输出:
        • 代码和解析:
          • 函数的定义和信息
          • 损失函数的鲁棒性检测
          • 初始化部分变量
          • 训练网络
            • 开始epoch循环
            • 设置模型为训练模式
            • `epoch`中按`batch`训练
            • `data` `target` to `device`
            • 正向传播
            • 反向传播
            • 计算损失
            • 绘制Training loss和Validation accuracy曲线
        • vis.line
            • 迭代变量加一
            • 回收无用变量
          • 计算 avg_loss,val_accuracies,metric
          • Save the weights
    • test()
        • 功能:
        • 输入和输出:
          • 输入:
          • 输出:
        • 代码和解析:
          • 函数定义
          • 模型设置为test模式
          • 提取超参数
          • 初始化返回结果 probs
          • 计算迭代总数 iterations
          • 开始迭代
          • 提取数据:
          • 获取预测值 output
          • 统计结果
  • inference.py
  • datasets.py
      • DATASETS_CONFIG + 更新
          • 数据集配置
          • 更新数据集配置
      • class TqdmUpTo(tqdm)
      • get_dataset()
          • 输入:
          • 输出:
        • 代码和解析:
          • 初始化参数:
          • 下载数据集:
          • 读取数据集+预处理:
            • 数据集读取:
            • 处理NaN的情况:
            • Normalization 归一化:
            • 返回值:
      • class HyperX(torch.utils.data.Dataset)
        • __ init__(self, data, gt, **hyperparams):
          • 功能:
          • 输入和输出:
            • 输入:
            • 输出:
          • 代码和解析:
            • 读取img、gt和超参数
            • 监督方式:
            • 获取索引:
        • flip(*arrays)
          • 功能:
          • 输入和输出:
            • 输入:
            • 输出:
          • 代码和解析:
        • radiation_noise()
          • 功能:
          • 输入和输出:
            • 输入:
            • 输出:
          • 代码和解析:
        • mixture_noise()
        • __ len__()
        • __ getitem__()
          • 功能:
          • 输入和输出:
            • 输入:
            • 输出:
          • 代码和解析:
            • 获取图像块:
            • 数据增强:
            • `data`和`label`转为**ndarray**类型:
            • **ndarray**转**tensor**:
            • Extract the center label if needed:
            • 返回值:
            • 数据增强:
            • `data`和`label`转为**ndarray**类型:
            • **ndarray**转**tensor**:
            • Extract the center label if needed:
            • 返回值:

GitHub链接: Hyperspectral-Classification Pytorch。

项目简介

项目的作者是Xidian university,是基于PyTorch的高光谱图像地物目标的分类程序。该项目兼容Python 2.7和Python 3.5+,基于PyTorch深度学习和GPU计算框架,并使用Visdom可视化服务器。

预定义的公开的数据集有:

  • 帕维亚大学
  • 帕维亚中心
  • 肯尼迪航天中心
  • 印度松树
  • 博茨瓦纳

用户也可添加自定义的数据集,示例是“数据融合大赛2018的高光谱数据集”DFC2018_HSI。开发人员应该为CUSTOM_DATASETS_CONFIG变量添加一个新条目,并为其用例定义特定的数据加载器。

该工具实现了scikit-learn库中的几个SVM变体以及PyTorch中实现的许多最先进的深度网络:

  • SVM(带网格搜索的线性,RBF和多核)
  • SGD(使用随机梯度下降的线性SVM进行快速优化)
    基线神经网络(4个完全连接的层,有丢失)
  • 1D CNN(用于高光谱图像分类的深度卷积神经网络,Hu等人,Journal of Sensors 2015)
  • 半监督的1D CNN(Autoencodeurs pour la visualization d’images hyperspectrales,Boulch et al。,GRETSI 2017)
  • 2D CNN(用于图像分类和频带选择的高光谱CNN,应用于人脸识别,Sharma等,技术报告2018)
  • 半监督2D CNN(用于高光谱图像分类的半监督卷积神经网络,Liu等,遥感信函2017)
  • 3D CNN(用于遥感图像分类的三维深度学习方法,Hamida等,TGRS 2018)
  • 3D FCN(基于上下文深度CNN的高光谱分类,Lee和Kwon,IGARSS 2016)
  • 3D CNN(基于卷积神经网络的深度特征提取和高光谱图像分类,Chen等,TGRS 2016)
  • 3D CNN(三维卷积神经网络的高光谱图像的光谱 - 空间分类,Li等,遥感2017)
  • 3D CNN(HSI-CNN:用于高光谱图像的新型卷积神经网络,Luo等,ICPR 2018)
  • 多尺度3D CNN(用于高光谱图像分类的多尺度3D深度卷积神经网络,He等,ICIP 2017)

用户也可以通过修改models.py文件来添加自定义深层网络。这意味着为自定义深层网络创建一个新类并更改该get_model功能。


项目各模块和函数的解析

utils.py


get_device(ordinal)

功能:

根据输入参数,判断device为CPU或GPU。

输入和输出:
输入:
  • ordinal:一个int类型的数,表示用哪个GPU
输出:
  • device:一个超参数,表示运算的位置(CPU or GPU)
代码:
def get_device(ordinal):
    # Use GPU ?
    if ordinal < 0:
        print("Computation on CPU")
        device = torch.device('cpu')
    elif torch.cuda.is_available():
        print("Computation on CUDA GPU device {}".format(ordinal))
        device = torch.device('cuda:{}'.format(ordinal))
    else:
        print("/!\\ CUDA was requested but is not available! Computation will go on CPU. /!\\")
        device = torch.device('cpu')
    return device
解析:

其实就是一个简单的分支结构:

  • ordinal < 0:CPU
  • ordinal < 0orch.cuda.is_available() == True:GPU
  • ordinal < 0orch.cuda.is_available() == False:CPU

open_file(dataset)

功能:

打开指定的数据集的文件。

输入和输出:
输入:
  • dataset:数据集文件的完整路径,比如C:\Datasets\OwnData\OwnData.mat
输出

(以读取.mat为例,因为读取的以.mat文件居多):

  • 一个以变量名,以数据的字典dictionary
代码:
def open_file(dataset):
    _, ext = os.path.splitext(dataset)
    ext = ext.lower()
    if ext == '.mat':
        # Load Matlab array
        return io.loadmat(dataset)
    elif ext == '.tif' or ext == '.tiff':
        # Load TIFF file
        return misc.imread(dataset)
    elif ext == '.hdr':
        img = spectral.open_image(dataset)
        return img.load()
    else:
        raise ValueError("Unknown file format: {}".format(ext))
解析:

最重要的是 _, ext = os.path.splitext(dataset)中的os.path.splitext(path)函数。
该函数将输入的路径path拆分为文件名 + 扩展名,并依次作为返回值。_, ext表示只获取扩展名,存入变量ext。之后就是根据不同的扩展名选择不同的打开方式。

需要注意的是,打开.mat文件,返回值是一个以变量名,以数据的字典dictionary。要取出其中的数据,需要通过字典操作,通过访问来获取,比如img = open_file(folder + 'OwnData.mat')['Data']


convert_to_color_()

功能:

将标签数组转换为RGB颜色编码图像。

输入和输出:
输入:
  • arr_2d: int类型的二维的标签数组(int 2D array of labels)
  • palette: 每个标签对应的RGB元组,三个值(dict of colors used (label number -> RGB tuple) )
输出:
  • int RGB格式的彩色编码标签的2D图像(int 2D images of color-encoded labels in RGB format)
代码:
def convert_to_color_(arr_2d, palette=None):
    """Convert an array of labels to RGB color-encoded image.

    Args:
        arr_2d: int 2D array of labels
        palette: dict of colors used (label number -> RGB tuple)    # 哪个标签对应什么样的颜色(RGB三个值)

    Returns:
        arr_3d: int 2D images of color-encoded labels in RGB format     # RGB三通道图像

    """
    arr_3d = np.zeros((arr_2d.shape[0], arr_2d.shape[1], 3), dtype=np.uint8)    # 确定维度和编码方式,行列为arr_2d的行列,编码方式为uint8
    # 异常报错
    if palette is None:
        raise Exception("Unknown color palette")

    for c, i in palette.items():
        m = arr_2d == c
        arr_3d[m] = i

    return arr_3d
解析:

(暂略)

convert_from_color_()

功能:

RGB编码图像转换为灰度标签。

输入和输出:
输入:
  • arr_3d: int 2D image of color-coded labels on 3 channels
  • palette:dict of colors used (RGB tuple -> label number)
输出:
  • arr_2d: int 2D array of labels
代码:
def convert_from_color_(arr_3d, palette=None):
    """Convert an RGB-encoded image to grayscale labels.

    Args:
        arr_3d: int 2D image of color-coded labels on 3 channels
        palette: dict of colors used (RGB tuple -> label number)

    Returns:
        arr_2d: int 2D array of labels

    """
    if palette is None:
        raise Exception("Unknown color palette")

    arr_2d = np.zeros((arr_3d.shape[0], arr_3d.shape[1]), dtype=np.uint8)

    for c, i in palette.items():
        m = np.all(arr_3d == np.array(c).reshape(1, 1, 3), axis=2)
        arr_2d[m] = i

    return arr_2d
解析:

(暂略)

display_predictions()

功能:

使用visdom可视化服务来可视化预测结果。

输入和输出:
输入:
  • pred:预测结果,二维
  • vis:vis服务
  • gt:ground truth
  • caption:图表名称
输出:
  • visdom服务器网址显示图表
代码:
def display_predictions(pred, vis, gt=None, caption=""):        # caption 字幕
    if gt is None:
        vis.images([np.transpose(pred, (2, 0, 1))],
                    opts={'caption': caption})
    else:
        vis.images([np.transpose(pred, (2, 0, 1)),
                    np.transpose(gt, (2, 0, 1))],
                    nrow=2,
                    opts={'caption': caption})
解析:

函数整体是一个简单的分支结构,分为gt is Nonegt is not None两种情况。

gt is None时:
vis.images()函数绘制一个列表images。它需要一个输入B x C x H x WC->channelH->heightW->width)张量或list of images全部相同的大小。它使大小的图像(B / Nrow,Nrow)的网格。

vis.images()的可调参数如下:

  • nrow:连续的图像数量
  • padding:在图像周围填充,四边均匀填充
  • opts.jpgquality:JPG质量(number0-100;默认= 100)
  • opts.caption:图像的标题

所以vis.images([np.transpose(pred, (2, 0, 1))], opts={'caption': caption})中主要有两部分

  • [np.transpose(pred, (2, 0, 1))]表示要可视化的图。
  • opts={'caption': caption}表示可选的操作。

np.transpose()是交换矩阵维度的数组(详见另一篇博客维度交换函数——a.transpose(m,n,r)),因为原始的图像的维度排序默认是H×W×C,而vis.images()要求的是C×H×W。所以张量维度的维度排序就要从(0,1,2)变为(2,0,1),而这通过np.transpose()函数实现。

opts={'caption': caption}就是给图加标题。

gt is not None时:
predgt都要通过np.transpose()函数来进行维度交换,另外参数nrow需要指定为2

display_dataset()

功能:

选择3个波段作为RGB波段,显示RGB合成图像。

输入和输出:
输入:
  • img: 3D hyperspectral image
  • gt: 2D array labels
  • bands: tuple of RGB bands to select
  • labels: list of label class names
  • palette: dict of colors
  • display (optional): type of display, if any

但是,只有imgbands这两个变量被用到了,其他4个变量都没有用到。

输出:
  • visdom服务器网址显示图表
代码:
def display_dataset(img, gt, bands, labels, palette, vis):
    """Display the specified dataset.

    Args:
        img: 3D hyperspectral image
        gt: 2D array labels
        bands: tuple of RGB bands to select
        labels: list of label class names
        palette: dict of colors
        display (optional): type of display, if any

    """
    print("Image has dimensions {}x{} and {} channels".format(*img.shape))
    rgb = spectral.get_rgb(img, bands)          # 从SpyFile对象或numpy数组中提取RGB数据以供显示。
    rgb /= np.max(rgb)                          # 最大值化处理
    rgb = np.asarray(255 * rgb, dtype='uint8')  # 转为ndarray类型

    # Display the RGB composite image       显示RGB合成图像
    caption = "RGB (bands {}, {}, {})".format(*bands)       # *来拆解变量
    # send to visdom server
    vis.images([np.transpose(rgb, (2, 0, 1))],
                opts={'caption': caption})
解析:

首先通过rgb = spectral.get_rgb(img, bands)来获取img中的指定的波段bands,来作为RGB波段。然后最大值化处理rgb /= np.max(rgb) ,将数值放缩到[0,1]之间。然后通过rgb = np.asarray(255 * rgb, dtype='uint8')来将rgb放缩到[0,255]之间,同时设定dtype=‘uint8’,即uint8编码的RGB图像

之后就是visdom server的操作,先设置标题 caption = "RGB (bands {}, {}, {})".format(*bands),其中format(*bands)通过*将列表类型(我猜的)的bands拆解,分别输出。然后调用vis.images()来将rgb可视化,参数解析见上面的display_predictions()函数部分。


explore_spectrums()

(暂略)


plot_spectrums()

(暂略)


build_dataset()

功能:

根据图像和蒙版创建训练样本列表。

输入和输出:
输入:
  • mat: 3D hyperspectral matrix to extract the spectrums from # 用来提取光谱的高光谱矩阵
  • gt: 2D ground truth
  • ignored_labels (optional): list of classes to ignore, e.g. 0 to remove
输出:
  • 根据图像和蒙版创建训练样本列表(Create a list of training samples based on an image and a mask.)
代码:
ef build_dataset(mat, gt, ignored_labels=None):
    """Create a list of training samples based on an image and a mask.

    Args:
        mat: 3D hyperspectral matrix to extract the spectrums from      # 用来提取光谱的高光谱矩阵
        gt: 2D ground truth
        ignored_labels (optional): list of classes to ignore, e.g. 0 to remove
        unlabeled pixels
        return_indices (optional): bool set to True to return the indices of
        the chosen samples

    """
    samples = []
    labels = []
    # Check that image and ground truth have the same 2D dimensions
    assert mat.shape[:2] == gt.shape[:2]    # 检查维度是否相符,比如PaviaU的mat和gt都是(610, 340)

    for label in np.unique(gt):
        if label in ignored_labels:
            continue
        else:
            indices = np.nonzero(gt == label)       # 返回同一类标签的全部索引。(对gt每个元素判断是否为label,是的话为1否则为0,然后提取全部的非零元素的索引
            samples += list(mat[indices])
            labels += len(indices[0]) * [label]
    return np.asarray(samples), np.asarray(labels)
解析:

首先检查数组维度是否相同,通过assert关键字实现,其中assert condition 等于 if not condition: raise AssertionError()
assert mat.shape[:2] == gt.shape[:2]来检查matgt的前两个维度是否相同。
mat.shape[:2]为提取mat数组的前两个维度。这里专门强调一个事情,通过open_file()读取得到的tensor,维度的排序是H × W × Channel

np.unique()函数返回值为The sorted unique values,类型为ndarray
之后遍历np.unique()的返回值,通过np.nonzero(gt == label)获取每次遍历的gtgt == label的元素的索引,返回为indices

之后将mat中对应索引的元素通过samples += list(mat[indices])来扩充到samples中。

下面通过一个简单的实例来说明:

import random
import numpy as np

mat = np.array([[0,0,0,0,0],[0,100,200,300,0],[0,200,300,200,0],[0,300,200,100,0],[0,0,0,0,0]])
gt = np.array([[0,0,0,0,0],[0,1,2,3,0],[0,2,3,2,0],[0,3,2,1,0],[0,0,0,0,0]])
ignored_labels = [0]
samples = []
labels = []

# Check that image and ground truth have the same 2D dimensions
assert mat.shape[:2] == gt.shape[:2]    # 检查维度是否相符,比如PaviaU的mat和gt都是(610, 340)

for label in np.unique(gt):
    if label in ignored_labels:
        continue
    else:
        indices = np.nonzero(gt == label)       # 返回同一类标签的全部索引。(对gt每个元素判断是否为label,是的话为1否则为0,然后提取全部的非零元素的索引
        samples += list(mat[indices])
        labels += len(indices[0]) * [label]
print(mat)
# [[  0   0   0   0   0]
#  [  0 100 200 300   0]
#  [  0 200 300 200   0]
#  [  0 300 200 100   0]
#  [  0   0   0   0   0]]
print(gt)
# [[0 0 0 0 0]
#  [0 1 2 3 0]
#  [0 2 3 2 0]
#  [0 3 2 1 0]
#  [0 0 0 0 0]]
print(samples)
# [100, 100, 200, 200, 200, 200, 300, 300, 300]
print(labels)
# [1, 1, 2, 2, 2, 2, 3, 3, 3]

get_random_pos()

功能:

随机返回输入图像的一个corner(Return the corners of a random window in the input image)

输入和输出:
输入:
  • img: 2D (or more) image, e.g. RGB or grayscale image
  • window_shape: (width, height) tuple of the window
输出:
  • xmin, xmax, ymin, ymax: tuple of the corners of the window

    代表corner位置的两个点(左下角和右上角),表现为两个参数。

代码:
def get_random_pos(img, window_shape):
    """ Return the corners of a random window in the input image

    Args:
        img: 2D (or more) image, e.g. RGB or grayscale image
        window_shape: (width, height) tuple of the window

    Returns:
        xmin, xmax, ymin, ymax: tuple of the corners of the window

    """
    # 思路:先随机找到一个点,然后在此基础上加上网格的w和h两个点构成一个网格
    w, h = window_shape
    W, H = img.shape[:2]        # 获取img的前两个维度
    x1 = random.randint(0, W - w - 1)       # 前闭后闭区间内产生随机数
    x2 = x1 + w
    y1 = random.randint(0, H - h - 1)
    y2 = y1 + h
    return x1, x2, y1, y2
解析

先通过 w, h = window_shape从输入元组中提取widthheight,再通过 W, H = img.shape[:2]获取img的前两个维度WH

然后生成左下角点的位置,所用函数是random.randint(),在前闭后闭区间产生随机数。W维度的随机数的范围是(0, W - w - 1)H维度同理。

将左下角(x1,y1)分别加上wh,则得到右下角(x2,y2)。这样就表示了一个corner。

最后将xmin, xmax, ymin, ymax这4个作为函数的返回值返回。


sliding_window()

功能:

生成在输入图像上滑动窗口生成器(Sliding window generator over an input image)

输入和输出:
输入:
  • image: 2D+ image to slide the window on, e.g. RGB or hyperspectral
  • step: int stride of the sliding window
  • window_size: int tuple, width and height of the window
  • with_data (optional): bool set to True to return both the data and the corner indices
输出:

with_data 为真时,返回image[x:x + w, y:y + h], x, y, w, h, 即窗口的数据和窗口的位置参数。当with_data 为假时,返回x, y, w, h,即仅仅返回窗口的位置参数。

代码:
def sliding_window(image, step=10, window_size=(20, 20), with_data=True):
    """Sliding window generator over an input image.        # 在输入图像上滑动窗口生成器

    Args:
        image: 2D+ image to slide the window on, e.g. RGB or hyperspectral
        step: int stride of the sliding window
        window_size: int tuple, width and height of the window
        with_data (optional): bool set to True to return both the data and the
        corner indices
    Yields:
        ([data], x, y, w, h) where x and y are the top-left corner of the
        window, (w,h) the window size

    """
    # slide a window across the image
    w, h = window_size
    W, H = image.shape[:2]
    offset_w = (W - w) % step
    offset_h = (H - h) % step
    for x in range(0, W - w + offset_w, step):
        if x + w > W:
            x = W - w
        for y in range(0, H - h + offset_h, step):
            if y + h > H:
                y = H - h

            if with_data:
                yield image[x:x + w, y:y + h], x, y, w, h
            else:
                yield x, y, w, h
解析:

通过 w, h = window_sizeW, H = image.shape[:2]分别获得窗口大小的参数和图像的尺寸。然后定义offset_woffset_h使得窗口在合适的范围滑动。

关键字yield来创建一个生成器。

yield的函数是一个生成器,而不是一个函数了,这个生成器有一个函数就是next函数,next就相当于“下一步”生成哪个数,这一次的next开始的地方是接着上一次的next停止的地方执行的。所以调用next的时候,生成器并不会从该函数的开始执行,只是接着上一步停止的地方开始,然后遇到yield后,return出要生成的数,此步就结束。

对于yield的详细解释见python中yield的用法详解——最简单,最清晰的解释。

对于这个函数来说,每调用一次这个函数,窗口就会从上一次的位置滑动一个步长。


count_sliding_window()

功能:

计算图像中的窗口数(Count the number of windows in an image.)

输入和输出:
输入:
  • image: 2D+ image to slide the window on, e.g. RGB or hyperspectral, …
  • step: int stride of the sliding window
  • window_size: int tuple, width and height of the window
输出:
  • int number of windows
代码:

def count_sliding_window(top, step=10, window_size=(20, 20)):
“”" Count the number of windows in an image. # 计算图像中的窗口数

def count_sliding_window(top, step=10, window_size=(20, 20)):
    """ Count the number of windows in an image.        # 计算图像中的窗口数

    Args:
        image: 2D+ image to slide the window on, e.g. RGB or hyperspectral, ...
        step: int stride of the sliding window
        window_size: int tuple, width and height of the window
    Returns:
        int number of windows
    """
    sw = sliding_window(top, step, window_size, with_data=False)
    return sum(1 for _ in sw)
解析:

先通过调用sliding_window()函数得到window的集合sw,然后遍历sw每遍历一次返回值+1


grouper()

功能:

Browse an iterable by grouping n elements by n elements.

输入和输出:
输入:
  • n: int, size of the groups
  • iterable: the iterable to Browse
输出:
  • chunk of n elements from the iterable
代码:
def grouper(n, iterable):       # 分组器?
    """ Browse an iterable by grouping n elements by n elements.        # 通过n个元素对n个元素进行分组来浏览iterable

    Args:
        n: int, size of the groups
        iterable: the iterable to Browse    迭代
    Yields:
        chunk of n elements from the iterable       可迭代的n个元素块

    """
    it = iter(iterable)
    while True:
        chunk = tuple(itertools.islice(it, n))
        if not chunk:
            return
        yield chunk
解析:

(暂略)


metrics()

功能:

计算并打印指标,包括准确率混淆矩阵F1分数(Compute and print metrics (accuracy, confusion matrix and F1 scores).)

输入和输出:
输入:
  • prediction: list of predicted labels
  • target: list of target labels
  • ignored_labels (optional): list of labels to ignore, e.g. 0 for undef
  • n_classes (optional): number of classes, max(target) by default
输出:
  • accuracy
  • F1 score by class
  • confusion matrix
代码:
def metrics(prediction, target, ignored_labels=[], n_classes=None):         # 输出指标
    """Compute and print metrics (accuracy, confusion matrix and F1 scores).

    Args:
        prediction: list of predicted labels
        target: list of target labels
        ignored_labels (optional): list of labels to ignore, e.g. 0 for undef
        n_classes (optional): number of classes, max(target) by default
    Returns:
        accuracy, F1 score by class, confusion matrix
    """
    ignored_mask = np.zeros(target.shape[:2], dtype=np.bool)
    for l in ignored_labels:
        ignored_mask[target == l] = True
    ignored_mask = ~ignored_mask
    target = target[ignored_mask]
    prediction = prediction[ignored_mask]

    results = {}

    n_classes = np.max(target) + 1 if n_classes is None else n_classes

    cm = confusion_matrix(
        target,
        prediction,
        labels=range(n_classes))

    results["Confusion matrix"] = cm

    # Compute global accuracy
    total = np.sum(cm)
    accuracy = sum([cm[x][x] for x in range(len(cm))])
    accuracy *= 100 / float(total)

    results["Accuracy"] = accuracy

    # Compute F1 score
    F1scores = np.zeros(len(cm))
    for i in range(len(cm)):
        try:
            F1 = 2. * cm[i, i] / (np.sum(cm[i, :]) + np.sum(cm[:, i]))
        except ZeroDivisionError:
            F1 = 0.
        F1scores[i] = F1

    results["F1 scores"] = F1scores

    # Compute kappa coefficient
    pa = np.trace(cm) / float(total)
    pe = np.sum(np.sum(cm, axis=0) * np.sum(cm, axis=1)) / \
        float(total * total)
    kappa = (pa - pe) / (1 - pe)
    results["Kappa"] = kappa

    return results
解析:

(这里要说一下,我对predictiontarget的数据的组织形式不明确,主要是有np.max(target)这个代码存在)

ignored_mask部分是创造了一个蒙版(mask),目的是不再考虑标签为ignored_labels的部分。

results = {}将输出结果定义为字典dictionary类型,作为函数的返回值。一开始的时候只是将results作为空字典,然后逐步增加键值对

计算混淆矩阵

cm = confusion_matrix(target, prediction, labels=range(n_classes))调用confusion_matrix()函数(该函数详见 python sklearn 计算混淆矩阵 confusion_matrix()函数)。简单来说,这句代码通过target, prediction, labels这3个参数,来计算得到array类型的混淆矩阵,并将结果返回给cm

之后results["Confusion matrix"] = cm这句代码,在字典类型的result中加入"Confusion matrix": cm键值对

计算分类准确率

先通过total = np.sum(cm)来计算总样本数total。通过对混淆矩阵cm的每一个元素求和,总和即为总样本数total

再计算分类准确的样本数accuracy。混淆矩阵cm的对角线的元素值的总和,即为分类准确的样本数accuracy

二者比值即为最后的准确率accuracyaccuracy *= 100 / float(total)

import numpy as np
cm = np.array([[1,1,1],[2,2,2],[3,3,3]])
print(cm)
# [[1 1 1]
#  [2 2 2]
#  [3 3 3]]
print(len(cm))
# 3
print(range(len(cm)))
# range(0, 3)
print('\n')
for i in range(len(cm)):
    print(i)
# 0
# 1
# 2

之后results["Accuracy"] = accuracy这句代码,在字典类型的result中加入"Accuracy": accuracy键值对

计算F1 score

F1分数(F1 Score),是统计学中用来衡量分类模型精确度的一种指标。它同时兼顾了分类模型的精确率Accuracy召回率Recall Rate 。F1分数可以看作是模型精确率和召回率的一种调和平均

这部分代码直接调用就行,套公式而已。

计算kappa系数

也是一种指标,套公式。(暂略)

返回结果

将字典dictionary类型的result作为函数的返回值。


show_results()

功能:

在visdom界面以文本形式输出结果。

输入和输出:
输入:
  • results:字典dictionary类型,包含Confusion matrixAccuracyF1 scoresKappa四个key
  • visvisdom可视化服务。
  • label_values:默认为None
  • agregated:默认为False
输出:
  • textvisdom输出的文本。
代码:
def show_results(results, vis, label_values=None, agregated=False):         # 可视化模块
    text = ""
    # if agregated部分没看懂要干啥
    if agregated:
        accuracies = [r["Accuracy"] for r in results]
        kappas = [r["Kappa"] for r in results]
        F1_scores = [r["F1 scores"] for r in results]

        F1_scores_mean = np.mean(F1_scores, axis=0)
        F1_scores_std = np.std(F1_scores, axis=0)
        cm = np.mean([r["Confusion matrix"] for r in results], axis=0)
        text += "Agregated results :\n"
    else:
        cm = results["Confusion matrix"]
        accuracy = results["Accuracy"]
        F1scores = results["F1 scores"]
        kappa = results["Kappa"]

    vis.heatmap(cm, opts={'title': "Confusion matrix", 
                          'marginbottom': 150,
                          'marginleft': 150,
                          'width': 500,
                          'height': 500,
                          'rownames': label_values, 'columnnames': label_values})
    text += "Confusion matrix :\n"
    text += str(cm)
    text += "---\n"

    if agregated:
        text += ("Accuracy: {:.03f} +- {:.03f}\n".format(np.mean(accuracies),
                                                         np.std(accuracies)))
    else:
        text += "Accuracy : {:.03f}%\n".format(accuracy)
    text += "---\n"

    text += "F1 scores :\n"
    if agregated:
        for label, score, std in zip(label_values, F1_scores_mean,
                                     F1_scores_std):
            text += "\t{}: {:.03f} +- {:.03f}\n".format(label, score, std)
    else:
        for label, score in zip(label_values, F1scores):
            text += "\t{}: {:.03f}\n".format(label, score)
    text += "---\n"

    if agregated:
        text += ("Kappa: {:.03f} +- {:.03f}\n".format(np.mean(kappas),
                                                      np.std(kappas)))
    else:
        text += "Kappa: {:.03f}\n".format(kappa)

    vis.text(text.replace('\n', '
'
)) print(text)
解析:

整体思路是遍历result的键值对,然后通过text += XXX来扩充text,最后在visdom上打印result

由于不知道agregated是在干啥,而且默认是False,所以只考虑agregated = false的情况。

首先,通过访问字典result的键,来获取对应键的值。

然后通过vis.heatmap()函数来绘制一个热图,它需要输入NxM张量X来指定热图中每个位置的值,此处为cm。设置title'title': "Confusion matrix",尺寸:'marginbottom': 150, 'marginleft': 150, 'width': 500, 'height': 500,行列标签:'rownames': label_values, 'columnnames': label_values

对于cm,先通过str(cm)将cm转为字符串类型,然后通过+=扩充到text中。

对于Accuracy也是一样扩充到text中,text += "Accuracy : {:.03f}%\n".format(accuracy)

对于F1scores,需要对应的label_values。通过zip(label_values, F1scores)来将可迭代对象label_valuesF1scores的对应元素组成元组,并以对象的形式返回(zip()函数详见:Python zip() 函数)。之后通过for循环遍历zip()返回的对象,同时扩充textfor label, score in zip(label_values, F1scores): text += "\t{}: {:.03f}\n".format(label, score)

对于Kappa就是简单的扩充到text中,text += "Kappa: {:.03f}\n".format(kappa)

vis.text()函数的功能是在一个盒子里打印文本。可以使用它来嵌入任意的HTML。它需要输入一个text字符串。opts目前没有具体的支持。text.replace('\n', '
)
用来实现换行符的替换。vis.text(text.replace('\n', '
'))


sample_gt()

功能:

从标签数组gt中提取固定百分比的样本(Extract a fixed percentage of samples from an array of labels)。

需要强调的是,被分割为训练集和测试集的样本,不包括类别为ignored_labels的sample。被分割的只是有效的sample。

输入和输出:
输入:
  • gt: a 2D array of int labels
  • percentage: [0, 1] float
输出:
  • train_gt:2D arrays of int labels
  • test_gt:2D arrays of int labels
代码:
def sample_gt(gt, train_size, mode='random'):
    """Extract a fixed percentage of samples from an array of labels.   从标签数组中提取固定百分比的样本。

    Args:
        gt: a 2D array of int labels
        percentage: [0, 1] float
    Returns:
        train_gt, test_gt: 2D arrays of int labels

    """
    indices = np.nonzero(gt)
    X = list(zip(*indices)) # x,y features
    y = gt[indices].ravel() # classes
    train_gt = np.zeros_like(gt)
    test_gt = np.zeros_like(gt)
    if train_size > 1:
       train_size = int(train_size)
    
    if mode == 'random':
       train_indices, test_indices = sklearn.model_selection.train_test_split(X, train_size=train_size, stratify=y)
       train_indices = [list(t) for t in zip(*train_indices)]
       test_indices = [list(t) for t in zip(*test_indices)]
       train_gt[train_indices] = gt[train_indices]
       test_gt[test_indices] = gt[test_indices]
    elif mode == 'fixed':
       print("Sampling {} with train size = {}".format(mode, train_size))
       train_indices, test_indices = [], []
       for c in np.unique(gt):
           if c == 0:
              continue
           indices = np.nonzero(gt == c)
           X = list(zip(*indices)) # x,y features

           train, test = sklearn.model_selection.train_test_split(X, train_size=train_size)
           train_indices += train
           test_indices += test
       train_indices = [list(t) for t in zip(*train_indices)]
       test_indices = [list(t) for t in zip(*test_indices)]
       train_gt[train_indices] = gt[train_indices]
       test_gt[test_indices] = gt[test_indices]

    elif mode == 'disjoint':
        train_gt = np.copy(gt)
        test_gt = np.copy(gt)
        for c in np.unique(gt):
            mask = gt == c
            for x in range(gt.shape[0]):
                first_half_count = np.count_nonzero(mask[:x, :])
                second_half_count = np.count_nonzero(mask[x:, :])
                try:
                    ratio = first_half_count / second_half_count
                    if ratio > 0.9 * train_size and ratio < 1.1 * train_size:
                        break
                except ZeroDivisionError:
                    continue
            mask[:x, :] = 0
            train_gt[mask] = 0

        test_gt[train_gt > 0] = 0
    else:
        raise ValueError("{} sampling is not implemented yet.".format(mode))
    return train_gt, test_gt
解析:

indices = np.nonzero(gt)获取gt中非零的元素的索引,返回值为两个array数组构成的元组,分别表示xy方向的索引。

X = list(zip(*indices))中,首先通过*来将np.nonzero(gt)的返回值拆解为两个array,然后通过zip()函数来将*indices拆解得到的两个array的对应元素组成元组,并以对象的形式返回,然后通过list()将类型转为列表list类型。y = gt[indices].ravel()这一句先通过gt[indices]根据索引indices取得相应的元素,然后通过ravel()展开成一维数组。

这一部分代码的功能可以通过这个demo表示 :

gt = np.array([[0,0,0,0],[0,1,2,0],[0,3,4,0],[0,0,0,0]])
print(gt)
# [[0 0 0 0]
#  [0 1 2 0]
#  [0 3 4 0]
#  [0 0 0 0]]
indices = np.nonzero(gt)
X = list(zip(*indices))  # x,y features  (x,y)形式的索引
y = gt[indices].ravel()  # classes
print(X)
# [(1, 1), (1, 2), (2, 1), (2, 2)]
print(type(X))
# 
print(y)
# [1 2 3 4]
print(type(y))
# 

train_gt = np.zeros_like(gt)test_gt = np.zeros_like(gt)train_gt和test_gt初始化为全0的,与gt维度相同的数组。

由于默认的moderandom,所以这里只解析mode == random的情况。

train_indices, test_indices = sklearn.model_selection.train_test_split(X, train_size=train_size, stratify=y)这一句代码主要是调用sklearn.model_selection.train_test_split()函数,用来将数据集划分成训练集和测试集。train_indicestest_indices是划分的结果,为元素的索引,形式是这种:[(4, 2), (3, 3), (2, 2), (3, 2), (1, 1), (4, 3)]

train_indices = [list(t) for t in zip(*train_indices)]这一句是转换train_indices的表示形式,变为这种形式:[[4, 3, 2, 3, 1, 4], [2, 3, 2, 2, 1, 3]]

train_gt[train_indices] = gt[train_indices]这一句是从gt中提取训练集的样本。其他非训练集样本保持初始化的0不变,作为ignored_labels

下面附上一个小demo,来帮助更好地理解:

import numpy as np
import sklearn.model_selection
gt = np.array([[0,0,0,0],[0,1,2,0],[0,3,4,0],[1,2,3,4],[4,3,3,2],[0,0,0,0]])
print(gt)
# [[0 0 0 0]
#  [0 1 2 0]
#  [0 3 4 0]
#  [1 2 3 4]
#  [4 3 3 2]
#  [0 0 0 0]]
indices = np.nonzero(gt)
X = list(zip(*indices))  # x,y features  (x,y)形式的索引
y = gt[indices].ravel()  # classes
print(X)
# [(1, 1), (1, 2), (2, 1), (2, 2), (3, 0), (3, 1), (3, 2), (3, 3), (4, 0), (4, 1), (4, 2), (4, 3)]
print(type(X))
# 
print(y)
# [1 2 3 4 1 2 3 4 4 3 3 2]
print(type(y))
# 
train_gt = np.zeros_like(gt)
test_gt = np.zeros_like(gt)
train_indices, test_indices = sklearn.model_selection.train_test_split(X, train_size=0.5, stratify=y)
print(train_indices)
# [(4, 2), (3, 3), (2, 2), (3, 2), (1, 1), (4, 3)]
print(test_indices)
# [(1, 2), (3, 1), (2, 1), (4, 0), (3, 0), (4, 1)]
print('________________________')
y_train = []
for i, j in train_indices:
    y_train.append(gt[i][j])
y_test = []
for i, j in test_indices:
    y_test.append(gt[i][j])
print(y_train)
# [3, 4, 2, 1, 2, 3]
print(y_test)
# [2, 3, 3, 4, 4, 1]
print('________________________')
train_indices = [list(t) for t in zip(*train_indices)]
print(train_indices)
# [[4, 3, 2, 3, 1, 4], [2, 3, 2, 2, 1, 3]]
train_gt[train_indices] = gt[train_indices]
print(train_gt)
# [[0 0 0 0]
#  [0 1 0 0]
#  [0 0 4 0]
#  [0 0 3 4]
#  [0 0 3 2]
#  [0 0 0 0]]

compute_imf_weights()

(暂略)


camel_to_snake()

(暂略)


module.py

_addindent()

(暂略)


class Module(object)

这部分是所有神经网络模块的基类。

File is read-only。也就是最好不要修改。

所以(暂略)


model.py

class Baseline(nn.Module)

定义一个class,继承nn.Module

属性:

方法:

weight_init()

def weight_init(m):
    if isinstance(m, nn.Linear):        # 判断类型是否相同
        init.kaiming_normal_(m.weight)      # 一种权重初始化方法
        init.zeros_(m.bias)

用来初始化权重weight和偏置bias

首先通过if isinstance(m, nn.Linear)来看输入m是不是和nn.Linear是一类(这里是继承关系)(鲁棒性检验)。

通过init.kaiming_normal_(m.weight)来初始化权重weight,其中kaiming_normal_()是一种初始化权重的方法。

init.zeros_(m.bias)将偏置bias初始化为0。

__ init__()

def __init__(self, input_channels, n_classes, dropout=False):   # 类的属性的初始化
    super(Baseline, self).__init__()
    self.use_dropout = dropout
    if dropout:
        self.dropout = nn.Dropout(p=0.5)

    self.fc1 = nn.Linear(input_channels, 2048)
    self.fc2 = nn.Linear(2048, 4096)
    self.fc3 = nn.Linear(4096, 2048)
    self.fc4 = nn.Linear(2048, n_classes)

    self.apply(self.weight_init)

这一部分是对类的初始化,包括是否使用dropout(True or Flase)、网络的层数和in_channelout_channel。同时对网络的参数(weightbias)进行初始化,通过self.apply(self.weight_init)来实现。

forward(self, x)

def forward(self, x):
    x = F.relu(self.fc1(x))
    if self.use_dropout:
        x = self.dropout(x)
    x = F.relu(self.fc2(x))
    if self.use_dropout:
        x = self.dropout(x)
    x = F.relu(self.fc3(x))
    if self.use_dropout:
        x = self.dropout(x)
    x = self.fc4(x)
    return x 

这个方法定义了前向传播过程,流程如下:

输入
nn.Linear(input_channels, 2048)
relu()
dropout()
self.fc2 = nn.Linear(2048, 4096)
relu()
dropout()
nn.Linear(4096, 2048)
relu()
dropout()
self.fc4 = nn.Linear(2048, n_classes)
输出

class HuEtAl(nn.Module)

属性:

方法:

weight_init()

def weight_init(m):
    # [All the trainable parameters in our CNN should be initialized to
    # be a random value between −0.05 and 0.05.]
    # 我们CNN中的所有可训练参数应初始化为介于-0.05和0.05之间的随机值。
    if isinstance(m, nn.Linear) or isinstance(m, nn.Conv1d):
        init.uniform_(m.weight, -0.05, 0.05)
        init.zeros_(m.bias)

模型权重weight初始化为介于-0.05和0.05之间的随机值,偏置bias初始化为0。

_get_final_flattened_size()

def _get_final_flattened_size(self):    # 得到最终的扁平尺寸
    with torch.no_grad():
        x = torch.zeros(1, 1, self.input_channels)      # 生成一个1×1×input_channels的全0的tensor
        x = self.pool(self.conv(x))         # 先卷积,再池化
    return x.numel()        # numel()返回张量中的元素个数

首先设置with torch.no_grad()。当确定不会调用Tensor.backward()时,禁用 gradient calculation 对于 inference 非常有用。 它将减少计算的内存消耗,否则会有requires_grad = True

x = torch.zeros(1, 1, self.input_channels)生成一个1×1×input_channels的全0的tensor。

x = self.pool(self.conv(x)),先卷积,再池化。

return x.numel()返回张量中的元素个数。

至于为什么这么做,我现在没看很明白。

__ init__()

    def __init__(self, input_channels, n_classes, kernel_size=None, pool_size=None):
        super(HuEtAl, self).__init__()
        if kernel_size is None:
           # [In our experiments, k1 is better to be [ceil](n1/9)]
           kernel_size = math.ceil(input_channels / 9)
        if pool_size is None:
           # The authors recommand that k2's value is chosen so that the pooled features have 30~40 values
           # ceil(kernel_size/5) gives the same values as in the paper so let's assume it's okay
           pool_size = math.ceil(kernel_size / 5)
        self.input_channels = input_channels

        # [The first hidden convolution layer C1 filters the n1 x 1 input data with 20 kernels of size k1 x 1]
        # 第一隐藏卷积层C1用大小为k1×1的20个内核对n1×1输入数据进行滤波
        self.conv = nn.Conv1d(1, 20, kernel_size)
        self.pool = nn.MaxPool1d(pool_size)
        self.features_size = self._get_final_flattened_size()
        # [n4 is set to be 100]
        self.fc1 = nn.Linear(self.features_size, 100)
        self.fc2 = nn.Linear(100, n_classes)
        self.apply(self.weight_init)

这部分是初始化kernel_sizepool_size,同时读取input_channels,以及网络层的结构。

kernel_sizepool_size的选择以论文为依据。kernel_sizemath.ceil(input_channels / 9),即input_channels / 9的向上取整。pool_size = math.ceil(kernel_size / 5),即kernel_size / 5的向上取整。

self.input_channels = input_channels

对于网络层的结构:

第一隐藏卷积层C1用大小为k1×1的20个内核对n1×1输入数据进行滤波,self.conv = nn.Conv1d(1, 20, kernel_size)

池化层统一是self.pool = nn.MaxPool1d(pool_size)

self.features_size = self._get_final_flattened_size()获取展平的size,作为第一个线性层(全连接层)的输入维度。

然后定义两个线性层(全连接层),第一个是self.fc1 = nn.Linear(self.features_size, 100),输入维度self.features_size,输出维度100。第二个是self.fc2 = nn.Linear(100, n_classes),输入维度100,输出维度n_classes

forward()

def forward(self, x):
    # [In our design architecture, we choose the hyperbolic tangent function tanh(u)]
    # 在我们的设计架构中,我们选择双曲正切函数
    x = x.squeeze(dim=-1).squeeze(dim=-1)
    x = x.unsqueeze(1)
    x = self.conv(x)
    x = torch.tanh(self.pool(x))
    x = x.view(-1, self.features_size)
    x = torch.tanh(self.fc1(x))
    x = self.fc2(x)
    return x

x = x.squeeze(dim=-1).squeeze(dim=-1)是数据预处理,将x的倒数第一个维度和倒数第二个维度抹掉。

x = x.unsqueeze(1)再在第2个维度的位置增加一个维度(维度的索引从0开始)。

用一个小demo演示(注意:这里x是tensor类型):

import numpy as np

x = np.array([1,2,3])
print(x.shape)
# (3,)
x = x.reshape(3,1,1)
print(x.shape)
# (3, 1, 1)
x = x.squeeze(-1).squeeze(-1)
print(x.shape)
# (3,)

之后是网络的前向传播过程。

输入
nn.Conv1d(1, 20, kernel_size)
nn.MaxPool1d(pool_size)
tanh()
view(-1, self.features_size)
nn.Linear(self.features_size, 100)
tanh()
nn.Linear(100, n_classes)
输出

(其他网络模型,暂略……)


get_model()

功能:

获取模型模型名称和相应地超参数(实例化并获得具有足够超参数的模型,Instantiate and obtain a model with adequate hyperparameters)

输入和输出:

输入:
  • name:模型的名称,string类型(string of the model name)
  • kwargs:超参数,dictionary类型,**kwargs表示数目不定
输出:
  • model: PyTorch network
  • optimizer: PyTorch optimizer
  • criterion: PyTorch loss Function
  • kwargs: 具有理智默认值的超参数(hyperparameters with sane defaults)

代码和解析:

def get_model(name, **kwargs):
    """
    Instantiate and obtain a model with adequate hyperparameters

    Args:
        name: string of the model name      网络名,string类型
        kwargs: hyperparameters             超参数,dictionary类型,**kwargs表示数目不定
    Returns:
        model: PyTorch network
        optimizer: PyTorch optimizer
        criterion: PyTorch loss Function
        kwargs: hyperparameters with sane defaults  具有理智默认值的超参数
    """
    device = kwargs.setdefault('device', torch.device('cpu'))   # 获取字典kwargs中键device的值,否则返回默认值为cpu。
    n_classes = kwargs['n_classes']                             # 获取字典kwargs中键n_classes的值。
    n_bands = kwargs['n_bands']                                 # 获取字典kwargs中键n_bands的值。
    weights = torch.ones(n_classes)
    weights[torch.LongTensor(kwargs['ignored_labels'])] = 0.
    weights = weights.to(device)            # 放到cpu或gpu
    weights = kwargs.setdefault('weights', weights)

首先要强调的是,超参数以键值对的形式存储在kwargs中。通过访问字典kwargs来获得常用的超参数devicen_classesn_bands,同时通过setdefault()函数来设定默认值。

if name == 'nn':
	……
elif name == 'hamida':
	……

这一部分是,根据模型的选择,用分支语句来设定自己模型的合适的超参数,比如learning_rateoptimizercriterionepochbatch_size等。

model = model.to(device)
epoch = kwargs.setdefault('epoch', 100)
kwargs.setdefault('scheduler', optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.1, patience=epoch//4, verbose=True))
#kwargs.setdefault('scheduler', None)
kwargs.setdefault('batch_size', 100)
kwargs.setdefault('supervision', 'full')
kwargs.setdefault('flip_augmentation', False)
kwargs.setdefault('radiation_augmentation', False)
kwargs.setdefault('mixture_augmentation', False)
kwargs['center_pixel'] = center_pixel

这部分是用到的函数主要都是setdefault(),当模型参数不全时,将参数设定为默认值。

但其实模型的参数一般都在前面的if-esle设定好了,所以这些其实就是查漏补缺的作用。


val()

功能:

计算 val set 的准确率。

输入和输出:

输入:
  • net
  • data_loader
  • device
  • supervision
输出:
  • accuracy / total:其实就是准确率,accuracy统计的是准确的个数。

代码和解析:

函数定义
def val(net, data_loader, device='cpu', supervision='full'):
# TODO : fix me using metrics()
初始化和预操作
accuracy, total = 0., 0.
ignored_labels = data_loader.dataset.ignored_labels
  • accuracytotal初始化为float类型的0
  • 获取ignored_labels
开始检测
for batch_idx, (data, target) in enumerate(data_loader):

需要注意的是,再检测过程中,只遍历一遍 val set ,即epoch1

保持梯度
with torch.no_grad():

因为这部分只是作为检测,并不进行网络的训练,,所以要设置with torch.no_grad()

选择device
# Load the data into the GPU if required
data, target = data.to(device), target.to(device)

把数据放到device上,没什么好说的。

根据不同方式获取预测值
if supervision == 'full':
    output = net(data)
elif supervision == 'semi':
    outs = net(data)
    output, rec = outs
    
_, output = torch.max(output, dim=1)

一般都是全监督,所以只看全监督的情况。output = net(data),没啥可说的。

_, output = torch.max(output, dim=1)获取预测值,torch.max(output, dim=1)是按行找到最大的元素,并返回最大的元素和索引(values, indices)_, output表示获取返回值里的索引indices,索引是几就代表是第几类。

统计accuracy和total
for out, pred in zip(output.view(-1), target.view(-1)):
    if out.item() in ignored_labels:
        continue
    else:
        accuracy += out.item() == pred.item()
        total += 1

没啥好说的,很容易看懂。

返回 accuracy / total
return accuracy / total

save_model()

def save_model(model, model_name, dataset_name, **kwargs):
     model_dir = './checkpoints/' + model_name + "/" + dataset_name + "/"
     if not os.path.isdir(model_dir):
         os.makedirs(model_dir, exist_ok=True)
     if isinstance(model, torch.nn.Module):
         filename = str('wk') + "_epoch{epoch}_{metric:.2f}".format(**kwargs)
         tqdm.write("Saving neural network weights in {}".format(filename))
         torch.save(model.state_dict(), model_dir + filename + '.pth')
     else:
         filename = str('wk')
         tqdm.write("Saving model params in {}".format(filename))
         joblib.dump(model, model_dir + filename + '.pkl')

比较容易看明白,现在也不太需要细究,暂略。


train()

功能:

封装好的网络训练的函数。(Training loop to optimize a network for several epochs and a specified loss)

输入和输出:

输入:
  • net: a PyTorch model
  • optimizer: a PyTorch optimizer
  • data_loader: a PyTorch dataset loader
  • epoch: int specifying the number of training epochs
  • criterion: a PyTorch-compatible loss function, e.g. nn.CrossEntropyLoss
  • device (optional): torch device to use (defaults to CPU)
  • display_iter (optional): number of iterations before refreshing the display (False/None to switch off).
  • scheduler (optional): PyTorch scheduler,基于epoch调整学习率lr
  • val_loader (optional): validation dataset
  • supervision (optional): ‘full’ or ‘semi’
输出:

代码和解析:

函数的定义和信息
def train(net, optimizer, criterion, data_loader, epoch, scheduler=None,
          display_iter=100, device=torch.device('cpu'), display=None,
          val_loader=None, supervision='full'):
    """
    Training loop to optimize a network for several epochs and a specified loss

    Args:
        net: a PyTorch model
        optimizer: a PyTorch optimizer
        data_loader: a PyTorch dataset loader
        epoch: int specifying the number of training epochs
        criterion: a PyTorch-compatible loss function, e.g. nn.CrossEntropyLoss
        device (optional): torch device to use (defaults to CPU)
        display_iter (optional): number of iterations before refreshing the display (False/None to switch off).
        scheduler (optional): PyTorch scheduler
        val_loader (optional): validation dataset
        supervision (optional): 'full' or 'semi'
    """
损失函数的鲁棒性检测
if criterion is None:
    raise Exception("Missing criterion. You must specify a loss function.")

这一部分就是,如果损失函数criterion不存在,则报错并打印报错信息。

这是为了增加程序的鲁棒性,不影响程序的主要功能。

初始化部分变量
net.to(device)

save_epoch = epoch // 20 if epoch > 20 else 1

losses = np.zeros(1000000)
mean_losses = np.zeros(100000000)
iter_ = 1
loss_win, val_win = None, None
val_accuracies = []
训练网络
开始epoch循环
for e in tqdm(range(1, epoch + 1), desc="Training the network"):

range(1, epoch + 1)表示循环次数为epoch次。

tqdm()创建一个进度条,描述信息为desc="Training the network"

设置模型为训练模式
# Set the network to training mode
net.train()
avg_loss = 0.

关于train()函数,PyTorch官方文档如下:

train(mode=True)[SOURCE]

Sets the module in training mode.

This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. Dropout, BatchNorm, etc.

  • Parameters

    mode (bool) – whether to set training mode (True) or evaluation mode (False). Default: True.

  • Returns

    self

  • Return type

    Module

在训练模式下设置模块。

这仅对某些模块有任何影响。如果它们受到影响(例如Dropout,BatchNorm等),有关其在training / evaluation 模式中的行为的详细信息,请参阅特定模块的文档。

在这里net.train()就是将net转为训练training模式。

avg_loss = 0.avg_loss初始化为float类型的0

epoch中按batch训练
for batch_idx, (data, target) in tqdm(enumerate(data_loader), total=len(data_loader)):

enumerate(data_loader)data_loader构成一个索引序列enumerate(train_loader):

len(data_loader)的返回值为一个epoch内的batch的值(样本被分为多少个batch)。在这里训练集样本数为np.count_nonzero(train_gt): 4063,对应的len(train_loader): 41。由于设定的batch_size100,所以训练集的4063个样本被分为41batch(最后一个batch的样本数不够batch_size)。

通过遍历enumerate(data_loader)构成的索引序列,索引被赋值给batch_idx,表示这是第几个batch;数据和真实值被赋值给(data, target),表示输入数据和真值。

data target to device
# Load the data into the GPU if required
data, target = data.to(device), target.to(device)

datatarget放到对应的device上,默认为cpu,一般为gpu。

正向传播
optimizer.zero_grad()
if supervision == 'full':
    output = net(data)
    loss = criterion(output, target)
elif supervision == 'semi':
    outs = net(data)
    output, rec = outs
    loss = criterion[0](output, target) + net.aux_loss_weight * criterion[1](rec, data)
else:
    raise ValueError("supervision mode \"{}\" is unknown.".format(supervision))

首先将optimizer的梯度置零:optimizer.zero_grad()

然后根据监督方式的不同选择不同的训练方法,因为一般都是全监督,所以只分析全监督的情况:

if supervision == 'full':
    output = net(data)
    loss = criterion(output, target)
  • 计算预测值outputoutput = net(data)
  • 计算损失函数losscriterion(output, target)
反向传播
loss.backward()
optimizer.step()
  • 反向传播:loss.backward()
  • 优化:optimizer.step()
计算损失
avg_loss += loss.item()
losses[iter_] = loss.item()
mean_losses[iter_] = np.mean(losses[max(0, iter_ - 100):iter_ + 1])
  • avg_loss:初始化为0,累加每个batch计算得到的loss的值。
  • lossesndarray类型,每个位置依次存储每次迭代(每个batch)的损失的值。比如索引为1的位置存放的是第一次迭代(第一个batch)的loss的值。
  • mean_lossesndarray类型,将索引为iter_的位置置为losses[max(0, iter_ - 100):iter_ + 1]的均值。至于为什么这么干,不知道。
绘制Training loss和Validation accuracy曲线
if display_iter and iter_ % display_iter == 0: 

这里说一下为什么要有这样一个if判断。

首先要解决一下变量display_iter的含义:

display_iter (optional): number of iterations before refreshing the display (False/None to switch off).

简单来说,display_iter的意义是,一旦迭代次数iter_display_iter整数倍(比如100,200,……),就刷新显示(refreshing the display)。

所以display_iter不为0,且迭代次数iter_display_iter整数倍(迭代次数iter_display_iter取余为0)的时候,才更新 Training lossValidation accuracy 曲线。

string = 'Train (epoch {}/{}) [{}/{} ({:.0f}%)]\tLoss: {:.6f}'
string = string.format(
	e, epoch, batch_idx *
	len(data), len(data) * len(data_loader),
	100. * batch_idx / len(data_loader), mean_losses[iter_])

这一段打印的示例(我自己的数据)为,有6个值:

Train (epoch 55/100) [31100/32200 (97%)]        Loss: 0.024220
  • e:当前的epoch数,第几次遍历整个训练集。
  • epoch:总的epoch数,一共遍历几次训练集。
  • batch_idx *len(data):这个epoch已经训练结束的样本数。batch_idx 是现在的batch数,len(data)是一个batch的训练样本数。
  • len(data) * len(data_loader):这个epoch要训练的样本数。len(data_loader)是总的的batch数,len(data)是一个batch的训练样本数。
  • 100. * batch_idx / len(data_loader):这个epoch已经训练结束的样本数 / 这个epoch要训练的样本数,的值,一个百分数。
  • mean_losses[iter_]:第iter_次迭代(第iter_batch)的损失的值。
update = None if loss_win is None else 'append'
loss_win = display.line(
	X=np.arange(iter_ - display_iter, iter_),
	Y=mean_losses[iter_ - display_iter:iter_],
	win=loss_win,
	update=update,
	opts={'title': "Training loss",
		'xlabel': "Iterations",
		'ylabel': "Loss"
		}
)
tqdm.write(string)

第一句update = None if loss_win is None else 'append'的意思是,如果loss_winNoneupdateNone,如果loss_win不是Noneupdate'append'

loss_win的定义在第一句的下面,这就导致第一次运行的时候updateNone,之后运行的时候都是'append'

update: None
iter_: 100
display_iter: 100
--------------------------------------------------------
Train (epoch 3/100) [1700/4100 (41%)]   Loss: 0.916481


update: append
iter_: 200
display_iter: 100
--------------------------------------------------------
Train (epoch 5/100) [3500/4100 (85%)]   Loss: 0.695566


update: append
iter_: 300
display_iter: 100
--------------------------------------------------------
Train (epoch 8/100) [1200/4100 (29%)]   Loss: 0.599200
loss_win = display.line(
	X=np.arange(iter_ - display_iter, iter_),
	Y=mean_losses[iter_ - display_iter:iter_],
	win=loss_win,
	update=update,
	opts={'title': "Training loss",
		'xlabel': "Iterations",
		'ylabel': "Loss"
		}
)
tqdm.write(string)

这里的displayvis,采用visdom可视化,用到的函数是vis.line()。关于vis.line()

vis.line

这个函数绘制一个线条图。它需要输入一个N或NxM张量 Y来指定要绘制的M线(连接N点)的值。它还采用可选的X张量来指定相应的x轴值; X可以是一个N张量(在这种情况下,所有的线将共享相同的x轴值)或具有相同的大小Y。

以下opts是支持的:

  • opts.fillarea :填充行(boolean)以下的区域
  • opts.colormap :colormap(string; default = ‘Viridis’)
  • opts.markers :show markers(boolean; default = false)
  • opts.markersymbol:标志符号(string;默认= ‘dot’)
  • opts.markersize :标记大小(number;默认= ‘10’)
  • opts.legend :table包含图例名称

win=loss_win据我猜测,应该是将操作的window设置为loss_win,否则这么多window不知道操作哪一个。

前3次的结果图如图:

[外链图片转存失败(img-6scjDfp7-1568206440201)(file:///E:\734167802\Image\Group\BC~H{8N6WZ9FQWXTFVFRJ]6.png)]

简单来说就是第一次运行的时候,恰好有iter_等于display_iter这时候创建新窗口并绘制一次,之后iter_等于display_iter的整数倍的时候再绘制一次,并将后绘制的部分更新到第一次绘制的图像(窗口)上。

最后由tqdm.write(string)打印进度条。

if len(val_accuracies) > 0:
val_win = display.line(Y=np.array(val_accuracies),
                           X=np.arange(len(val_accuracies)),
                           win=val_win,
                           opts={'title': "Validation accuracy",
                                 'xlabel': "Epochs",
                                 'ylabel': "Accuracy"
                                })

这部分是打印 Validation accuracy,还有指定窗口这点东西,上面都有,不再赘述。

迭代变量加一
iter_ += 1
回收无用变量
del(data, target, loss, output)

对于del方法:

init() 方法对应的是 del() 方法,init() 方法用于初始化 Python 对象,而 del() 则用于销毁 Python 对象,即在任何 Python 对象将要被系统回收之时,系统都会自动调用该对象的 del() 方法。 当程序不再需要一个 Python 对象时,系统必须把该对象所占用的内存空间释放出来,这个过程被称为垃圾回收(GC,Garbage Collector),Python 会自动回收所有对象所占用的内存空间,因此开发者无须关心对象垃圾回收的过程。

简单来说,在运行完这一个batch,这一个batchdatatargetlossoutput都不被需要了,可以被回收来释放内存。

因为在下一个batch又会又新的datatarget,产生新的lossoutput

到此为止,一个epoch完毕

计算 avg_loss,val_accuracies,metric
avg_loss /= len(data_loader)

if val_loader is not None:
    val_acc = val(net, val_loader, device=device, supervision=supervision)
    val_accuracies.append(val_acc)
    metric = -val_acc
else:
    metric = avg_loss

avg_loss是对一个epoch内的所有的batch的损失loss取均值。

下面的一个判断语句是,如果val_loaderNone(第一次执行到这里),metric = avg_loss;之后执行到这里,就调用val()函数计算 val set 的准确率,并加入到val_accuracies中,再将指标metric设置为-val_acc(不知道为什么,暂略)。

Save the weights
# Save the weights
if e % save_epoch == 0:
    save_model(net, camel_to_snake(str(net.__class__.__name__)), data_loader.dataset.name, epoch=e, metric=abs(metric))

存储的文件名的示例为:

wk_epoch60_0.99.pth

test()

功能:

Test a model on a specific image(在特定图像上测试模型)

输入和输出:

输入:
  • net
  • img:用来做test的图像
  • hyperparams:超参数的字典
输出:
  • probsW × H × n_classes

代码和解析:

函数定义
def test(net, img, hyperparams):
    """
    Test a model on a specific image
    """
模型设置为test模式
net.eval()
提取超参数
patch_size = hyperparams['patch_size']
center_pixel = hyperparams['center_pixel']
batch_size, device = hyperparams['batch_size'], hyperparams['device']
n_classes = hyperparams['n_classes']

kwargs = {'step': hyperparams['test_stride'], 'window_size': (patch_size, patch_size)}
  • patch_size:窗口的大小。窗口可以包含上下文信息。
  • center_pixel:为True的时候,只看中间的样本,不考虑上下文信息。
  • batch_size
  • device
  • kwargs:一个字典,step为步长,window_size为窗口大小(元组类型)
初始化返回结果 probs
probs = np.zeros(img.shape[:2] + (n_classes,))

img的维度是W × H × channelimg.shape[:2] + (n_classes,)获得img的前两个维度W × H,并将n_classes作为第三个维度,即W × H × n_classes

用一个小demo演示:

shape = (340,680,103)
print(shape[:2] + (10,))
# (340, 680, 10)

probs的初始化结果为W × H × n_classes的全0数组。

计算迭代总数 iterations
iterations = count_sliding_window(img, **kwargs) // batch_size

count_sliding_window()计算整个图像可以产生多少个window,每batch_size最为一批(batch),二者相除就是最大迭代次数iterations

开始迭代
for batch in tqdm(grouper(batch_size, sliding_window(img, **kwargs)),
                  total=(iterations),
                  desc="Inference on the image"
                  ):

tqdm()是生成进度条,total表示进度条的上限,desc为描述信息的字符串。

grouper()是分组器,返回 chunk of n elements from the iterable ,应该是从迭代器sliding_window()获取n个elements

提取数据:
with torch.no_grad():
    if patch_size == 1:
        data = [b[0][0, 0] for b in batch]
        data = np.copy(data)
        data = torch.from_numpy(data)
    else:
        data = [b[0] for b in batch]
        data = np.copy(data)
        data = data.transpose(0, 3, 1, 2)
        data = torch.from_numpy(data)
        data = data.unsqueeze(1)

通过这部分自加的代码,来查看batchb的相关信息:

# ----------------------------自加-------------------------
print('batch',batch)
for b in batch:
    print('b:',b)
    print('b[0]:',b[0])
    print('b[0][0, 0]:',b[0][0, 0])
os.system('pause')
# ----------------------------自加-------------------------

这里只详细解释patch_size的情况。

首先是 b 的形式,b的类型是tuple(因为元素类型不同):

b: (array([[[0.077625, 0.09325 , 0.0695  , 0.045   , 0.035625, 0.0375  ,
         0.03425 , 0.0345  , 0.0415  , 0.039875, 0.03475 , 0.031875,
         0.029   , 0.025875, 0.02625 , 0.026125, 0.021   , 0.017375,
         0.017125, 0.01925 , 0.021   , 0.02525 , 0.028125, 0.028875,
         0.0305  , 0.032125, 0.032875, 0.03275 , 0.03325 , 0.0345  ,
         0.035625, 0.036375, 0.035625, 0.034   , 0.033875, 0.030125,
         0.026   , 0.02425 , 0.022375, 0.019625, 0.02025 , 0.0215  ,
         0.02125 , 0.0195  , 0.018875, 0.017625, 0.016875, 0.015875,
         0.017125, 0.017875, 0.018   , 0.016125, 0.012875, 0.013   ,
         0.012375, 0.011625, 0.013375, 0.01575 , 0.014625, 0.01225 ,
         0.01175 , 0.01175 , 0.012375, 0.01275 , 0.0145  , 0.019125,
         0.0235  , 0.030375, 0.04025 , 0.051625, 0.0615  , 0.073875,
         0.092125, 0.116625, 0.140625, 0.165875, 0.189875, 0.20825 ,
         0.22375 , 0.24175 , 0.253625, 0.25425 , 0.25125 , 0.258625,
         0.273875, 0.279125, 0.280625, 0.281125, 0.281875, 0.28125 ,
         0.281125, 0.279875, 0.279875, 0.28525 , 0.286   , 0.28025 ,
         0.274125, 0.27525 , 0.278125, 0.28325 , 0.2885  , 0.293125,
         0.295125]]], dtype=float32), 0, 2, 1, 1)

那么batch应该就是由好多b组成的列表或元组(我更倾向是列表)。

那么b[0]就是数据的部分,即:

array([[[0.077625, 0.09325 , 0.0695  , 0.045   , 0.035625, 0.0375  ,
         0.03425 , 0.0345  , 0.0415  , 0.039875, 0.03475 , 0.031875,
         0.029   , 0.025875, 0.02625 , 0.026125, 0.021   , 0.017375,
         0.017125, 0.01925 , 0.021   , 0.02525 , 0.028125, 0.028875,
         0.0305  , 0.032125, 0.032875, 0.03275 , 0.03325 , 0.0345  ,
         0.035625, 0.036375, 0.035625, 0.034   , 0.033875, 0.030125,
         0.026   , 0.02425 , 0.022375, 0.019625, 0.02025 , 0.0215  ,
         0.02125 , 0.0195  , 0.018875, 0.017625, 0.016875, 0.015875,
         0.017125, 0.017875, 0.018   , 0.016125, 0.012875, 0.013   ,
         0.012375, 0.011625, 0.013375, 0.01575 , 0.014625, 0.01225 ,
         0.01175 , 0.01175 , 0.012375, 0.01275 , 0.0145  , 0.019125,
         0.0235  , 0.030375, 0.04025 , 0.051625, 0.0615  , 0.073875,
         0.092125, 0.116625, 0.140625, 0.165875, 0.189875, 0.20825 ,
         0.22375 , 0.24175 , 0.253625, 0.25425 , 0.25125 , 0.258625,
         0.273875, 0.279125, 0.280625, 0.281125, 0.281875, 0.28125 ,
         0.281125, 0.279875, 0.279875, 0.28525 , 0.286   , 0.28025 ,
         0.274125, 0.27525 , 0.278125, 0.28325 , 0.2885  , 0.293125,
         0.295125]]], dtype=float32)

b[0]shape为:

(1, 1, 103)

那么b[0][0,0]和它的shape为:

[0.077625, 0.09325 , 0.0695  , 0.045   , 0.035625, 0.0375  ,
 0.03425 , 0.0345  , 0.0415  , 0.039875, 0.03475 , 0.031875,
 0.029   , 0.025875, 0.02625 , 0.026125, 0.021   , 0.017375,
 0.017125, 0.01925 , 0.021   , 0.02525 , 0.028125, 0.028875,
 0.0305  , 0.032125, 0.032875, 0.03275 , 0.03325 , 0.0345  ,
 0.035625, 0.036375, 0.035625, 0.034   , 0.033875, 0.030125,
 0.026   , 0.02425 , 0.022375, 0.019625, 0.02025 , 0.0215  ,
 0.02125 , 0.0195  , 0.018875, 0.017625, 0.016875, 0.015875,
 0.017125, 0.017875, 0.018   , 0.016125, 0.012875, 0.013   ,
 0.012375, 0.011625, 0.013375, 0.01575 , 0.014625, 0.01225 ,
 0.01175 , 0.01175 , 0.012375, 0.01275 , 0.0145  , 0.019125,
 0.0235  , 0.030375, 0.04025 , 0.051625, 0.0615  , 0.073875,
 0.092125, 0.116625, 0.140625, 0.165875, 0.189875, 0.20825 ,
 0.22375 , 0.24175 , 0.253625, 0.25425 , 0.25125 , 0.258625,
 0.273875, 0.279125, 0.280625, 0.281125, 0.281875, 0.28125 ,
 0.281125, 0.279875, 0.279875, 0.28525 , 0.286   , 0.28025 ,
 0.274125, 0.27525 , 0.278125, 0.28325 , 0.2885  , 0.293125,
 0.295125]
 
(103,)

了解了bb[0]b[0][0,1]的组成形式和类型后,再回来看这一句:

data = [b[0][0, 0] for b in batch]

等号右边先是一个中括号[],表示是一个列表,列表里面是一个for循环,将每次循环得到的b进行b[0][0, 0]的操作,放到列表里作为列表的一个元素。

用以和小demo演示一下功能:

a1 = np.array([[[0.079625, 0.074   , 0.06025 , 0.0695  , 0.0635  , 0.0355  ,
         0.02225 , 0.02475 , 0.024125, 0.028   , 0.027125, 0.026875,
         0.023375, 0.020125, 0.019   , 0.017   , 0.0155  , 0.01525 ,
         0.015875, 0.01575 , 0.015625, 0.015375, 0.018375, 0.0235  ,
         0.026   , 0.025375, 0.02525 , 0.02575 , 0.027375, 0.029375,
         0.02975 , 0.028375, 0.027125, 0.026875, 0.027   , 0.025125,
         0.02375 , 0.020875, 0.018625, 0.02025 , 0.02175 , 0.0225  ,
         0.022125, 0.02125 , 0.0205  , 0.021625, 0.0235  , 0.02275 ,
         0.020125, 0.018375, 0.017625, 0.01925 , 0.021875, 0.02075 ,
         0.0175  , 0.01725 , 0.01825 , 0.017375, 0.016125, 0.018125,
         0.01925 , 0.017125, 0.016625, 0.016   , 0.016125, 0.02175 ,
         0.030625, 0.04225 , 0.056875, 0.073125, 0.09    , 0.10625 ,
         0.126625, 0.153125, 0.1825  , 0.21275 , 0.24225 , 0.269625,
         0.289625, 0.304125, 0.315625, 0.319   , 0.311625, 0.31925 ,
         0.341625, 0.347625, 0.3435  , 0.3435  , 0.342125, 0.33875 ,
         0.335125, 0.33025 , 0.330625, 0.3355  , 0.334375, 0.326125,
         0.317625, 0.318875, 0.321375, 0.321125, 0.321625, 0.3275  ,
         0.3305  ]]])
a2 = np.array([[[0.115   , 0.0825  , 0.058125, 0.038875, 0.04375 , 0.046875,
         0.047   , 0.041375, 0.032625, 0.024875, 0.022375, 0.021125,
         0.02075 , 0.022375, 0.021625, 0.019125, 0.017375, 0.015   ,
         0.013375, 0.016625, 0.019875, 0.0235  , 0.028375, 0.0305  ,
         0.030875, 0.032375, 0.035   , 0.03725 , 0.04    , 0.042375,
         0.042625, 0.043375, 0.042375, 0.039875, 0.0385  , 0.036375,
         0.033875, 0.032375, 0.03075 , 0.029   , 0.0295  , 0.029625,
         0.030625, 0.027625, 0.024875, 0.02525 , 0.025375, 0.02625 ,
         0.026375, 0.026   , 0.0265  , 0.026875, 0.026875, 0.027125,
         0.02675 , 0.024125, 0.022875, 0.021625, 0.01975 , 0.019125,
         0.0195  , 0.019875, 0.019625, 0.02025 , 0.023625, 0.0295  ,
         0.03625 , 0.046625, 0.06175 , 0.076875, 0.09225 , 0.108625,
         0.130875, 0.157375, 0.182625, 0.209   , 0.2285  , 0.246   ,
         0.262   , 0.272625, 0.27975 , 0.276875, 0.269125, 0.279   ,
         0.299875, 0.3005  , 0.294   , 0.291625, 0.293   , 0.290375,
         0.289   , 0.292875, 0.2935  , 0.290875, 0.287375, 0.278625,
         0.27325 , 0.278375, 0.281   , 0.27725 , 0.282375, 0.294   ,
         0.298125]]])
batch = [a1, a2]
data = [b[0, 0] for b in batch]
data = np.copy(data)
print('data:',data)
# data: [[0.079625 0.074    0.06025  0.0695   0.0635   0.0355   0.02225  0.02475
#   0.024125 0.028    0.027125 0.026875 0.023375 0.020125 0.019    0.017
#   0.0155   0.01525  0.015875 0.01575  0.015625 0.015375 0.018375 0.0235
#   0.026    0.025375 0.02525  0.02575  0.027375 0.029375 0.02975  0.028375
#   0.027125 0.026875 0.027    0.025125 0.02375  0.020875 0.018625 0.02025
#   0.02175  0.0225   0.022125 0.02125  0.0205   0.021625 0.0235   0.02275
#   0.020125 0.018375 0.017625 0.01925  0.021875 0.02075  0.0175   0.01725
#   0.01825  0.017375 0.016125 0.018125 0.01925  0.017125 0.016625 0.016
#   0.016125 0.02175  0.030625 0.04225  0.056875 0.073125 0.09     0.10625
#   0.126625 0.153125 0.1825   0.21275  0.24225  0.269625 0.289625 0.304125
#   0.315625 0.319    0.311625 0.31925  0.341625 0.347625 0.3435   0.3435
#   0.342125 0.33875  0.335125 0.33025  0.330625 0.3355   0.334375 0.326125
#   0.317625 0.318875 0.321375 0.321125 0.321625 0.3275   0.3305  ]
#  [0.115    0.0825   0.058125 0.038875 0.04375  0.046875 0.047    0.041375
#   0.032625 0.024875 0.022375 0.021125 0.02075  0.022375 0.021625 0.019125
#   0.017375 0.015    0.013375 0.016625 0.019875 0.0235   0.028375 0.0305
#   0.030875 0.032375 0.035    0.03725  0.04     0.042375 0.042625 0.043375
#   0.042375 0.039875 0.0385   0.036375 0.033875 0.032375 0.03075  0.029
#   0.0295   0.029625 0.030625 0.027625 0.024875 0.02525  0.025375 0.02625
#   0.026375 0.026    0.0265   0.026875 0.026875 0.027125 0.02675  0.024125
#   0.022875 0.021625 0.01975  0.019125 0.0195   0.019875 0.019625 0.02025
#   0.023625 0.0295   0.03625  0.046625 0.06175  0.076875 0.09225  0.108625
#   0.130875 0.157375 0.182625 0.209    0.2285   0.246    0.262    0.272625
#   0.27975  0.276875 0.269125 0.279    0.299875 0.3005   0.294    0.291625
#   0.293    0.290375 0.289    0.292875 0.2935   0.290875 0.287375 0.278625
#   0.27325  0.278375 0.281    0.27725  0.282375 0.294    0.298125]]
print('data.shape:',data.shape)
# data.shape: (2, 103)

这里的batch只设定了两个元素,可以看到最终的返回值的shape(2, 103),应该是每个sample作为一行。

如果元组batch中有100个元素,datashape就是 (100,103),为(batch_size, channel)

所以再看if patch_size == 1的这部分代码:

with torch.no_grad():
    if patch_size == 1:
        data = [b[0][0, 0] for b in batch]
        data = np.copy(data)
        data = torch.from_numpy(data)

首先是test,所以设定在with torch.no_grad():下执行。

data = [b[0][0, 0] for b in batch]元组batch中的每个元素的第一个元素(每一个sample的数据)提取出来,组成一个叫data的列表。

然后通过np.copy()datalist转成array

data = torch.from_numpy(data)dataarray转成tensor

获取预测值 output
indices = [b[1:] for b in batch]
data = data.to(device)
output = net(data)
if isinstance(output, tuple):
    output = output[0]
output = output.to('cpu')

if patch_size == 1 or center_pixel:
    output = output.numpy()
else:
    output = np.transpose(output.numpy(), (0, 2, 3, 1))

indices = [b[1:] for b in batch]获取索引信息。

data = data.to(device)data放到相应的device上。

output = net(data)获取data的预测值。

if isinstance(output, tuple): output = output[0]这一句不知道,暂略。

output = output.to('cpu')output转到cpu。

然后在patch_size == 1 or center_pixel的情况下,将output转成array类型output = output.numpy()

统计结果
for (x, y, w, h), out in zip(indices, output):
    if center_pixel:
        probs[x + w // 2, y + h // 2] += out
    else:
        probs[x:x + w, y:y + h] += out

首先强调一下,返回结果probs初始化的时候是初始化为全0的(对应的一般只设置一个ignored_label并将其对应的label作为0)。

整个test的大致意思是,每次通过grouper()获取一个batch的sample(个数为batch_size),然后将它们的预测结果output更新到probs中。这样一个batch一个batch地进行完,就得到了全部训练集样本的预测结果(训练样本以外的样本,对应位置的值为全零)。

其它暂略。

inference.py

这一部分的代码,跟main.py有重复。

(暂略)

datasets.py

此文件包含用于高光谱图像和相关助手的PyTorch数据集。

DATASETS_CONFIG + 更新

数据集配置

DATASETS_CONFIG是数据集配置字典,是字典dictionary类型。键值对dataset_name是数据集的urlsimggt

DATASETS_CONFIG = {
        'PaviaC': {
            'urls': ['http://www.ehu.eus/ccwintco/uploads/e/e3/Pavia.mat',      # urls是链接
                     'http://www.ehu.eus/ccwintco/uploads/5/53/Pavia_gt.mat'],
            'img': 'Pavia.mat',
            'gt': 'Pavia_gt.mat'
            },
        'PaviaU': {
            'urls': ['http://www.ehu.eus/ccwintco/uploads/e/ee/PaviaU.mat',
                     'http://www.ehu.eus/ccwintco/uploads/5/50/PaviaU_gt.mat'],
            'img': 'PaviaU.mat',
            'gt': 'PaviaU_gt.mat'
            },
        'KSC': {
            'urls': ['http://www.ehu.es/ccwintco/uploads/2/26/KSC.mat',
                     'http://www.ehu.es/ccwintco/uploads/a/a6/KSC_gt.mat'],
            'img': 'KSC.mat',
            'gt': 'KSC_gt.mat'
            },
        'IndianPines': {
            'urls': ['http://www.ehu.eus/ccwintco/uploads/6/67/Indian_pines_corrected.mat',
                     'http://www.ehu.eus/ccwintco/uploads/c/c4/Indian_pines_gt.mat'],
            'img': 'Indian_pines_corrected.mat',
            'gt': 'Indian_pines_gt.mat'
            },
        'Botswana': {
            'urls': ['http://www.ehu.es/ccwintco/uploads/7/72/Botswana.mat',
                     'http://www.ehu.es/ccwintco/uploads/5/58/Botswana_gt.mat'],
            'img': 'Botswana.mat',
            'gt': 'Botswana_gt.mat',
            }
    }
更新数据集配置
try:
    from custom_datasets import CUSTOM_DATASETS_CONFIG
    DATASETS_CONFIG.update(CUSTOM_DATASETS_CONFIG)
except ImportError:
    pass

本质上是将字典CUSTOM_DATASETS_CONFIG更新到字典DATASETS_CONFIG中,用到的是字典操作的update()函数。


class TqdmUpTo(tqdm)

一个class 进度条功能。

(暂略)


get_dataset()

####功能:

下载并读取数据集。

####输入和输出:

输入:
  • dataset_name: string with the name of the dataset
  • target_folder (optional): folder to store the datasets, defaults to ./ 。当然我一般是指定位置的。
  • datasets (optional): dataset configuration dictionary, defaults to prebuilt one。一般设定为DATASETS_CONFIG
输出:
  • img: 3D hyperspectral image (WxHxB),B为波段。
  • gt: 2D int array of labels
  • label_values: list of class names
  • ignored_labels: list of int classes to ignore
  • rgb_bands: int元组,对应红色、绿色和蓝色波段(int tuple that correspond to red, green and blue bands)

代码和解析:

初始化参数:
def get_dataset(dataset_name, target_folder="./", datasets=DATASETS_CONFIG):
# def get_dataset(dataset_name, target_folder="C:\\Users\\73416\\PycharmProjects\\HSIproject\\Datasets\\", datasets=DATASETS_CONFIG):
    """ Gets the dataset specified by name and return the related components.
    Args:
        dataset_name: string with the name of the dataset
        target_folder (optional): folder to store the datasets, defaults to ./  
        datasets (optional): dataset configuration dictionary, defaults to prebuilt one
    Returns:
        img: 3D hyperspectral image (WxHxB)
        gt: 2D int array of labels                      # 标签array
        label_values: list of class names               # 类的名单
        ignored_labels: list of int classes to ignore
        rgb_bands: int tuple that correspond to red, green and blue bands           # int元组,对应红色、绿色和蓝色波段
    """
    target_folder = "C:\\Datasets\\"       # 自加,修改数据集的路径
    # print(target_folder)  # 自加

    palette = None

    # 当输入的数据集的名字没有在数据集字典datasets=DATASETS_CONFIG中,则报错dataset is unknown
    if dataset_name not in datasets.keys():
        raise ValueError("{} dataset is unknown.".format(dataset_name))

    # 字典操作,取得数据集字典datasets中,键(key)为dataset_name的值(urls、img和gt)
    dataset = datasets[dataset_name]

    folder = target_folder + datasets[dataset_name].get('folder', dataset_name + '/')
    # folder为:C:\Datasets\PaviaU/

这部分是初始的一些参数:

  • target_folder:数据集文件夹Datasets的存放路径。比如target_folder = "C:\\Datasets\\"
  • palette:调色板,初始化为None
  • dataset:取得数据集字典datasets中,键(key)为dataset_name的值(urlsimggt
  • folder:特定数据集文件夹的存放路径。比如C:\Datasets\PaviaU/
下载数据集:
# Download the dataset if is not present
if dataset.get('download', True):
    # 如果没有folder(C:\Datasets\PaviaU/)文件夹,则创建该文件夹
    if not os.path.isdir(folder):
        os.mkdir(folder)
    # 下载数据集(暂pass)
    for url in datasets[dataset_name]['urls']:
        # download the files
        filename = url.split('/')[-1]
        if not os.path.exists(folder + filename):
            with TqdmUpTo(unit='B', unit_scale=True, miniters=1,
                      desc="Downloading {}".format(filename)) as t:
                urlretrieve(url, filename=folder + filename,
                                 reporthook=t.update_to)
elif not os.path.isdir(folder):
   print("WARNING: {} is not downloadable.".format(dataset_name))

if dataset.get('download', True):,则下载指定数据集。

首先检查指定路径folder下,文件夹是否存在 ,os.path.isdir(folder)。如果不存在则在指定路径folder下,创建文件夹。

之后是下载数据集的代码(包括获取urls、创建进度条等),暂略。

当然还有对于dataset_name的鲁棒性检查,也是暂略。

读取数据集+预处理:
数据集读取:
# 读取数据集
if dataset_name == 'PaviaC':
    # Load the image
    # 通过自己写的open_file()函数打开C:\Datasets\PaviaU/Pavia.mat文件,返回值为字典类型,通过['pavia']来提取键值对的中的值
    img = open_file(folder + 'Pavia.mat')['pavia']

    # 取RGB波段,为什么这么取不知道
    rgb_bands = (55, 41, 12)

    # 通过自己写的open_file()函数打开C:\Datasets\PaviaU/Pavia_gt.mat文件,返回值为字典类型,通过['pavia_gt']来提取键值对的中的值
    gt = open_file(folder + 'Pavia_gt.mat')['pavia_gt']

    # ???label_values有什么用,如何和gt链接
    label_values = ["Undefined", "Water", "Trees", "Asphalt",
                    "Self-Blocking Bricks", "Bitumen", "Tiles", "Shadows",
                    "Meadows", "Bare Soil"]

    ignored_labels = [0]

elif dataset_name == 'PaviaU':
    # Load the image
    img = open_file(folder + 'PaviaU.mat')['paviaU']

    rgb_bands = (55, 41, 12)

    gt = open_file(folder + 'PaviaU_gt.mat')['paviaU_gt']

    label_values = ['Undefined', 'Asphalt', 'Meadows', 'Gravel', 'Trees',
                    'Painted metal sheets', 'Bare Soil', 'Bitumen',
                    'Self-Blocking Bricks', 'Shadows']

    ignored_labels = [0]

elif dataset_name == 'IndianPines':
    # Load the image
    img = open_file(folder + 'Indian_pines_corrected.mat')
    img = img['indian_pines_corrected']

    rgb_bands = (43, 21, 11)  # AVIRIS sensor

    gt = open_file(folder + 'Indian_pines_gt.mat')['indian_pines_gt']

    label_values = ["Undefined", "Alfalfa", "Corn-notill", "Corn-mintill",
                    "Corn", "Grass-pasture", "Grass-trees",
                    "Grass-pasture-mowed", "Hay-windrowed", "Oats",
                    "Soybean-notill", "Soybean-mintill", "Soybean-clean",
                    "Wheat", "Woods", "Buildings-Grass-Trees-Drives",
                    "Stone-Steel-Towers"]

    ignored_labels = [0]

elif dataset_name == 'Botswana':
    # Load the image
    img = open_file(folder + 'Botswana.mat')['Botswana']

    rgb_bands = (75, 33, 15)

    gt = open_file(folder + 'Botswana_gt.mat')['Botswana_gt']
    label_values = ["Undefined", "Water", "Hippo grass",
                    "Floodplain grasses 1", "Floodplain grasses 2",
                    "Reeds", "Riparian", "Firescar", "Island interior",
                    "Acacia woodlands", "Acacia shrublands",
                    "Acacia grasslands", "Short mopane", "Mixed mopane",
                    "Exposed soils"]

    ignored_labels = [0]

elif dataset_name == 'KSC':
    # Load the image
    img = open_file(folder + 'KSC.mat')['KSC']

    rgb_bands = (43, 21, 11)  # AVIRIS sensor

    gt = open_file(folder + 'KSC_gt.mat')['KSC_gt']

    label_values = ["Undefined", "Scrub", "Willow swamp",
                    "Cabbage palm hammock", "Cabbage palm/oak hammock",
                    "Slash pine", "Oak/broadleaf hammock",
                    "Hardwood swamp", "Graminoid marsh", "Spartina marsh",
                    "Cattail marsh", "Salt marsh", "Mud flats", "Wate"]

    ignored_labels = [0]
else:
    # 详细见自定义数据集模块
    # Custom dataset
    img, gt, rgb_bands, ignored_labels, label_values, palette = CUSTOM_DATASETS_CONFIG[dataset_name]['loader'](folder)

这部分是读取下载好的数据集文件,包括3D image2D label。读取不同的数据集文件的操作也有很大的重复性。

每次读取数据集文件的时候,都是干了这几件事情(以dataset_name == 'PaviaC'为例):

  • 读取数据。通过自己写的open_file()函数打开C:\Datasets\PaviaU/Pavia.mat文件,返回值为字典类型,通过['pavia']来提取键值对的中的值。代码img = open_file(folder + 'Pavia.mat')['pavia']
  • 取RGB波段,但怎么取不知道。代码rgb_bands = (55, 41, 12)
  • 读取gt。通过自己写的open_file()函数打开C:\Datasets\PaviaU/Pavia_gt.mat文件,返回值为字典类型,通过['pavia_gt']来提取键值对的中的值。代码gt = open_file(folder + 'Pavia_gt.mat')['pavia_gt']
  • 确定label_values代码label_values = ["Undefined", "Water", "Trees", "Asphalt", "Self-Blocking Bricks", "Bitumen", "Tiles", "Shadows", "Meadows", "Bare Soil"]
  • 确定ignored_labels,一般为0。代码ignored_labels = [0]

**需要注意的是:**当要处理的数据集不是项目预先定义的数据集的时候(即处理的是用户自己的数据集),会在最后的else来返回值。详见CUSTOM_DATASETS_CONFIG

else:
    # 详细见自定义数据集模块
    # Custom dataset
    img, gt, rgb_bands, ignored_labels, label_values, palette = CUSTOM_DATASETS_CONFIG[dataset_name]['loader'](folder)
处理NaN的情况:
# 处理NaN的情况
# Filter NaN out
nan_mask = np.isnan(img.sum(axis=-1))
if np.count_nonzero(nan_mask) > 0:
   print("Warning: NaN have been found in the data. It is preferable to remove them beforehand. Learning on NaN data is disabled.")
img[nan_mask] = 0
gt[nan_mask] = 0
ignored_labels.append(0)
ignored_labels = list(set(ignored_labels))

这个情况不常见,暂略。

Normalization 归一化:
# Normalization 归一化
img = np.asarray(img, dtype='float32')
img = (img - np.min(img)) / (np.max(img) - np.min(img))

首先把img的每个元素的类型变为float32img = np.asarray(img, dtype='float32')),然后归一化操作(img = (img - np.min(img)) / (np.max(img) - np.min(img)))。

返回值:
return img, gt, label_values, ignored_labels, rgb_bands, palette
  • img: 3D hyperspectral image (WxHxB),B为波段。
  • gt: 2D int array of labels
  • label_values: list of class names
  • ignored_labels: list of int classes to ignore
  • rgb_bands: int元组,对应红色、绿色和蓝色波段(int tuple that correspond to red, green and blue bands)
  • palette:默认返回为None。

class HyperX(torch.utils.data.Dataset)

这是高光谱场景的通用的

class HyperX(torch.utils.data.Dataset):

类名为HyperX,继承的父类是torch.utils.data.Dataset

__ init__(self, data, gt, **hyperparams):

功能:

对类的属性进行初始化。

输入和输出:
输入:
  • data: 3D hyperspectral image 图形
  • gt: 2D array of labels 标签
  • **hyperparamshyperparams是包含超参数的字典dictionary**表示这个位置接收任意多个关键字参数(比如a=1,b=2,c=3,d=4,e=5等类似于键值对)。 **将多输入的变量,存储为字典dictionary类型
输出:

无。

代码和解析:
读取img、gt和超参数
class HyperX(torch.utils.data.Dataset):
    """ Generic class for a hyperspectral scene """

    def __init__(self, data, gt, **hyperparams):    #??? **hyperparams表示接受不定数量的超参数?
        """
        Args:
            data: 3D hyperspectral image    图形
            gt: 2D array of labels          标签
            patch_size: int, size of the spatial neighbourhood  (int,空间邻域的大小)
            center_pixel: bool, set to True to consider only the label of the
                          center pixel  (bool类型,设置为True仅考虑中心像素的标签)
            data_augmentation: bool, set to True to perform random flips    (数据增强)
            supervision: 'full' or 'semi' supervised algorithms             (监督方式:监督或半监督)
        """
        super(HyperX, self).__init__()
        # 读取img
        self.data = data
        # 读取gt
        self.label = gt
        # 读取超参数
        self.name = hyperparams['dataset']
        self.patch_size = hyperparams['patch_size']
        self.ignored_labels = set(hyperparams['ignored_labels'])
        self.flip_augmentation = hyperparams['flip_augmentation']
        self.radiation_augmentation = hyperparams['radiation_augmentation'] 
        self.mixture_augmentation = hyperparams['mixture_augmentation'] 
        self.center_pixel = hyperparams['center_pixel']
        supervision = hyperparams['supervision']

读取ignored_labels的这行代码(self.ignored_labels = set(hyperparams['ignored_labels']))有一个set()函数,是把一个可迭代对象的元素变为集合类型。

一个set()函数的小demo:

x = set('runoob')
print(x)
# {'r', 'b', 'o', 'u', 'n'}
print(type(x))
# 

其他没什么好说的。

监督方式:
# 监督方式
# Fully supervised : use all pixels with label not ignored. 全监督:使用标签未被忽略的所有像素
if supervision == 'full':
    mask = np.ones_like(gt)
    for l in self.ignored_labels:
        mask[gt == l] = 0
# Semi-supervised : use all pixels, except padding. 半监督:使用除填充之外的所有像素
elif supervision == 'semi':
    mask = np.ones_like(gt)

全监督是使用标签未被忽略的所有像素,mask是将gt中类别为ignored_labels的对应位置置零,其他位置置一。

半监督是使用除padding之外的所有像素,mask是全部置一。

获取索引:
x_pos, y_pos = np.nonzero(mask)
p = self.patch_size // 2
self.indices = np.array([(x,y) for x,y in zip(x_pos, y_pos) if x > p and x < data.shape[0] - p and y > p and y < data.shape[1] - p])
self.labels = [self.label[x,y] for x,y in self.indices]
np.random.shuffle(self.indices)

由于maskignored_labels为0,所以通过np.nonzero()获取mask中的非零元素的索引,返回值为两个array组成的元组,一个array是 x 轴的索引,另一个array是 y 轴的索引。

p = self.patch_size // 2(" / “就表示浮点数除法,返回浮点结果;” // "表示整数除法),但为什么这样取并不知道。但我猜测是将图片分块成不同的block。而且我觉得是把整个图片分成了4个block,类似“田”的形状。

下一步是获取指定范围的元素的索引,范围是 x ∈ (p, data.shape[0] - p), y ∈ (p, data.shape[1] - p)。方法是通过zip()函数来使索引变为(x, y)的形式,然后遍历这个索引筛选出在x ∈ (p, data.shape[0] - p), y ∈ (p, data.shape[1] - p)中的索引,把值赋给indices

然后由self.labels = [self.label[x,y] for x,y in self.indices]获取相应索引indices下的标签。

np.random.shuffle(self.indices)self.indices打乱。(但是标签并没有跟着被打乱呀?)


flip(*arrays)

功能:

对输入的多个数组arrays进行水平或垂直翻转。(我怀疑是数据增强的一种方法)

输入和输出:
输入:
  • arrays:多个数组。类型不明。

*表示这个位置接收任意多个非关键字参数(比如1,2,3,4,5等值) ,*将多输入的变量,存储为元组tuple类型

输出:
  • arrays:多个数组。类型不明。
代码和解析:
def flip(*arrays):
    horizontal = np.random.random() > 0.5
    vertical = np.random.random() > 0.5
    if horizontal:
        arrays = [np.fliplr(arr) for arr in arrays]
    if vertical:
        arrays = [np.flipud(arr) for arr in arrays]
    return arrays

水平(左右)翻转和垂直(上下)翻转,都是随机,通过生成随机数来实现。代码中默认的概率为0.5。对于p = 0.5如何实现,就是先生成一个随机数,若比0.5大则为True,反之则为False

horizontalTrue时,通过遍历arrays使得每一个arr都水平(左右)翻转。

verticalTrue时,通过遍历arrays使得每一个arr都垂直(上下)翻转。


radiation_noise()

功能:

给数据data加上噪音。

输入和输出:
输入:
  • data:带处理的数据。
  • alpha_rangedata的保留范围,默认为(0.9, 1.1)
  • beta:噪音noise的保留比例,默认为1/25
输出:
  • alpha * data + beta * noisedatanoise的组合
代码和解析:
@staticmethod
def radiation_noise(data, alpha_range=(0.9, 1.1), beta=1/25):
    alpha = np.random.uniform(*alpha_range)
    noise = np.random.normal(loc=0., scale=1.0, size=data.shape)
    return alpha * data + beta * noise

首先通过alpha_range确定alpha的值,用np.random.uniform(a, b)函数随机生成(a, b)之间的随机数,*alpha_range表示接收的是非关键字类型的变量,并将变量拆解。alpha = np.random.uniform(*alpha_range)

随机数noise通过正态分布产生,noise = np.random.normal(loc=0., scale=1.0, size=data.shape)表示产生均值为0,标准差为1.0,与data同size的服从正态分布的随机数。

最后返回datanoise的组合:alpha * data + beta * noise


mixture_noise()

看不懂,暂略。


__ len__()

返回对象的indices属性的长度。

def __len__(self):
    return len(self.indices)

__ getitem__()

功能:

得到以指定位置i为中心,size为patch_size × patch_size图像块block

同时附加数据增强效果。

这里的item指的是”图像块“。

输入和输出:
输入:
  • i:所取位置的索引
输出:
  • data(Batch x) Planes x Channels x Width x Height
  • label:标签
代码和解析:
获取图像块:
def __getitem__(self, i):
    x, y = self.indices[i]
    x1, y1 = x - self.patch_size // 2, y - self.patch_size // 2
    x2, y2 = x1 + self.patch_size, y1 + self.patch_size
    
    data = self.data[x1:x2, y1:y2]
    label = self.label[x1:x2, y1:y2]

这部分是得到以指定位置i为中心,size为patch_size × patch_size图像块block

大概是这样的原理:

x2, y2
x, y
x1, y1

由x, y分别减去self.patch_size // 2得到x1, y1,然后由x1, y1加上self.patch_size得到x2, y2。

然后通过[x1:x2, y1:y2]来获取对应的datalabel图像块

数据增强:
if self.flip_augmentation and self.patch_size > 1:
    # Perform data augmentation (only on 2D patches)
    data, label = self.flip(data, label)
if self.radiation_augmentation and np.random.random() < 0.1:
        data = self.radiation_noise(data)
if self.mixture_augmentation and np.random.random() < 0.2:
        data = self.mixture_noise(data, label)

这里数据增强并不是默认执行,而是需要self.flip_augmentation == True或者self.radiation_augmentation== True或者self.mixture_augmentation== True

对于self.flip_augmentation == True的情况,需要self.patch_size > 1,而且仅仅对2D的data执行。

对于self.radiation_augmentation== True的情况,需要np.random.random() < 0.1,即只有10%的概率执行操作。

对于self.mixture_augmentation== True的情况,需要np.random.random() < 0.2,即只有20%的概率执行操作。

datalabel转为ndarray类型:
# Copy the data into numpy arrays (PyTorch doesn't like numpy views)
data = np.asarray(np.copy(data).transpose((2, 0, 1)), dtype='float32')
label = np.asarray(np.copy(label), dtype='int64')

这部分将datalabel转为ndarray类型。来源类型暂不明确。

PyTorch doesn’t like numpy views。numpy的view是channel × row × column,即C × W × H。所以要通过转置将第三个维度提到第一个维度的位置。

numpy的view是C × W × H,而PyTorchW × H × C

data = np.asarray(np.copy(data).transpose((2, 0, 1)), dtype='float32')这一句先拷贝data,作为ndarray类型,再调用转置函数transpose()将维度顺序从W × H × C变为C × W × H。最后由dtype='float32'将每个元素的类型设置为float32

对于label,就直接拷贝变ndarray,再设置元素类型为int64,就Ok了。

ndarraytensor
# Load the data into PyTorch tensors
data = torch.from_numpy(data)
label = torch.from_numpy(label)

简单的应用torch.from_numpy(),没什么好讲的。

Extract the center label if needed:
# Extract the center label if needed
if self.center_pixel and self.patch_size > 1:
    label = label[self.patch_size // 2, self.patch_size // 2]
# Remove unused dimensions when we work with invidual spectrums
elif self.patch_size == 1:
    data = data[:, 0, 0]
    label = label[0, 0]

对于self.center_pixel == Trueself.patch_size > 1,对于size为self.patch_size × self.patch_sizelabel只取[self.patch_size // 2, self.patch_size // 2]的位置。

好叭这部分暂略,看不懂。

# Add a fourth dimension for 3D CNN
if self.patch_size > 1:
    # Make 4D data ((Batch x) Planes x Channels x Width x Height)
    data = data.unsqueeze(0)

这部分是给data在第一个维度的位置增加一个维度,作为Batch,即最后的维度顺序是Batch x Channels x Width x Height

增加维度的是unsqueeze()函数,axis = 0 表示在第一个维度的位置增加维度。

返回值:
return data, label

返回值为data, label。类型为tensor

  |      | x2, y2 |

| :----: | :–: | :----: |
| | x, y | |
| x1, y1 | | |

由x, y分别减去self.patch_size // 2得到x1, y1,然后由x1, y1加上self.patch_size得到x2, y2。

然后通过[x1:x2, y1:y2]来获取对应的datalabel图像块

数据增强:
if self.flip_augmentation and self.patch_size > 1:
    # Perform data augmentation (only on 2D patches)
    data, label = self.flip(data, label)
if self.radiation_augmentation and np.random.random() < 0.1:
        data = self.radiation_noise(data)
if self.mixture_augmentation and np.random.random() < 0.2:
        data = self.mixture_noise(data, label)

这里数据增强并不是默认执行,而是需要self.flip_augmentation == True或者self.radiation_augmentation== True或者self.mixture_augmentation== True

对于self.flip_augmentation == True的情况,需要self.patch_size > 1,而且仅仅对2D的data执行。

对于self.radiation_augmentation== True的情况,需要np.random.random() < 0.1,即只有10%的概率执行操作。

对于self.mixture_augmentation== True的情况,需要np.random.random() < 0.2,即只有20%的概率执行操作。

datalabel转为ndarray类型:
# Copy the data into numpy arrays (PyTorch doesn't like numpy views)
data = np.asarray(np.copy(data).transpose((2, 0, 1)), dtype='float32')
label = np.asarray(np.copy(label), dtype='int64')

这部分将datalabel转为ndarray类型。来源类型暂不明确。

PyTorch doesn’t like numpy views。numpy的view是channel × row × column,即C × W × H。所以要通过转置将第三个维度提到第一个维度的位置。

numpy的view是C × W × H,而PyTorchW × H × C

data = np.asarray(np.copy(data).transpose((2, 0, 1)), dtype='float32')这一句先拷贝data,作为ndarray类型,再调用转置函数transpose()将维度顺序从W × H × C变为C × W × H。最后由dtype='float32'将每个元素的类型设置为float32

对于label,就直接拷贝变ndarray,再设置元素类型为int64,就Ok了。

ndarraytensor
# Load the data into PyTorch tensors
data = torch.from_numpy(data)
label = torch.from_numpy(label)

简单的应用torch.from_numpy(),没什么好讲的。

Extract the center label if needed:
# Extract the center label if needed
if self.center_pixel and self.patch_size > 1:
    label = label[self.patch_size // 2, self.patch_size // 2]
# Remove unused dimensions when we work with invidual spectrums
elif self.patch_size == 1:
    data = data[:, 0, 0]
    label = label[0, 0]

对于self.center_pixel == Trueself.patch_size > 1,对于size为self.patch_size × self.patch_sizelabel只取[self.patch_size // 2, self.patch_size // 2]的位置。

好叭这部分暂略,看不懂。

# Add a fourth dimension for 3D CNN
if self.patch_size > 1:
    # Make 4D data ((Batch x) Planes x Channels x Width x Height)
    data = data.unsqueeze(0)

这部分是给data在第一个维度的位置增加一个维度,作为Batch,即最后的维度顺序是Batch x Channels x Width x Height

增加维度的是unsqueeze()函数,axis = 0 表示在第一个维度的位置增加维度。

返回值:
return data, label

返回值为data, label。类型为tensor

你可能感兴趣的:(开源项目使用,高光谱图像分类,GitHub,开源项目)