[高光谱] 开源项目Hyperspectral-Classification Pytorch解析之main

开源项目Hyperspectral-Classification Pytorch解析之main.py

编码方式:

# -*- coding: utf-8 -*-

项目简介:

"""
DEEP LEARNING FOR HYPERSPECTRAL DATA.

This script allows the user to run several deep models (and SVM baselines)
against various hyperspectral datasets. It is designed to quickly benchmark
state-of-the-art CNNs on various public hyperspectral datasets.

This code is released under the GPLv3 license for non-commercial and research
purposes only.
For commercial use, please contact the authors.
"""

项目简介,汉语版为:

"""
应用于高光谱数据的深度学习。

该脚本允许用户运行多个深层模型(和SVM baselines)
针对各种高光谱数据集。 它旨在快速进行基准测试
各种公共高光谱数据集上最先进的CNN。

此代码根据GPLv3许可证发布,用于非商业和研究
目的而已。
如需商业用途,请联系作者。
"""

###解决兼容:

# Python 2/3 compatiblity
from __future__ import print_function
from __future__ import division

用于解决python2和python3的兼容问题。


增加修改搜索路径:

# 自加,增加修改路径
import sys,os
curPath = os.path.abspath(os.path.dirname(__file__))
print(os.path.dirname(__file__))
# C:/Users/73416/PycharmProjects/HSIproject
print(curPath)
# C:\Users\73416\PycharmProjects\HSIproject
rootPath = os.path.split(curPath)[0]
print(os.path.split(curPath))
# ('C:\\Users\\73416\\PycharmProjects', 'HSIproject')
print(rootPath)
# C:\Users\73416\PycharmProjects
print(sys.path)
# ['C:\\Users\\73416\\PycharmProjects\\HSIproject',
#  'E:\\PyCharm 2018.3.4\\helpers\\pydev',
#  'C:\\Users\\73416\\PycharmProjects\\HSIproject',
#  'E:\\PyCharm 2018.3.4\\helpers\\third_party\\thriftpy',
#  'E:\\PyCharm 2018.3.4\\helpers\\pydev',
#  'E:\\Anaconda\\python37.zip',
#  'E:\\Anaconda\\DLLs',
#  'E:\\Anaconda\\lib',
#  'E:\\Anaconda',
#  'E:\\Anaconda\\lib\\site-packages',
#  'E:\\Anaconda\\lib\\site-packages\\win32',
#  'E:\\Anaconda\\lib\\site-packages\\win32\\lib',
#  'E:\\Anaconda\\lib\\site-packages\\Pythonwin',
#  'E:\\PyCharm 2018.3.4\\helpers\\pycharm_matplotlib_backend',
#  'E:\\Anaconda\\lib\\site-packages\\IPython\\extensions',
#  'C:\\Users\\73416\\PycharmProjects\\HSIproject',
#  'C:/Users/73416/PycharmProjects/HSIproject']
sys.path.append(rootPath)
sys.path.append('E:\\Anaconda\\lib\\site-packages\\')
print(sys.path)
# ['C:\\Users\\73416\\PycharmProjects\\HSIproject',
#  'E:\\PyCharm 2018.3.4\\helpers\\pydev',
#  'C:\\Users\\73416\\PycharmProjects\\HSIproject',
#  'E:\\PyCharm 2018.3.4\\helpers\\third_party\\thriftpy',
#  'E:\\PyCharm 2018.3.4\\helpers\\pydev',
#  'E:\\Anaconda\\python37.zip',
#  'E:\\Anaconda\\DLLs',
#  'E:\\Anaconda\\lib',
#  'E:\\Anaconda',
#  'E:\\Anaconda\\lib\\site-packages',
#  'E:\\Anaconda\\lib\\site-packages\\win32',
#  'E:\\Anaconda\\lib\\site-packages\\win32\\lib',
#  'E:\\Anaconda\\lib\\site-packages\\Pythonwin',
#  'E:\\PyCharm 2018.3.4\\helpers\\pycharm_matplotlib_backend',
#  'E:\\Anaconda\\lib\\site-packages\\IPython\\extensions',
#  'C:\\Users\\73416\\PycharmProjects\\HSIproject',
#  'C:/Users/73416/PycharmProjects/HSIproject',
#  'C:\\Users\\73416\\PycharmProjects',
#  'E:\\Anaconda\\lib\\site-packages\\']
# 自加,修改路径的代码
C:\Users\73416\PycharmProjects\HSIproject

C:\Users\73416\PycharmProjects\HSIproject

('C:\Users\73416\PycharmProjects', 'HSIproject')

C:\Users\73416\PycharmProjects

['C:\Users\73416\PycharmProjects\HSIproject', 'E:\Python37\python37.zip', 'E:\Python37\DLLs', 'E:\Python37\lib', 'E:\Python37', 'E:\Python37\lib\site-packages', 'E:\Python37\lib\site-packages\win32', 'E:\Python37\lib\site-packages\win32\lib', 'E:\Python37\lib\site-packages\Pythonwin']

['C:\Users\73416\PycharmProjects\HSIproject', 'E:\Python37\python37.zip', 'E:\Python37\DLLs', 'E:\Python37\lib', 'E:\Python37', 'E:\Python37\lib\site-packages', 'E:\Python37\lib\site-packages\win32', 'E:\Python37\lib\site-packages\win32\lib', 'E:\Python37\lib\site-packages\Pythonwin', 'C:\Users\73416\PycharmProjects', 'E:\Anaconda\lib\site-packages\']

上面的结果是使用pycharm运行的结果,下面是使用命令行运行的结果。

项目的正确运行方式是使用命令行窗口运行。

需要注意的是,尽管不报关于函数库的错,但二者运行过程中使用的包的版本不同。一个是E:\\Anaconda的路径下,一个是E:\Python37的路径下。

想要调用项目文件路径下的其他.py文件,需要在默认的搜索路径sys.path中加入rootPath(C:\Users\73416\PycharmProjects)。

加上sys.path.append('E:\\Anaconda\\lib\\site-packages\\')这一句是因为命令行搜索路径下没有torch库,所以使用Anaconda路径下的torch库。

(函数库的下载路径,使用的方式不同也有所不同。使用默认的pip,下载路径在E:\Python37下;使用Anaconda的pip或直接使用conda命令,下载路径在E:\\Anaconda下。)


import模块:

# Torch
import torch
import torch.utils.data as data
from torchsummary import summary

# Numpy, scipy, scikit-image, spectral
import numpy as np
import sklearn.svm
import sklearn.model_selection
from skimage import io

# Visualization
import seaborn as sns
import visdom

import os

# import 自定义模块
from utils import metrics, convert_to_color_, convert_from_color_,\
    display_dataset, display_predictions, explore_spectrums, plot_spectrums,\
    sample_gt, build_dataset, show_results, compute_imf_weights, get_device
from datasets import get_dataset, HyperX, open_file, DATASETS_CONFIG
from models import get_model, train, test, save_model

# 命令行解析器
import argparse

这里想说一点项目文件的组织形式(恕我是新手)。

main.py 一般是放的项目的主程序,整个代码的过程大概就是运行的流程

utils.py 一般是存放一些自己写的函数,在main.py被调用。

datasets.py 存放的是于数据集有关的函数、类……

models.py 是用到的模型和相关内容。

整体的意思,在 main.py 中不去定义函数,只调用函数。需要的自定义函数分门别类地写在其他的 .py 文件中。这样项目就有很好的组织形式。


获取数据集名称:

# 获取DATASETS_CONFIG中的dataset_names,操作的对字典的操作。
dataset_names = [v['name'] if 'name' in v.keys() else k for k, v in DATASETS_CONFIG.items()]

这一部分是得到数据集的名字,以列表list的形式储存。

数据集的名字既包含预定义的数据集的名字,也包括自定义的数据集的名字。原因是在 datasets.pyDATASETS_CONFIG有一个try - except的部分,通过字典更新update()方法将自定义数据集字典CUSTOM_DATASETS_CONFIG加入到DATASETS_CONFIG中。


命令行解析器:

详细可见另一篇博客:argparse命令行解析器的使用。


创建解析器:
parser = argparse.ArgumentParser(description="Run deep learning experiments on"
                                             " various hyperspectral datasets")

argparse.ArgumentParser()创建解析器,description为帮助信息。


读取参数:
未分组:
parser.add_argument('--dataset', type=str, default=None, choices=dataset_names,
                    help="Dataset to use.")
  • 参数名:dataset
  • 类型:字符串string
  • 缺省值:None
  • 可选择范围:dataset_names
  • 帮助信息:"Dataset to use."
parser.add_argument('--model', type=str, default=None,
                    help="Model to train. Available:\n"
                    "SVM (linear),\n "
                    "SVM_grid (grid search on linear, poly and RBF kernels), \n"
                    "baseline (fully connected NN), \n"
                    "hu (1D CNN), \n"
                    "hamida (3D CNN + 1D classifier), \n"
                    "lee (3D FCN), \n"
                    "chen (3D CNN), \n"
                    "li (3D CNN), \n"
                    "he (3D CNN), \n"
                    "luo (3D CNN), \n"
                    "sharma (2D CNN), \n"
                    "boulch (1D semi-supervised CNN), \n"
                    "liu (3D semi-supervised CNN), \n"
                    "mou (1D RNN)")
  • 参数名:model
  • 类型:字符串string
  • 缺省值:None
  • 帮助信息:"Model to train. Available:\n", ……
parser.add_argument('--folder', type=str, help="Folder where to store the "
                    "datasets (defaults to the current working directory).",
                    default="./Datasets/")
  • 参数名:folder
  • 类型:字符串string
  • 缺省值:"./Datasets/"
  • 帮助信息:"Folder where to store the " ……
parser.add_argument('--cuda', type=int, default=-1,
                    help="Specify CUDA device (defaults to -1, which learns on CPU)")
  • 参数名:cuda
  • 类型:整型int
  • 缺省值:-1
  • 帮助信息:"Specify CUDA device (defaults to -1, which learns on CPU)"
parser.add_argument('--runs', type=int, default=1, help="Number of runs (default: 1)")
  • 参数名:runs
  • 类型:整型int
  • 缺省值:1
  • 帮助信息:"Number of runs (default: 1)"
parser.add_argument('--restore', type=str, default=None,
                    help="Weights to use for initialization, e.g. a checkpoint")
  • 参数名:restore
  • 类型:字符串string
  • 缺省值:None
  • 帮助信息:"Weights to use for initialization, e.g. a checkpoint"
Dataset组:
group_dataset = parser.add_argument_group('Dataset')

这一行将下面增加的参数归为一组,大概的效果是这样:

Dataset:
  --training_sample ……
  --sampling_mode ……
  --train_set ……
  --test_set ……
group_dataset.add_argument('--training_sample', type=float, default=10,
                    help="Percentage of samples to use for training (default: 10%%)")
  • 参数名:training_sample
  • 类型:浮点型float
  • 缺省值:10
  • 帮助信息:"Percentage of samples to use for training (default: 10%%)"
group_dataset.add_argument('--sampling_mode', type=str, help="Sampling mode"
                    " (random sampling or disjoint, default: random)",
                    default='random')
  • 参数名:sampling_mode
  • 类型:字符串string
  • 缺省值:'random'
  • 帮助信息:"Sampling mode (random sampling or disjoint, default: random)"
group_dataset.add_argument('--train_set', type=str, default=None,
                    help="Path to the train ground truth (optional, this "
                    "supersedes(取代版本) the --sampling_mode option)")
  • 参数名:train_set
  • 类型:字符串string
  • 缺省值:None
  • 帮助信息:""Path to the train ground truth (optional, this supersedes(取代版本) the --sampling_mode option)"

需要额外说明的是,这个train_set参数是参数sampling_mode的取代版本。

group_dataset.add_argument('--test_set', type=str, default=None,
                    help="Path to the test set (optional, by default "
                    "the test_set is the entire ground truth minus the training)")
  • 参数名:test_set
  • 类型:字符串string
  • 缺省值:None
  • 帮助信息:"Path to the test set (optional, by default the test_set is the entire ground truth minus the training) "

需要额外说明的是,test_set 默认是整个 ground truth 减去 the training。

Training组:
group_train = parser.add_argument_group('Training')

这一行将下面增加的参数归为一组,大概的效果是这样:

Training:
  --epoch EPOCH         ……
  --patch_size PATCH_SIZE	……
  --lr LR               ……
  --class_balancing     ……
  --batch_size BATCH_SIZE	……                       
  --test_stride TEST_STRIDE	……
group_train.add_argument('--epoch', type=int, help="Training epochs (optional, if"
                    " absent will be set by the model)")
  • 参数名:epoch
  • 类型:整型int
  • 帮助信息:"Training epochs (optional, if absent will be set by the model)"

需要额外说明的是,参数epoch一般不特别指定,而是默认使用模型中预定义的值。

group_train.add_argument('--patch_size', type=int,
                    help="Size of the spatial neighbourhood 空间邻域的大小 (optional, if "
                    "absent will be set by the model)")
  • 参数名:patch_size
  • 类型:整型int
  • 帮助信息:"Size of the spatial neighbourhood 空间邻域的大小 (optional, if absent will be set by the model)"

需要额外说明的是,参数patch_size一般不特别指定,而是默认使用模型中预定义的值。

这里想额外写一点关于patch的东西。patch的来源是图像分块处理。一般在实际应用中为比较小的size(3 × 3或5 × 5)。从之前看到的(2019.9.1)一般是先指定中心位置,然后再得到以这个位置为中心的patch。所以这是空间上下文信息!!!

group_train.add_argument('--lr', type=float,
                    help="Learning rate, set by the model if not specified.")
  • 参数名:lr
  • 类型:浮点型float
  • 帮助信息:"Learning rate, set by the model if not specified."
group_train.add_argument('--class_balancing', action='store_true',
                    help="Inverse median frequency class balancing (default = False)")
  • 参数名:class_balancing
  • 行为:'store_true'
  • 帮助信息:"Inverse median frequency class balancing (default = False)",反中值频率类平衡?
group_train.add_argument('--batch_size', type=int,
                    help="Batch size (optional, if absent will be set by the model")
  • 参数名:batch_size
  • 类型:整型int
  • 帮助信息:"Batch size (optional, if absent will be set by the model"

需要额外说明的是,参数batch_size一般不特别指定,而是默认使用模型中预定义的值。

group_train.add_argument('--test_stride', type=int, default=1,
                     help="Sliding window step stride during inference (default = 1)")
  • 参数名:test_stride
  • 类型:整型int
  • 缺省值:1
  • 帮助信息:"Sliding window step stride during inference (default = 1)"

在 inference 期间滑动窗口步幅

Data augmentation组:
group_da.add_argument('--flip_augmentation', action='store_true',
                    help="Random flips (if patch_size > 1)")
  • 参数名:flip_augmentation
  • 行为:'store_true'
  • 帮助信息:"Random flips (if patch_size > 1)"
group_da.add_argument('--radiation_augmentation', action='store_true',
                    help="Random radiation noise (illumination)")
  • 参数名:radiation_augmentation
  • 行为:'store_true'
  • 帮助信息:"Random radiation noise (illumination)"
group_da.add_argument('--mixture_augmentation', action='store_true',
                    help="Random mixes between spectra")
  • 参数名:mixture_augmentation
  • 行为:'store_true'
  • 帮助信息:"Random mixes between spectra"
parser.add_argument('--with_exploration', action='store_true',
                    help="See data exploration visualization")
  • 参数名:with_exploration
  • 行为:'store_true'
  • 帮助信息:"See data exploration visualization"
parser.add_argument('--download', type=str, default=None, nargs='+',
                    choices=dataset_names,
                    help="Download the specified datasets and quits.")
  • 参数名:download
  • 类型:字符串string
  • 缺省值:None
  • nargs+
  • 选择范围:dataset_names
  • 帮助信息:""Download the specified datasets and quits."

nargs:默认情况下 ArgumentParser对象将参数与一个与action一对一关联,通过指定 nargs可以将多个参数与一个action相关联。


解析参数:
args = parser.parse_args()

ArgumentParser通过该parse_args()方法解析参数 。这将检查命令行,将每个参数转换为适当的类型,然后调用相应的操作。在大多数情况下,这意味着Namespace将从命令行解析的属性构建一个简单的对象注意,函数返回的是对象类型,参数值为对象的属性!

在脚本中,parse_args()通常会调用不带参数。

以运行这个命令为例:

python C:\Users\73416\PycharmProjects\HSIproject\main.py --model nn --dataset PaviaU --training_sample 0.1 --cuda 0

得到的args如下:

Namespace(batch_size=None, class_balancing=False, cuda=0, dataset='PaviaU', download=None, epoch=None, flip_augmentation=False, folder='./Datasets/', lr=None, mixture_augmentation=False, model='nn', patch_size=None, radiation_augmentation=False, restore=None, runs=1, sampling_mode='random', test_set=None, test_stride=1, train_set=None, training_sample=0.1, with_exploration=False)

可以看到args是字典dictionary类型,包含超参数的键值对。


操作参数:
CUDA_DEVICE

这一部分是通过访问args中的,从而将赋给相应的变量,并进行后续操作。

CUDA_DEVICE = get_device(args.cuda)

get_device()是一个在 utlis.py 中自定义的函数,用来根据输入的不同选择对应的CUDA_DEVICE

SAMPLE_PERCENTAGE
SAMPLE_PERCENTAGE = args.training_sample

用来确定 training samplesentire samples 的占比。

Data augmentation
FLIP_AUGMENTATION = args.flip_augmentation
RADIATION_AUGMENTATION = args.radiation_augmentation
MIXTURE_AUGMENTATION = args.mixture_augmentation

根据FLIP_AUGMENTATIONRADIATION_AUGMENTATIONMIXTURE_AUGMENTATIONbool 类型的真假,选择是否采用对应的数据增强的方法。

这三个值都默认为False

Dataset name
DATASET = args.dataset

获取数据集名称,存放在变量DATASET中。

Model name
MODEL = args.model

获取模型名称,存放在变量MODEL中。

Number of runs
N_RUNS = args.runs

获取运行次数,存放在变量N_RUNS中。

Spatial context size

空间上下文大小(每个空间方向上的邻居数)

PATCH_SIZE = args.patch_size

patch的来源是图像分块处理。一般在实际应用中为比较小的size(3 × 3或5 × 5)。从之前看到的(2019.9.1)一般是先指定中心位置,然后再得到以这个位置为中心的patch。所以一个patch就是中心的目标样本 + 其空间上下文信息

Add spectra visualization
DATAVIZ = args.with_exploration

根据DATAVIZbool 类型的真假,选择是否增加光谱可视化。默认为False

Target folder
FOLDER = args.folder

存放/下载/加载数据集的目标文件夹(Target folder to store/download/load the datasets)

Number of epochs
EPOCH = args.epoch

一个epoch就是把整个数据集过一遍。

Sampling mode

采样模式。

SAMPLING_MODE = args.sampling_mode
Pre-computed weights

Pre-computed weights to restore。要恢复的预先计算的权重。

CHECKPOINT = args.restore
Learning rate

随机梯度下降的学习率(Learning rate for the SGD)

LEARNING_RATE = args.lr
class balancing

不懂什么意思,暂略。

CLASS_BALANCING = args.class_balancing
Training ground truth file
TRAIN_GT = args.train_set

args中并没有train_set,为什么?怎么办?

Testing ground truth file
TEST_GT = args.test_set

args中并没有test_set,为什么?怎么办?

test stride

Sliding window step stride during inference.

TEST_STRIDE = args.test_stride

???

if args.download is not None and len(args.download) > 0:
    for dataset in args.download:
        get_dataset(dataset, target_folder=FOLDER)
    quit()

设置visdom环境

viz = visdom.Visdom(env=DATASET + ' ' + MODEL)  # 设置visdom环境
if not viz.check_connection:                    # 检测与visdom服务器的连接
    print("Visdom is not connected. Did you run 'python -m visdom.server' ?")

这部分是设置visdom环境,将环境名env设置为DATASET + ' ' + MODEL

并检查与visdom服务器的连接,若没能连接则打印报错信息。


加载数据集

# Load the dataset          # 加载数据集
img, gt, LABEL_VALUES, IGNORED_LABELS, RGB_BANDS, palette = get_dataset(DATASET, FOLDER)
"""
img: 3D hyperspectral image (WxHxB)
gt: 2D int array of labels
label_values: list of class names
ignored_labels: list of int classes to ignore
rgb_bands: int tuple that correspond to red, green and blue bands
"""

get_dataset()定义在datasets.py中,包括数据集下载读取预处理的操作。

给出指定的DATASETFOLDER ,就可以得到img(WxHxBands)、gt(2D int array of labels)、LABEL_VALUESlist of class names)、IGNORED_LABELSlist of int classes to ignore)和RGB_BANDSint tuple that correspond to red, green and blue bands)


获取类数和波段数

# Number of classes
N_CLASSES = len(LABEL_VALUES)
# Number of bands (last dimension of the image tensor)
N_BANDS = img.shape[-1]

类数N_CLASSES就是LABEL_VALUES的长度。其中LABEL_VALUES为:

['Undefined', 'Asphalt', 'Meadows', 'Gravel', 'Trees', 'Painted metal sheets', 'Bare Soil', 'Bitumen', 'Self-Blocking Bricks', 'Shadows']

波段数N_BANDSimg的最后一个维度 shape[-1]


SVM参数设定

# Parameters for the SVM grid search    SVM参数设定
SVM_GRID_PARAMS = [{'kernel': ['rbf'], 'gamma': [1e-1, 1e-2, 1e-3],
                                       'C': [1, 10, 100, 1000]},
                   {'kernel': ['linear'], 'C': [0.1, 1, 10, 100, 1000]},
                   {'kernel': ['poly'], 'degree': [3], 'gamma': [1e-1, 1e-2, 1e-3]}]

因为不用SVM,所以这部分暂略。


设置palette

# mport numpy as np
# import seaborn as sns

# palette = None
# LABEL_VALUES = ['Undefined', 'Asphalt', 'Meadows', 'Gravel', 'Trees', 'Painted metal sheets', # 'Bare Soil', 'Bitumen', 'Self-Blocking Bricks', 'Shadows']

if palette is None:     # 调色板
    # Generate color palette
    palette = {0: (0, 0, 0)}
    for k, color in enumerate(sns.color_palette("hls", len(LABEL_VALUES) - 1)):
        palette[k + 1] = tuple(np.asarray(255 * np.array(color), dtype='uint8'))
invert_palette = {v: k for k, v in palette.items()}

# print(palette)
# # {0: (0, 0, 0), 1: (219, 94, 86), 2: (219, 183, 86), 3: (167, 219, 86), 4: (86, 219, 94), 5: (86, 219, 183), 6: (86, 167, 219), 7: (94, 86, 219), 8: (183, 86, 219), 9: (219, 86, 167)}
# print(invert_palette)
# # {(0, 0, 0): 0, (219, 94, 86): 1, (219, 183, 86): 2, (167, 219, 86): 3, (86, 219, 94): 4, (86, 219, 183): 5, (86, 167, 219): 6, (94, 86, 219): 7, (183, 86, 219): 8, (219, 86, 167): 9}

注释起来的是测试代码。

首先说一下数目的问题。代码首先通过palette = {0: (0, 0, 0)}将类别为0ignored_labels)的调色板直接置为(0, 0, 0),作为背景。之后对剩下的9个类别(非ignored_labels),再分别产生对应的颜色,所以参数是len(LABEL_VALUES) - 1

hls是一种颜色空间,这是RGB值的简单转换

这部分代码的整体的流程是:当paletteNone的时候,执行后续的生成palette的操作。之后首先将类别为0ignored_labels)的调色板直接置为(0, 0, 0),作为背景。然后通过enumerate(sns.color_palette("hls", len(LABEL_VALUES) - 1))构建颜色空间为hls,数目为len(LABEL_VALUES) - 1的三元素元组的索引序列。然后通过for循环遍历这个索引序列,palette[k + 1] = tuple(np.asarray(255 * np.array(color), dtype='uint8')),获得不同的类别的palette值。

之后遍历palette的时候调转键值对,得到invert_paletteinvert_palette = {v: k for k, v in palette.items()}


定义类别颜色转换函数

def convert_to_color(x):
    return convert_to_color_(x, palette=palette)
def convert_from_color(x):
    return convert_from_color_(x, palette=invert_palette)

xint 2D array of labels,标签的二维矩阵。

paletteinvert_palette在上面定义过。

这两个函数就是类别和RGB颜色相互转换的函数。其实说到底就是类别显示颜色的对应。


在超参数中更新类数波段数等

# Instantiate the experiment based on predefined networks   根据预定义的网络实例化实验
hyperparams.update({'n_classes': N_CLASSES, 'n_bands': N_BANDS, 'ignored_labels': IGNORED_LABELS, 'device': CUDA_DEVICE})
hyperparams = dict((k, v) for k, v in hyperparams.items() if v is not None)     # 遍历hyperparams将键值对再变成字典类型

类数N_CLASSES、波段数N_BANDSIGNORED_LABELS在加载数据集(和后续操作)就得到了。

CUDA_DEVICE在命令行解析器的部分就读取了。

这里是把类数N_CLASSES、波段数N_BANDSIGNORED_LABELSCUDA_DEVICE这4个超参数的值,通过update()更新到超参数字典hyperparams中。

然后再通过遍历将超参数字典hyperparams中的值为None的键值对筛选掉。


visdom中展示img + gt --> color

# Show the image and the ground truth
display_dataset(img, gt, RGB_BANDS, LABEL_VALUES, palette, viz)
color_gt = convert_to_color(gt)

这部分是 Show the image and the ground truth,其中 ground truth 由颜色的不同来表示类别的不同。

其中visdom中展示的img如下图:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-GOeGAzvc-1570353563159)(C:\Users\73416\Desktop\visdom_image (1)].jpg)


Data exploration

if DATAVIZ:     # ???
    # Data exploration : compute and show the mean spectrums
    mean_spectrums = explore_spectrums(img, gt, LABEL_VALUES, viz,
                                       ignored_labels=IGNORED_LABELS)
    plot_spectrums(mean_spectrums, viz, title='Mean spectrum/class')

DATAVIZ默认为False,所以这段代码也是默认不执行。

暂略。


初始化结果列表 result

results = []

这里是将存储最终结果的变量result初始化为一个空的列表list

后续通过append()方法来将结果加入到列表result中。


run the experiment

# run the experiment several times
for run in range(N_RUNS):
    ……

run the experiment several times. 默认是1次。


得到 train_gt 和 test_gt

if TRAIN_GT is not None and TEST_GT is not None:
    train_gt = open_file(TRAIN_GT)
    test_gt = open_file(TEST_GT)
elif TRAIN_GT is not None:
    train_gt = open_file(TRAIN_GT)
    test_gt = np.copy(gt)
    w, h = test_gt.shape
    test_gt[(train_gt > 0)[:w,:h]] = 0
elif TEST_GT is not None:
    test_gt = open_file(TEST_GT)
else:
# Sample random training spectra        随机训练光谱样本(有训练集,有测试集)
    train_gt, test_gt = sample_gt(gt, SAMPLE_PERCENTAGE, mode=SAMPLING_MODE)
print("{} samples selected (over {})".format(np.count_nonzero(train_gt),
                                             np.count_nonzero(gt)))

在这次运行中,在执行这段代码之前,TRAIN_GTTEST_GT都是None

所以默认执行的都是最后一个分支

else:
# Sample random training spectra        随机训练光谱样本(有训练集,有测试集)
    train_gt, test_gt = sample_gt(gt, SAMPLE_PERCENTAGE, mode=SAMPLING_MODE)

sample_gt()是从标签数组gt中提取固定百分比SAMPLE_PERCENTAGE的样本(Extract a fixed percentage of samples from an array of labels)。

需要强调的是,被分割为训练集和测试集的样本,不包括类别为ignored_labels的sample。

print("{} samples selected (over {})".format(np.count_nonzero(train_gt),
                                             np.count_nonzero(gt)))
# 4277 samples selected (over 42776)

这一句打印划分train_gttest_gt的划分结果。

可以看到结果为从 entire sample(42776个sample) 中选取了4277个sample作为train_gt。这与SAMPLE_PERCENTAGE在这里被设定为0.1相对应。


打印run {}/{}

print("Running an experiment with the {} model".format(MODEL),
      "run {}/{}".format(run + 1, N_RUNS))

在由“run the experiment”进入循环后,打印这是第几次run

run + 1的原因是,在进入for循环的时候for run in range(N_RUNS):run是从0开始计数,但运行次数应该从1开始计数。


选择模型并开始训练

非神经网络模型如SVM
if MODEL == 'SVM_grid':
    print("Running a grid search SVM")
    # Grid search SVM (linear and RBF)
    X_train, y_train = build_dataset(img, train_gt,
                                     ignored_labels=IGNORED_LABELS)
    class_weight = 'balanced' if CLASS_BALANCING else None
    clf = sklearn.svm.SVC(class_weight=class_weight)
    clf = sklearn.model_selection.GridSearchCV(clf, SVM_GRID_PARAMS, verbose=5, n_jobs=4)
    clf.fit(X_train, y_train)
    print("SVM best parameters : {}".format(clf.best_params_))
    prediction = clf.predict(img.reshape(-1, N_BANDS))
    save_model(clf, MODEL, DATASET)
    prediction = prediction.reshape(img.shape[:2])
elif MODEL == 'SVM':
    X_train, y_train = build_dataset(img, train_gt,
                                     ignored_labels=IGNORED_LABELS)
    class_weight = 'balanced' if CLASS_BALANCING else None
    clf = sklearn.svm.SVC(class_weight=class_weight)
    clf.fit(X_train, y_train)
    save_model(clf, MODEL, DATASET)
    prediction = clf.predict(img.reshape(-1, N_BANDS))
    prediction = prediction.reshape(img.shape[:2])
elif MODEL == 'SGD':
    X_train, y_train = build_dataset(img, train_gt,
                                     ignored_labels=IGNORED_LABELS)
    X_train, y_train = sklearn.utils.shuffle(X_train, y_train)
    scaler = sklearn.preprocessing.StandardScaler()
    X_train = scaler.fit_transform(X_train)
    class_weight = 'balanced' if CLASS_BALANCING else None
    clf = sklearn.linear_model.SGDClassifier(class_weight=class_weight, learning_rate='optimal', tol=1e-3, average=10)
    clf.fit(X_train, y_train)
    save_model(clf, MODEL, DATASET)
    prediction = clf.predict(scaler.transform(img.reshape(-1, N_BANDS)))
    prediction = prediction.reshape(img.shape[:2])
elif MODEL == 'nearest':
    X_train, y_train = build_dataset(img, train_gt,
                                     ignored_labels=IGNORED_LABELS)
    X_train, y_train = sklearn.utils.shuffle(X_train, y_train)
    class_weight = 'balanced' if CLASS_BALANCING else None
    clf = sklearn.neighbors.KNeighborsClassifier(weights='distance')
    clf = sklearn.model_selection.GridSearchCV(clf, {'n_neighbors': [1, 3, 5, 10, 20]}, verbose=5, n_jobs=4)
    clf.fit(X_train, y_train)
    clf.fit(X_train, y_train)
    save_model(clf, MODEL, DATASET)
    prediction = clf.predict(img.reshape(-1, N_BANDS))
    prediction = prediction.reshape(img.shape[:2])

这部分都是非神经网络的模型,不是我的方向,所以暂略。


神经网络的模型
else:
    # Neural network
    ……
获得所选模型的基本信息
model, optimizer, loss, hyperparams = get_model(MODEL, **hyperparams)

超参数包括:模型model,优化器optimizer,损失函数loss,超参数hyperparams

其中hyperparams中是超参数的字典,包括epochbatch_sizepatch_size等。

最后打印的modeloptimizerlosshyperparams的信息如下:

Baseline(
  (fc1): Linear(in_features=103, out_features=2048, bias=True)
  (fc2): Linear(in_features=2048, out_features=4096, bias=True)
  (fc3): Linear(in_features=4096, out_features=2048, bias=True)
  (fc4): Linear(in_features=2048, out_features=10, bias=True)
)

Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    eps: 1e-08
    lr: 0.0001
    weight_decay: 0
)

CrossEntropyLoss()

{'dataset': 'PaviaU', 'model': 'nn', 'folder': './Datasets/', 'cuda': 0, 'runs': 1, 'training_sample': 0.1, 'sampling_mode': 'random', 'class_balancing': False, 'test_stride': 1, 'flip_augmentation': False, 'radiation_augmentation': False, 'mixture_augmentation': False, 'with_exploration': False, 'n_classes': 10, 'n_bands': 103, 'ignored_labels': [0], 'device': device(type='cuda', index=0), 'weights': tensor([0., 1., 1., 1., 1., 1., 1., 1., 1., 1.], device='cuda:0'), 'patch_size': 1, 'dropout': False, 'learning_rate': 0.0001, 'epoch': 100, 'batch_size': 100, 'scheduler': , 'supervision': 'full', 'center_pixel': True}

可以看到网络各层的类型和参数,优化器optimizer为喜闻乐见的Adam,损失函数为交叉熵 CrossEntropyLoss,还有hyperparams的信息。


class balancing(optional)
if CLASS_BALANCING:
    weights = compute_imf_weights(train_gt, N_CLASSES, IGNORED_LABELS)
    hyperparams['weights'] = torch.from_numpy(weights)

由于CLASS_BALANCING默认是False,所以这一部分,暂略。

在train_gt中分离val_gt
train_gt, val_gt = sample_gt(train_gt, 0.95, mode='random')
train_loader = data.DataLoader(train_dataset,
                                       batch_size=hyperparams['batch_size'],
                                       #pin_memory=hyperparams['device'],
                                       shuffle=True)
val_dataset = HyperX(img, val_gt, **hyperparams)
val_loader = data.DataLoader(val_dataset,
                                     #pin_memory=hyperparams['device'],
                                     batch_size=hyperparams['batch_size'])

trainvaltest的含义和区别见:训练过程中的train,val,test的区别

简单来说,val是从train中分出来的,用来检查网络是否对训练集过拟合(这种情况对应着随着训练的进行,train的损失越来越小,但val的损失越来越大)。

这一行代码是的意思是:从train_gt中取95%作为新的train_gt,而将5%作为val_gt

运行结果:

4063 samples selected (over 4277)

额外需要注意的是:选择的方式是将被选中的sample保持不变,未被选中的置为0(ignored_label),所以维度依旧保持不变,依旧保持为(610, 340)。

print(train_gt)   
# [0 0 0 ... 0 0 0]
# [0 0 0 ... 0 0 0]
# ...
# [0 0 0 ... 0 0 0]
# [0 0 0 ... 0 0 0]
# [0 0 0 ... 0 0 0]]
print(train_gt.shape)
# (610, 340)
print(val_gt)
# [0 0 0 ... 0 0 0]
# [0 0 0 ... 0 0 0]
# ...
# [0 0 0 ... 0 0 0]
# [0 0 0 ... 0 0 0]
# [0 0 0 ... 0 0 0]]
print(val_gt.shape)
# (610, 340)

Generate the dataset
train_dataset = HyperX(img, train_gt, **hyperparams)
train_loader = data.DataLoader(train_dataset,
                                       batch_size=hyperparams['batch_size'],
                                       #pin_memory=hyperparams['device'],
                                       shuffle=True)
val_dataset = HyperX(img, val_gt, **hyperparams)
val_loader = data.DataLoader(val_dataset,
                                     #pin_memory=hyperparams['device'],
                                     batch_size=hyperparams['batch_size'])

train_dataset = HyperX(img, train_gt, **hyperparams)第一句是,train_dataset是从class HyperX中 继承,属性为对象:

<datasets.HyperX object at 0x000001403EB62CF8>

之后通过这段代码查看对象的属性(自加,没在项目代码中):

for attr in dir(train_dataset):
	print(attr)
	print(getattr(train_dataset, attr))

查看到的属性有这些:

  • 数据集信息:
    • name
    • data
    • label
    • labels
    • indices
    • ignored_labels
name
# PaviaU
data
# [[[0.080875 0.062375 0.058    ... 0.402625 0.40475  0.40625 ]
#   [0.0755   0.06825  0.065875 ... 0.30525  0.308    0.316   ]
#   [0.077625 0.09325  0.0695   ... 0.2885   0.293125 0.295125]
#   ...
#   [0.074125 0.048375 0.0535   ... 0.29775  0.300875 0.302875]
#   [0.074125 0.093875 0.081875 ... 0.289    0.2885   0.286125]
#   [0.111125 0.09     0.056125 ... 0.302    0.305875 0.310625]]]
label
# [[0 0 0 ... 0 0 0]
#  [0 0 0 ... 0 0 0]
#  [0 0 0 ... 0 0 0]
#  ...
#  [0 0 0 ... 0 0 0]
#  [2 0 2 ... 0 0 0]
#  [0 0 0 ... 0 0 0]]
labels
# [1, 1, 4, 4, 1, 1, 1, 1, 1, 1, 1, 1, 4, 1, 1, 4, 4, 1, 4, 4, 4, 4, 1, 4, 4, 1, 4, 1, 4, 4, 1, 1, 4, 1, 1, 4, 4, 4, 4, 1, 1, 4, 1, 4, 4, 4, 1, 1, 1, 1, 4, 4, 4, 4, 1, 4, 1, 1, 1, 1, 4, 4, 4, …… 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]
indices
# [[511 237]
#  [551  94]
#  [595 176]
#  ...
#  [142 173]
#  [403  75]
#  [575 135]]
ignored_labels
# {0}
  • 数据集处理方式
    • patch_size
    • center_pixel
    • mixture_augmentation
    • flip
    • radiation_augmentation
    • mixture_noise
patch_size
# 1
center_pixel
# True
mixture_augmentation
# False
flip_augmentation
# False
radiation_augmentation
# 
mixture_noise
# >
flip
# 

train_loader = data.DataLoader(train_dataset, ……)这一句是应该是创建了一个类似DataLoader的东西。

PyTorch中数据读取的一个重要接口是torch.utils.data.DataLoader,该接口定义在dataloader.py脚本中,只要是用PyTorch来训练模型基本都会用到该接口,该接口主要用来将自定义的数据读取接口的输出(读取的结果)或者PyTorch已有的数据读取接口的输出(读取的结果)按照batch size封装成Tensor,之后便可作为模型的输入,因此该接口有点承上启下的作用,比较重要。

这一段代码来打印train_loader的信息:

print(train_loader)
print('Over')
for attr in dir(train_loader):
    if '_' not in attr:
        print(attr)
        print(getattr(train_loader, attr))
print('Done!')

结果为:

<torch.utils.data.dataloader.DataLoader object at 0x000001BFC5A714E0>
Over
dataset
<datasets.HyperX object at 0x000001BFC5A71F98>
sampler
<torch.utils.data.sampler.RandomSampler object at 0x000001BFC5BBCC50>
timeout
0
Done!

下面两行的代码和这个一样,只是对val操作,不再赘述。


打印hyperparams

打印结果为:

{'dataset': 'PaviaU', 'model': 'nn', 'folder': './Datasets/', 'cuda': 0, 'runs': 1, 'training_sample': 0.1, 'sampling_mode': 'random', 'class_balancing': False, 'test_stride': 1, 'flip_augmentation': False, 'radiation_augmentation': False, 'mixture_augmentation': False, 'with_exploration': False, 'n_classes': 10, 'n_bands': 103, 'ignored_labels': [0], 'device': device(type='cuda', index=0), 'weights': tensor([0., 1., 1., 1., 1., 1., 1., 1., 1., 1.], device='cuda:0'), 'patch_size': 1, 'dropout': False, 'learning_rate': 0.0001, 'epoch': 100, 'batch_size': 100, 'scheduler': <torch.optim.lr_scheduler.ReduceLROnPlateau object at 0x000001AC80528208>, 'supervision': 'full', 'center_pixel': True}
打印网络的信息
print("Network :")
with torch.no_grad():
    for input, _ in train_loader:
        break
    # summary(model.to(hyperparams['device']), input.size()[1:], device=hyperparams['device'])
    summary(model.to(hyperparams['device']), input.size()[1:])
    # -------------------------------------------
    # 自加
    print('begin')
    print('input: ',input)
    print('input.shape: ',input.shape)
    print('input.size()[1:]: ',input.size()[1:])
    os.system('pause')
    # --------------------------------------------

这部分是打印网络的信息,包括:

  • Layer (type)
  • Output Shape
  • Param及其详细信息
  • Estimated Total Size
  • ……

用到的函数库是torchsummary中的summary()方法。summary()方法的语法如下:

summary(your_model, input_size=(channels, H, W))
  • your_model:定义好的模型。
  • input_size:输入数据的维度,顺序为channels × H × W

这部分的思路是:

首先设定with torch.no_grad():为什么?不知道,暂略)。

然后通过for input, _ in train_loader:来获取train_loader中的input,获取之后(第一次进入for循环)就通过break跳出循环。简单来说,train_loader是一个迭代器,每次遍历(调用)这个迭代器,得到一个batch的sample。

然后调用summary()方法打印模型信息。summary(model.to(hyperparams['device']), input.size()[1:])

为了方便讲解,我写了一段打印相关信息的代码:

# --------------------------------------------
print('begin')
print('input: ',input)
print('input.shape: ',input.shape)
print('input.size()[1:]: ',input.size()[1:])
os.system('pause')
# --------------------------------------------

打印出的信息是:

Network :
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Linear-1                 [-1, 2048]         212,992
            Linear-2                 [-1, 4096]       8,392,704
            Linear-3                 [-1, 2048]       8,390,656
            Linear-4                   [-1, 10]          20,490
================================================================
Total params: 17,016,842
Trainable params: 17,016,842
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.06
Params size (MB): 64.91
Estimated Total Size (MB): 64.98
----------------------------------------------------------------
begin
input:  tensor([[0.1695, 0.1893, 0.1727,  ..., 0.1466, 0.1446, 0.1451],
        [0.0589, 0.0333, 0.0351,  ..., 0.4984, 0.5005, 0.5049],
        [0.1375, 0.1199, 0.0784,  ..., 0.2763, 0.2820, 0.2906],
        ...,
        [0.0237, 0.0671, 0.0636,  ..., 0.4619, 0.4638, 0.4703],
        [0.1209, 0.0480, 0.0318,  ..., 0.4645, 0.4672, 0.4769],
        [0.0845, 0.0955, 0.1168,  ..., 0.2921, 0.2962, 0.3014]])
input.shape:  torch.Size([100, 103])
input.size()[1:]:  torch.Size([103])

另外强调一下为什么inputshapetorch.Size([100, 103]):因为data.DataLoader()自定义的数据读取接口的输出(读取的结果)或者PyTorch已有的数据读取接口的输出(读取的结果)按照batch size封装成Tensor,而这里设定的batch_size100,所以有 torch.Size([100, 103])

但是这里遗留一个问题(暂略):

summary()语法的要求输入的参数是(channels, H, W),但是程序中输入的却是input.size()[1:],打印其值为torch.Size([103]),不是(channels, H, W)的形式。虽然在这个只有四个线性层的网络中计算 Output Shape 的时候并不需要HW,但是就可以不输入了嘛?(个人的一个猜测,有趣前面的100batch_size,后面的103channel,而这里每一个sample的维度都是1 × 1 × 103,是不是1可以默认不写呢?)。

为了验证我的猜想,验证过程和结果见:summary():记一次不明不白的debug.txt吧,这里不赘述了。

是否加载 CHECKPOINT
if CHECKPOINT is not None:
    model.load_state_dict(torch.load(CHECKPOINT))

其中load_state_dict()的介绍是:

load_state_dict(state_dict)[SOURCE]

Loads the optimizer state.

  • Parameters

    state_dict (dict) – optimizer state. Should be an object returned from a call to state_dict().

其中 state_dict()

state_dict()[SOURCE]

Returns the state of the optimizer as a dict.

It contains two entries:

  • state - a dict holding current optimization state. Its content

    differs between optimizer classes.

  • param_groups - a dict containing all parameter groups

这部分代码的作用是:读取之前保存的模型的参数。而且这个“参数”指的是模型中的可训练参数

原因是summary()的时候打印:

----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.06
Params size (MB): 64.91
Estimated Total Size (MB): 64.98
----------------------------------------------------------------

而我们检查保存的参数(C:\Users\73416\checkpoints\baseline\OwnData),发现大小是65.0 MB (68,191,571 字节),几乎完美匹配。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-KLSBhwzW-1570353563160)(C:\Users\73416\AppData\Local\Temp\1567521820597.png)]

但是由于一般来说CHECKPOINTNone,所以一般不执行这部分的代码。

训练过程
try:
    train(model, optimizer, loss, train_loader, hyperparams['epoch'],
          scheduler=hyperparams['scheduler'], device=hyperparams['device'],
          supervision=hyperparams['supervision'], val_loader=val_loader,
          display=viz)
except KeyboardInterrupt:
    # Allow the user to stop the training
    pass

这一部分的主要代码就是:

train(model, optimizer, loss, train_loader, hyperparams['epoch'],
          scheduler=hyperparams['scheduler'], device=hyperparams['device'],
          supervision=hyperparams['supervision'], val_loader=val_loader,
          display=viz)

训练的具体过程,请见 model.py 中的 train()函数。

try-except仅仅是为了增加程序的鲁棒性。


在测试集上测试并获得测试结果

probabilities = test(model, img, hyperparams)
prediction = np.argmax(probabilities, axis=-1)

获得图像img的test的结果。

需要注意test()函数的返回值是W × H × n_classes,最后一个维度n_classes正好对应分类的类数。

np.argmax()是Returns the indices of the maximum values along an axis(返回沿axis的最大值的索引。). ,需要注意的是返回值是索引,但正好对应n_classes(在哪个位置的数值最大,就是哪一类)。

额外注意np.argmax()的返回值:

Returns:

index_array : ndarray of ints

​ Array of indices into the array. It has the same shape as a.shape with the dimension along axis removed.

简单来说就是移除指定的维度,并寻找该维度的最大值,返回元素的索引。

用一个小demo演示这部分的两句代码:

# 1一个batch(100个sample)经过net()得到的output的数据
output = np.array([[-3.4172e+00,  4.7474e-01,  3.6548e+00,  9.0230e-01,  7.8329e-01,
         -1.0330e+00,  2.2331e+00, -6.2564e-01,  4.0340e-01, -3.0324e+00],
        [-3.1507e+00,  2.2456e+00,  2.9226e-01,  1.8920e+00, -8.8580e-01,
          7.2265e-01,  1.0055e+00, -4.8283e-02,  1.5603e+00, -2.6959e+00],
        [-5.2190e+00,  3.3996e+00, -1.6052e+00,  2.5934e+00, -1.1674e+00,
          5.8359e+00,  4.2171e-01,  2.4511e-01,  2.2210e+00, -3.8763e+00],
        [-2.6106e+00, -7.5499e-01,  4.3269e+00, -1.6255e-01,  2.8252e+00,
         -1.4511e+00,  1.7670e+00, -8.3349e-01, -5.4468e-01, -2.2654e+00],
        [-3.0543e+00,  9.9883e-01,  2.3457e+00,  1.1807e+00,  6.2201e-02,
         -6.4643e-01,  1.7674e+00, -4.1443e-01,  7.6956e-01, -2.7224e+00],
        [-2.3868e+00, -4.9510e-01,  3.7331e+00, -8.3138e-03,  2.1344e+00,
         -1.2079e+00,  1.6056e+00, -6.8800e-01, -3.8621e-01, -2.0586e+00],
        [-3.2416e+00, -4.5884e-02,  4.2205e+00,  5.1736e-01,  1.4380e+00,
         -1.2833e+00,  2.2156e+00, -7.2556e-01,  2.3806e-02, -2.8677e+00],
        [-2.9053e+00,  1.6935e-01,  3.4237e+00,  6.0061e-01,  1.0369e+00,
         -9.5306e-01,  1.9077e+00, -5.9553e-01,  1.6974e-01, -2.5703e+00],
        [-7.1136e+00,  4.7759e+00, -2.2519e+00,  3.6735e+00, -1.7759e+00,
          7.8180e+00,  6.6654e-01,  3.6286e-01,  3.1942e+00, -5.3241e+00],
        [-3.0338e+00, -6.3117e-02,  3.9705e+00,  4.5521e-01,  1.3632e+00,
         -1.1764e+00,  2.0957e+00, -6.7216e-01,  1.7762e-02, -2.6780e+00],
        [-2.4539e+00,  1.5082e+00,  6.6919e-01,  1.3515e+00, -4.8787e-01,
          2.2641e-01,  9.3917e-01, -1.2048e-01,  1.0607e+00, -2.1299e+00],
        [-3.2204e+00,  1.9349e+00,  9.7008e-01,  1.7553e+00, -6.2425e-01,
          2.0920e-01,  1.2744e+00, -1.7639e-01,  1.3713e+00, -2.7858e+00],
        [-2.4182e+00, -8.2613e-01,  4.0879e+00, -2.6228e-01,  2.9733e+00,
         -1.4046e+00,  1.5851e+00, -8.0221e-01, -5.8830e-01, -2.0608e+00],
        [-3.2013e+00,  1.9282e+00,  9.4064e-01,  1.7319e+00, -5.9678e-01,
          2.6542e-01,  1.2547e+00, -1.6712e-01,  1.3512e+00, -2.7726e+00],
        [-2.5719e+00, -5.3754e-01,  4.0272e+00, -4.3249e-03,  2.3080e+00,
         -1.3214e+00,  1.7518e+00, -7.5447e-01, -4.1301e-01, -2.2544e+00],
        [-3.1710e+00,  1.8300e+00,  1.0955e+00,  1.7135e+00, -5.8126e-01,
          8.5049e-02,  1.3449e+00, -2.0962e-01,  1.3227e+00, -2.7708e+00],
        [-2.1461e+00,  1.4345e+00,  3.2470e-01,  1.1824e+00, -5.0001e-01,
          5.2405e-01,  7.0300e-01, -6.3770e-02,  9.5943e-01, -1.8195e+00],
        [-2.4559e+00,  1.6129e+00,  4.1325e-01,  1.3720e+00, -5.4098e-01,
          5.1285e-01,  8.0745e-01, -5.2784e-02,  1.0950e+00, -2.0901e+00],
        [-2.5694e+00,  1.3131e+00,  1.1399e+00,  1.2092e+00, -2.4525e-01,
          6.9929e-02,  1.0451e+00, -1.8160e-01,  8.5916e-01, -2.2093e+00],
        [-2.0823e+00,  1.3539e+00,  4.1539e-01,  1.1599e+00, -4.1376e-01,
          3.9781e-01,  6.8572e-01, -6.6757e-02,  9.0060e-01, -1.7785e+00],
        [-2.2904e+00,  1.0222e-01,  2.7664e+00,  4.1632e-01,  9.4062e-01,
         -8.1115e-01,  1.3992e+00, -4.5315e-01,  4.9409e-02, -1.9998e+00],
        [-2.4030e+00, -1.2090e+00,  3.9664e+00, -6.5414e-01,  4.6159e+00,
         -1.4920e+00,  1.2595e+00, -8.8912e-01, -8.5802e-01, -1.9434e+00],
        [-3.0300e+00, -7.2083e-01,  4.8153e+00, -9.8287e-02,  2.8917e+00,
         -1.5489e+00,  2.0796e+00, -8.7492e-01, -5.1128e-01, -2.5959e+00],
        [-2.3685e+00,  1.6391e+00,  2.4802e-01,  1.3489e+00, -5.3888e-01,
          6.4715e-01,  6.5692e-01, -4.1977e-02,  1.0578e+00, -2.0024e+00],
        [-1.9444e+00, -5.2009e-01,  3.1849e+00, -9.3790e-02,  1.9984e+00,
         -1.0692e+00,  1.3111e+00, -6.1101e-01, -3.8485e-01, -1.6877e+00],
        [-3.0943e+00,  2.3930e+00, -3.5446e-01,  1.8568e+00, -9.6759e-01,
          1.6489e+00,  5.9468e-01,  7.7725e-02,  1.5249e+00, -2.5396e+00],
        [-2.6556e+00, -3.6977e-01,  3.9245e+00,  1.4978e-01,  1.8240e+00,
         -1.2115e+00,  1.8738e+00, -6.9119e-01, -2.3623e-01, -2.3246e+00],
        [-3.1395e+00,  4.1125e-02,  3.9727e+00,  5.3787e-01,  1.2732e+00,
         -1.1591e+00,  2.1246e+00, -6.8051e-01,  6.9807e-02, -2.7764e+00],
        [-2.9746e+00,  1.8802e-01,  3.5310e+00,  6.3683e-01,  9.6213e-01,
         -1.0437e+00,  2.0196e+00, -6.0621e-01,  2.0326e-01, -2.6465e+00],
        [-2.8328e+00,  1.6977e+00,  8.0364e-01,  1.4708e+00, -4.7245e-01,
          3.5556e-01,  1.0646e+00, -1.3904e-01,  1.1448e+00, -2.4328e+00],
        [-3.1625e+00,  2.0841e+00,  6.1937e-01,  1.7831e+00, -7.0946e-01,
          5.2520e-01,  1.1075e+00, -1.1551e-01,  1.4229e+00, -2.7167e+00],
        [-4.0167e+00,  1.7574e+00,  2.2763e+00,  1.7559e+00, -1.8198e-01,
         -1.7509e-01,  1.9423e+00, -4.0833e-01,  1.2436e+00, -3.5162e+00],
        [-2.4236e+00, -7.8535e-01,  4.0795e+00, -2.2951e-01,  2.8136e+00,
         -1.3446e+00,  1.6017e+00, -7.9643e-01, -5.4431e-01, -2.0819e+00],
        [-2.2582e+00, -4.6274e-01,  3.5520e+00,  1.3249e-02,  1.8428e+00,
         -1.1697e+00,  1.5907e+00, -6.3526e-01, -3.2695e-01, -1.9612e+00],
        [-3.4212e+00,  2.2788e+00,  6.3480e-01,  1.9573e+00, -7.8700e-01,
          5.6299e-01,  1.1635e+00, -1.1935e-01,  1.5660e+00, -2.9292e+00],
        [-2.5760e+00, -3.0569e-01,  3.7433e+00,  1.4583e-01,  1.8048e+00,
         -1.1692e+00,  1.7010e+00, -6.4062e-01, -2.5369e-01, -2.2543e+00],
        [-3.0344e+00, -1.4159e-01,  4.1256e+00,  3.8586e-01,  1.5449e+00,
         -1.2349e+00,  2.0604e+00, -7.1976e-01, -7.2623e-02, -2.6631e+00],
        [-2.8772e+00, -1.1097e+00,  4.8425e+00, -4.5681e-01,  4.0864e+00,
         -1.6900e+00,  1.7676e+00, -9.4947e-01, -8.0563e-01, -2.4084e+00],
        [-3.0412e+00, -1.5252e+00,  4.8663e+00, -8.4481e-01,  5.9819e+00,
         -1.8587e+00,  1.5335e+00, -1.0996e+00, -1.0606e+00, -2.4177e+00],
        [-2.2495e+00, -6.2242e-01,  3.6935e+00, -1.0642e-01,  2.3495e+00,
         -1.2321e+00,  1.4774e+00, -7.2164e-01, -4.5571e-01, -1.9474e+00],
        [-2.1554e+00,  1.3898e-01,  2.5519e+00,  4.1213e-01,  8.7232e-01,
         -7.1161e-01,  1.2969e+00, -4.2464e-01,  5.3723e-02, -1.8957e+00],
        [-2.9576e+00, -1.3114e+00,  5.1378e+00, -5.9229e-01,  4.5756e+00,
         -1.7936e+00,  1.7785e+00, -1.0670e+00, -9.2556e-01, -2.4614e+00],
        [-3.2555e+00,  1.5876e+00,  1.6341e+00,  1.5986e+00, -3.8576e-01,
         -2.5096e-01,  1.5805e+00, -2.9209e-01,  1.1768e+00, -2.8661e+00],
        [-3.4841e+00,  2.1306e+00,  9.4273e-01,  1.9013e+00, -6.5666e-01,
          3.6733e-01,  1.3094e+00, -1.7591e-01,  1.4909e+00, -3.0080e+00],
        [-2.6257e+00, -3.5297e-01,  3.8823e+00,  1.4276e-01,  1.8069e+00,
         -1.1846e+00,  1.8600e+00, -6.9508e-01, -2.3632e-01, -2.3041e+00],
        [-5.7164e+00,  4.1493e+00,  1.6602e-01,  3.4875e+00, -1.6735e+00,
          1.7104e+00,  1.6943e+00, -1.0589e-02,  2.8905e+00, -4.8421e+00],
        [-2.2660e+00, -1.0308e+00,  3.8669e+00, -4.9326e-01,  3.7366e+00,
         -1.4057e+00,  1.3182e+00, -8.2074e-01, -7.1355e-01, -1.8841e+00],
        [-2.5752e+00, -5.6564e-01,  4.0870e+00, -1.0507e-02,  2.2154e+00,
         -1.4014e+00,  1.8387e+00, -7.4375e-01, -3.7901e-01, -2.2514e+00],
        [-2.9679e+00, -1.0837e+00,  4.9977e+00, -3.9962e-01,  3.9723e+00,
         -1.7014e+00,  1.8724e+00, -9.8039e-01, -7.8818e-01, -2.5105e+00],
        [-3.2278e+00,  2.0735e+00,  7.2812e-01,  1.7949e+00, -6.9313e-01,
          4.3198e-01,  1.1599e+00, -1.4001e-01,  1.4349e+00, -2.7687e+00],
        [-2.9973e+00, -1.3724e+00,  5.0783e+00, -6.5903e-01,  4.9946e+00,
         -1.8330e+00,  1.6992e+00, -1.0746e+00, -9.6564e-01, -2.4660e+00],
        [-2.4994e+00,  1.7602e+00,  1.6353e-01,  1.4317e+00, -6.5011e-01,
          7.9777e-01,  7.3128e-01, -2.0274e-02,  1.1598e+00, -2.1139e+00],
        [-2.1210e+00, -4.5631e-01,  3.3404e+00, -8.2309e-04,  1.8792e+00,
         -1.0719e+00,  1.4665e+00, -6.0914e-01, -3.1663e-01, -1.8488e+00],
        [-2.0178e+00,  7.5944e-01,  1.3621e+00,  8.4683e-01, -8.1003e-02,
         -3.3616e-01,  1.1177e+00, -2.3672e-01,  5.8855e-01, -1.7959e+00],
        [-4.0473e+00, -1.5565e+00,  6.4261e+00, -6.8970e-01,  6.1694e+00,
         -2.2335e+00,  2.2404e+00, -1.3308e+00, -1.1467e+00, -3.3744e+00],
        [-3.7822e+00, -1.1448e+00,  6.0267e+00, -3.5539e-01,  4.6618e+00,
         -1.9564e+00,  2.2661e+00, -1.1758e+00, -8.8636e-01, -3.1916e+00],
        [-3.1418e-01,  1.5967e-01,  4.5192e-02,  1.1821e-01,  5.7125e-02,
          1.3193e-01, -1.1263e-02, -2.5072e-02,  3.2257e-02, -1.6350e-01],
        [-2.5006e+00, -3.7688e-01,  3.7182e+00,  9.5787e-02,  1.8344e+00,
         -1.1744e+00,  1.7026e+00, -6.4518e-01, -2.8520e-01, -2.1683e+00],
        [-2.7033e+00, -1.2633e-01,  3.6465e+00,  3.3531e-01,  1.3966e+00,
         -1.0790e+00,  1.8082e+00, -6.1470e-01, -6.8538e-02, -2.3573e+00],
        [-3.1769e+00,  2.5477e-01,  3.6971e+00,  7.0723e-01,  9.9888e-01,
         -1.0660e+00,  2.1025e+00, -6.3563e-01,  2.2431e-01, -2.8429e+00],
        [-3.1090e+00, -2.1315e-01,  4.2975e+00,  3.8161e-01,  1.6554e+00,
         -1.3550e+00,  2.1961e+00, -7.5499e-01, -8.2681e-02, -2.7584e+00],
        [-3.3324e+00,  1.0784e+00,  2.5192e+00,  1.2437e+00,  1.9829e-01,
         -4.5998e-01,  1.7754e+00, -4.3879e-01,  7.8876e-01, -2.9155e+00],
        [-1.9215e+00, -2.6999e-01,  2.8816e+00,  1.0510e-01,  1.3734e+00,
         -8.9469e-01,  1.3081e+00, -5.1852e-01, -2.0649e-01, -1.6914e+00],
        [-2.7723e+00, -3.9382e-02,  3.5956e+00,  4.6801e-01,  1.2204e+00,
         -1.1168e+00,  1.9174e+00, -6.2143e-01,  3.5122e-02, -2.4784e+00],
        [-3.1408e+00,  8.7559e-01,  2.7258e+00,  1.1185e+00,  1.7175e-01,
         -8.6532e-01,  1.9308e+00, -4.5202e-01,  7.0314e-01, -2.8272e+00],
        [-2.8434e+00, -1.1880e+00,  4.9346e+00, -4.9937e-01,  4.2047e+00,
         -1.7423e+00,  1.7886e+00, -9.9772e-01, -8.4582e-01, -2.3973e+00],
        [-2.6930e+00,  3.8193e-01,  2.8897e+00,  6.8692e-01,  6.8390e-01,
         -7.9716e-01,  1.6867e+00, -4.8306e-01,  2.7250e-01, -2.3930e+00],
        [-2.7698e+00,  1.6703e+00,  8.2543e-01,  1.4979e+00, -5.2919e-01,
          1.8657e-01,  1.0821e+00, -1.5562e-01,  1.1658e+00, -2.3999e+00],
        [-2.9499e+00,  2.0103e+00,  4.1318e-01,  1.6914e+00, -7.0104e-01,
          6.3629e-01,  9.3164e-01, -5.5044e-02,  1.3423e+00, -2.5152e+00],
        [-2.9679e+00,  8.5536e-01,  2.4318e+00,  1.0781e+00,  1.7901e-01,
         -6.4314e-01,  1.7377e+00, -4.0450e-01,  6.6847e-01, -2.6346e+00],
        [-2.1822e+00, -2.1775e-01,  3.0966e+00,  1.6233e-01,  1.4289e+00,
         -9.0650e-01,  1.5137e+00, -5.4407e-01, -1.4786e-01, -1.9107e+00],
        [-2.5802e+00,  1.7874e+00,  2.8635e-01,  1.4753e+00, -6.2779e-01,
          6.7400e-01,  7.6219e-01, -2.9079e-02,  1.1687e+00, -2.1778e+00],
        [-3.1480e+00,  1.5729e+00,  1.4711e+00,  1.5068e+00, -3.1229e-01,
         -2.4183e-02,  1.3990e+00, -2.6736e-01,  1.1071e+00, -2.7398e+00],
        [-2.4344e+00, -6.0649e-01,  3.9408e+00, -7.0979e-02,  2.3058e+00,
         -1.2803e+00,  1.6756e+00, -7.3830e-01, -4.2828e-01, -2.1105e+00],
        [-2.3972e+00, -8.4519e-01,  4.0643e+00, -3.0006e-01,  3.0803e+00,
         -1.3756e+00,  1.5393e+00, -8.0675e-01, -6.0961e-01, -2.0428e+00],
        [-3.2675e+00,  2.1267e+00,  6.8864e-01,  1.8853e+00, -7.4988e-01,
          4.1254e-01,  1.1888e+00, -1.2328e-01,  1.5016e+00, -2.8328e+00],
        [-2.2283e+00, -5.6476e-01,  3.6542e+00, -7.9373e-02,  2.1571e+00,
         -1.1448e+00,  1.5443e+00, -6.7334e-01, -3.8690e-01, -1.9310e+00],
        [-2.5529e+00, -1.6107e-01,  3.5172e+00,  2.6857e-01,  1.4543e+00,
         -1.0681e+00,  1.7158e+00, -5.9504e-01, -1.2687e-01, -2.2421e+00],
        [-2.9471e+00,  1.3661e+00,  1.5572e+00,  1.3502e+00, -1.8738e-01,
         -1.0739e-01,  1.3368e+00, -2.8337e-01,  9.6195e-01, -2.5632e+00],
        [-4.0490e+00,  2.6916e+00, -1.2811e+00,  2.0386e+00, -9.5085e-01,
          4.5006e+00,  2.7572e-01,  1.7540e-01,  1.7477e+00, -3.0254e+00],
        [-2.2811e+00, -6.6988e-01,  3.8223e+00, -1.5893e-01,  2.4005e+00,
         -1.2246e+00,  1.5716e+00, -7.0980e-01, -4.5187e-01, -1.9646e+00],
        [-2.6464e+00,  1.8344e+00,  2.7921e-01,  1.5202e+00, -6.4925e-01,
          7.0848e-01,  8.1471e-01, -3.6928e-02,  1.2366e+00, -2.2512e+00],
        [-2.2836e+00,  1.5426e+00,  3.0953e-01,  1.2679e+00, -5.0542e-01,
          5.8676e-01,  6.8070e-01, -4.7086e-02,  9.9322e-01, -1.9262e+00],
        [-2.1354e+00,  2.6699e-02,  2.6765e+00,  3.7227e-01,  9.4378e-01,
         -7.5686e-01,  1.3870e+00, -4.7164e-01,  2.3713e-02, -1.8903e+00],
        [-3.2814e+00,  1.3032e-01,  3.9823e+00,  6.4409e-01,  1.2050e+00,
         -1.1468e+00,  2.1940e+00, -6.6835e-01,  1.5291e-01, -2.9050e+00],
        [-2.5112e+00,  1.5445e+00,  7.0261e-01,  1.3935e+00, -5.2214e-01,
          1.9404e-01,  9.8879e-01, -1.4678e-01,  1.1036e+00, -2.1830e+00],
        [-2.3466e+00, -7.3760e-01,  3.8941e+00, -1.9764e-01,  2.7026e+00,
         -1.3215e+00,  1.5114e+00, -7.5438e-01, -5.3157e-01, -1.9959e+00],
        [-2.9707e+00, -1.3565e+00,  5.1351e+00, -6.3653e-01,  4.8045e+00,
         -1.8174e+00,  1.7231e+00, -1.0866e+00, -9.6608e-01, -2.4615e+00],
        [-4.7498e-01,  1.2949e-01,  3.7521e-01,  1.2195e-01,  1.4310e-01,
         -2.9742e-02,  1.8941e-01, -6.1047e-02,  4.5253e-02, -3.7414e-01],
        [-2.6389e+00, -2.8462e-01,  3.7923e+00,  2.1464e-01,  1.6867e+00,
         -1.1546e+00,  1.8157e+00, -6.7523e-01, -1.9182e-01, -2.3107e+00],
        [-3.0642e+00,  1.6007e-02,  3.9244e+00,  5.0872e-01,  1.3313e+00,
         -1.1347e+00,  2.0441e+00, -6.8004e-01,  3.6863e-02, -2.7183e+00],
        [-2.4678e+00, -4.8809e-01,  3.8563e+00,  2.0195e-02,  2.0485e+00,
         -1.2082e+00,  1.7291e+00, -7.0810e-01, -3.3191e-01, -2.1626e+00],
        [-3.5040e+00,  2.4570e+00,  3.6188e-01,  2.0714e+00, -9.0356e-01,
          8.4894e-01,  1.0637e+00, -6.7259e-02,  1.6715e+00, -2.9999e+00],
        [-2.2794e+00, -9.1547e-01,  3.9683e+00, -3.7716e-01,  3.1500e+00,
         -1.3683e+00,  1.4711e+00, -7.7084e-01, -6.3944e-01, -1.8953e+00],
        [-1.7820e+00,  1.2340e+00,  1.4878e-01,  9.9099e-01, -4.0080e-01,
          5.7725e-01,  4.6312e-01, -2.5199e-02,  7.7039e-01, -1.4899e+00],
        [-2.1745e+00, -1.7468e-01,  3.0339e+00,  1.9426e-01,  1.3234e+00,
         -8.8056e-01,  1.4736e+00, -5.1925e-01, -1.3082e-01, -1.8905e+00],
        [-2.8420e+00,  1.6299e+00,  9.3658e-01,  1.4826e+00, -4.5678e-01,
          1.9835e-01,  1.1390e+00, -1.4821e-01,  1.1446e+00, -2.4459e+00],
        [-3.5596e+00,  2.2245e+00,  8.8818e-01,  1.9778e+00, -7.1741e-01,
          4.0682e-01,  1.3327e+00, -1.5880e-01,  1.5614e+00, -3.0715e+00],
        [-2.8122e+00, -9.3059e-01,  4.7239e+00, -2.9683e-01,  3.4055e+00,
         -1.5805e+00,  1.8424e+00, -9.1344e-01, -6.7323e-01, -2.4044e+00],
        [-3.1333e+00,  6.4306e-01,  2.9880e+00,  9.6653e-01,  4.6051e-01,
         -8.1380e-01,  2.0050e+00, -4.9463e-01,  5.4800e-01, -2.7895e+00]])
print(output)
print(output.shape)
# (100, 10)
prediction = np.argmax(output, axis=-1)
print(prediction)
# [2 1 5 2 2 2 2 2 5 2 1 1 2 1 2 1 1 1 1 1 2 4 2 1 2 1 2 2 2 1 1 2 2 2 1 2 2
#  2 4 2 2 2 2 1 2 1 2 2 2 1 2 1 2 2 2 2 1 2 2 2 2 2 2 2 2 2 2 1 1 2 2 1 1 2
#  2 1 2 2 2 5 2 1 1 2 2 1 2 2 2 2 2 2 1 2 1 2 1 1 2 2]
print(prediction.shape)
# (100,)

可以看到最后能够的到样本经网络预测后所属的

====


获取测试集运行结果(指标)

run_results = metrics(prediction, test_gt, ignored_labels=hyperparams['ignored_labels'], n_classes=N_CLASSES)

运行结果包括:

  • accuracy
  • F1 score by class
  • confusion matrix

具体见 utils.pymetrics()函数。


获取测试集预测的 label

mask = np.zeros(gt.shape, dtype='bool')
for l in IGNORED_LABELS:
    mask[gt == l] = True
prediction[mask] = 0

这里接着“在测试集上测试并获得测试结果”的部分,指prediction的结果。

这里创建一个蒙版mask,目的是将prediction所有的类别为ignored_label的位置的值置零(因为ignored_label对应的类别为0)。

功能演示见下面的小demo:

IGNORED_LABELS = [0]
gt = np.array([[0,1,0],[2,3,2],[0,3,0]])
prediction = np.array([[10,1,10],[2,3,2],[10,3,10]])
print('former prediction:\n',prediction)
# former prediction:
#  [[10  1 10]
#  [ 2  3  2]
#  [10  3 10]]
mask = np.zeros(gt.shape, dtype='bool')
for l in IGNORED_LABELS:
    mask[gt == l] = True
prediction[mask] = 0

print('gt:\n',gt)
# gt:
#  [[0 1 0]
#  [2 3 2]
#  [0 3 0]]
print('mask:\n',mask)
# mask:
#  [[ True False  True]
#  [False False False]
#  [ True False  True]]
print('processed prediction:\n',prediction)
# processed prediction:
#  [[0 1 0]
#  [2 3 2]
#  [0 3 0]]

visdom可视化test预测值并对比test_gt

color_prediction = convert_to_color(prediction)
display_predictions(color_prediction, viz, gt=convert_to_color(test_gt), caption="Prediction vs. test ground truth")

color_prediction = convert_to_color(prediction)借用paletteprediction不同的类别转为RGB的颜色。

display_predictions()函数在同一张图中可视化predictiontest_gt

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-GbfG1rTx-1570353563161)(C:\Users\73416\AppData\Local\Temp\1567689193912.png)]


打印网络指标

results.append(run_results)
show_results(run_results, viz, label_values=LABEL_VALUES)

results.append(run_results)是将run_results加入到results中,但如果只运行一次(N_RUNS==1results就是空的,这时直接打印run_results也是一样的结果。但N_RUNS > 1就不能打印run_results了。

show_results使用visdom可视化results

至此跳出了N_RUNS的循环,虽然一般也就一个RUNS


打印 N_RUNS 的结果

if N_RUNS > 1:
    show_results(results, viz, label_values=LABEL_VALUES, agregated=True)

这部分是打印N_RUNS(大于1时)的合并的results

##P.S

hyperparams的变化过程

hyperparams的初始化:

hyperparams的初始化是hyperparams = vars(args),将命名空间的对象:

Namespace(batch_size=None, class_balancing=False, cuda=0, dataset='PaviaU', download=None, epoch=None, flip_augmentation=False, folder='./Datasets/', lr=None, mixture_augmentation=False, model='nn', patch_size=None, radiation_augmentation=False, restore=None, runs=1, sampling_mode='random', test_set=None, test_stride=1, train_set=None, training_sample=0.1, with_exploration=False)

转化为字典dictionary类型:

{'dataset': 'PaviaU', 'model': 'nn', 'folder': './Datasets/', 'cuda': 0, 'runs': 1, 'restore': None, 'training_sample': 0.1, 'sampling_mode': 'random', 'train_set': None, 'test_set': None, 'epoch': None, 'patch_size': None, 'lr': None, 'class_balancing': False, 'batch_size': None, 'test_stride': 1, 'flip_augmentation': False, 'radiation_augmentation': False, 'mixture_augmentation': False, 'with_exploration': False, 'download': None}

这时,只是读入了cudadatasetmodeltraining_sample这四个超参数,其他的超参数(数据处理的超参数、训练模型的超参数)并没有保存到hyperparams中,目前还是默认值default

更新数据处理的超参数:
# 操作参数
CUDA_DEVICE = get_device(args.cuda)
……
# Number of classes
N_CLASSES = len(LABEL_VALUES)
# Number of bands (last dimension of the image tensor)
N_BANDS = img.shape[-1]
……
img, gt, LABEL_VALUES, IGNORED_LABELS, RGB_BANDS, palette = get_dataset(DATASET, FOLDER)
……
hyperparams.update({'n_classes': N_CLASSES, 'n_bands': N_BANDS, 'ignored_labels': IGNORED_LABELS, 'device': CUDA_DEVICE})
hyperparams = dict((k, v) for k, v in hyperparams.items() if v is not None)     # 遍历hyperparams将键值对再变成字典类型

跟新数据处理相关的超参数,并且去掉字典hyperparamsNone的超参数键值对。

更新通过字典的update()方法实现。

筛选就是for循环 + if判断。

最后的结果hyperparams是:

hyperparams:  {'dataset': 'PaviaU', 'model': 'nn', 'folder': './Datasets/', 'cuda': 0, 'runs': 1, 'training_sample': 0.1, 'sampling_mode': 'random', 'class_balancing': False, 'test_stride': 1, 'flip_augmentation': False, 'radiation_augmentation': False, 'mixture_augmentation': False, 'with_exploration': False, 'n_classes': 10, 'n_bands': 103, 'ignored_labels': [0], 'device': device(type='cuda', index=0)}
更新模型相关的超参数:
# Neural network
model, optimizer, loss, hyperparams = get_model(MODEL, **hyperparams)
if CLASS_BALANCING:
    weights = compute_imf_weights(train_gt, N_CLASSES, IGNORED_LABELS)
    hyperparams['weights'] = torch.from_numpy(weights)
# Split train set in train/val
train_gt, val_gt = sample_gt(train_gt, 0.95, mode='random')
# Generate the dataset
train_dataset = HyperX(img, train_gt, **hyperparams)
train_loader = data.DataLoader(train_dataset,
                               batch_size=hyperparams['batch_size'],
                               #pin_memory=hyperparams['device'],
                               shuffle=True)
val_dataset = HyperX(img, val_gt, **hyperparams)
val_loader = data.DataLoader(val_dataset,
                             #pin_memory=hyperparams['device'],
                             batch_size=hyperparams['batch_size'])

通过这段代码将模型的超参数更新到hyperparams中。具体的细节问题还没看到。

最后的hyperparams的结果是:

{'dataset': 'PaviaU', 'model': 'nn', 'folder': './Datasets/', 'cuda': 0, 'runs': 1, 'training_sample': 0.1, 'sampling_mode': 'random', 'class_balancing': False, 'test_stride': 1, 'flip_augmentation': False, 'radiation_augmentation': False, 'mixture_augmentation': False, 'with_exploration': False, 'n_classes': 10, 'n_bands': 103, 'ignored_labels': [0], 'device': device(type='cuda', index=0), 'weights': tensor([0., 1., 1., 1., 1., 1., 1., 1., 1., 1.], device='cuda:0'), 'patch_size': 1, 'dropout': False, 'learning_rate': 0.0001, 'epoch': 100, 'batch_size': 100, 'scheduler': , 'supervision': 'full', 'center_pixel': True}

ugmentation’: False, ‘with_exploration’: False, ‘n_classes’: 10, ‘n_bands’: 103, ‘ignored_labels’: [0], ‘device’: device(type=‘cuda’, index=0)}


##### 更新模型相关的超参数:

```python
# Neural network
model, optimizer, loss, hyperparams = get_model(MODEL, **hyperparams)
if CLASS_BALANCING:
    weights = compute_imf_weights(train_gt, N_CLASSES, IGNORED_LABELS)
    hyperparams['weights'] = torch.from_numpy(weights)
# Split train set in train/val
train_gt, val_gt = sample_gt(train_gt, 0.95, mode='random')
# Generate the dataset
train_dataset = HyperX(img, train_gt, **hyperparams)
train_loader = data.DataLoader(train_dataset,
                               batch_size=hyperparams['batch_size'],
                               #pin_memory=hyperparams['device'],
                               shuffle=True)
val_dataset = HyperX(img, val_gt, **hyperparams)
val_loader = data.DataLoader(val_dataset,
                             #pin_memory=hyperparams['device'],
                             batch_size=hyperparams['batch_size'])

通过这段代码将模型的超参数更新到hyperparams中。具体的细节问题还没看到。

最后的hyperparams的结果是:

{'dataset': 'PaviaU', 'model': 'nn', 'folder': './Datasets/', 'cuda': 0, 'runs': 1, 'training_sample': 0.1, 'sampling_mode': 'random', 'class_balancing': False, 'test_stride': 1, 'flip_augmentation': False, 'radiation_augmentation': False, 'mixture_augmentation': False, 'with_exploration': False, 'n_classes': 10, 'n_bands': 103, 'ignored_labels': [0], 'device': device(type='cuda', index=0), 'weights': tensor([0., 1., 1., 1., 1., 1., 1., 1., 1., 1.], device='cuda:0'), 'patch_size': 1, 'dropout': False, 'learning_rate': 0.0001, 'epoch': 100, 'batch_size': 100, 'scheduler': , 'supervision': 'full', 'center_pixel': True}

你可能感兴趣的:(开源项目使用,高光谱图像分类)