# -*- coding: utf-8 -*-
"""
DEEP LEARNING FOR HYPERSPECTRAL DATA.
This script allows the user to run several deep models (and SVM baselines)
against various hyperspectral datasets. It is designed to quickly benchmark
state-of-the-art CNNs on various public hyperspectral datasets.
This code is released under the GPLv3 license for non-commercial and research
purposes only.
For commercial use, please contact the authors.
"""
项目简介,汉语版为:
"""
应用于高光谱数据的深度学习。
该脚本允许用户运行多个深层模型(和SVM baselines)
针对各种高光谱数据集。 它旨在快速进行基准测试
各种公共高光谱数据集上最先进的CNN。
此代码根据GPLv3许可证发布,用于非商业和研究
目的而已。
如需商业用途,请联系作者。
"""
###解决兼容:
# Python 2/3 compatiblity
from __future__ import print_function
from __future__ import division
用于解决python2和python3的兼容问题。
# 自加,增加修改路径
import sys,os
curPath = os.path.abspath(os.path.dirname(__file__))
print(os.path.dirname(__file__))
# C:/Users/73416/PycharmProjects/HSIproject
print(curPath)
# C:\Users\73416\PycharmProjects\HSIproject
rootPath = os.path.split(curPath)[0]
print(os.path.split(curPath))
# ('C:\\Users\\73416\\PycharmProjects', 'HSIproject')
print(rootPath)
# C:\Users\73416\PycharmProjects
print(sys.path)
# ['C:\\Users\\73416\\PycharmProjects\\HSIproject',
# 'E:\\PyCharm 2018.3.4\\helpers\\pydev',
# 'C:\\Users\\73416\\PycharmProjects\\HSIproject',
# 'E:\\PyCharm 2018.3.4\\helpers\\third_party\\thriftpy',
# 'E:\\PyCharm 2018.3.4\\helpers\\pydev',
# 'E:\\Anaconda\\python37.zip',
# 'E:\\Anaconda\\DLLs',
# 'E:\\Anaconda\\lib',
# 'E:\\Anaconda',
# 'E:\\Anaconda\\lib\\site-packages',
# 'E:\\Anaconda\\lib\\site-packages\\win32',
# 'E:\\Anaconda\\lib\\site-packages\\win32\\lib',
# 'E:\\Anaconda\\lib\\site-packages\\Pythonwin',
# 'E:\\PyCharm 2018.3.4\\helpers\\pycharm_matplotlib_backend',
# 'E:\\Anaconda\\lib\\site-packages\\IPython\\extensions',
# 'C:\\Users\\73416\\PycharmProjects\\HSIproject',
# 'C:/Users/73416/PycharmProjects/HSIproject']
sys.path.append(rootPath)
sys.path.append('E:\\Anaconda\\lib\\site-packages\\')
print(sys.path)
# ['C:\\Users\\73416\\PycharmProjects\\HSIproject',
# 'E:\\PyCharm 2018.3.4\\helpers\\pydev',
# 'C:\\Users\\73416\\PycharmProjects\\HSIproject',
# 'E:\\PyCharm 2018.3.4\\helpers\\third_party\\thriftpy',
# 'E:\\PyCharm 2018.3.4\\helpers\\pydev',
# 'E:\\Anaconda\\python37.zip',
# 'E:\\Anaconda\\DLLs',
# 'E:\\Anaconda\\lib',
# 'E:\\Anaconda',
# 'E:\\Anaconda\\lib\\site-packages',
# 'E:\\Anaconda\\lib\\site-packages\\win32',
# 'E:\\Anaconda\\lib\\site-packages\\win32\\lib',
# 'E:\\Anaconda\\lib\\site-packages\\Pythonwin',
# 'E:\\PyCharm 2018.3.4\\helpers\\pycharm_matplotlib_backend',
# 'E:\\Anaconda\\lib\\site-packages\\IPython\\extensions',
# 'C:\\Users\\73416\\PycharmProjects\\HSIproject',
# 'C:/Users/73416/PycharmProjects/HSIproject',
# 'C:\\Users\\73416\\PycharmProjects',
# 'E:\\Anaconda\\lib\\site-packages\\']
# 自加,修改路径的代码
C:\Users\73416\PycharmProjects\HSIproject
C:\Users\73416\PycharmProjects\HSIproject
('C:\Users\73416\PycharmProjects', 'HSIproject')
C:\Users\73416\PycharmProjects
['C:\Users\73416\PycharmProjects\HSIproject', 'E:\Python37\python37.zip', 'E:\Python37\DLLs', 'E:\Python37\lib', 'E:\Python37', 'E:\Python37\lib\site-packages', 'E:\Python37\lib\site-packages\win32', 'E:\Python37\lib\site-packages\win32\lib', 'E:\Python37\lib\site-packages\Pythonwin']
['C:\Users\73416\PycharmProjects\HSIproject', 'E:\Python37\python37.zip', 'E:\Python37\DLLs', 'E:\Python37\lib', 'E:\Python37', 'E:\Python37\lib\site-packages', 'E:\Python37\lib\site-packages\win32', 'E:\Python37\lib\site-packages\win32\lib', 'E:\Python37\lib\site-packages\Pythonwin', 'C:\Users\73416\PycharmProjects', 'E:\Anaconda\lib\site-packages\']
上面的结果是使用pycharm运行的结果,下面是使用命令行运行的结果。
项目的正确运行方式是使用命令行窗口运行。
需要注意的是,尽管不报关于函数库的错,但二者运行过程中使用的包的版本不同。一个是E:\\Anaconda
的路径下,一个是E:\Python37
的路径下。
想要调用项目文件路径下的其他.py文件,需要在默认的搜索路径sys.path
中加入rootPath
(C:\Users\73416\PycharmProjects)。
加上sys.path.append('E:\\Anaconda\\lib\\site-packages\\')
这一句是因为命令行搜索路径下没有torch库,所以使用Anaconda路径下的torch库。
(函数库的下载路径,使用的方式不同也有所不同。使用默认的pip,下载路径在E:\Python37
下;使用Anaconda的pip或直接使用conda命令,下载路径在E:\\Anaconda
下。)
# Torch
import torch
import torch.utils.data as data
from torchsummary import summary
# Numpy, scipy, scikit-image, spectral
import numpy as np
import sklearn.svm
import sklearn.model_selection
from skimage import io
# Visualization
import seaborn as sns
import visdom
import os
# import 自定义模块
from utils import metrics, convert_to_color_, convert_from_color_,\
display_dataset, display_predictions, explore_spectrums, plot_spectrums,\
sample_gt, build_dataset, show_results, compute_imf_weights, get_device
from datasets import get_dataset, HyperX, open_file, DATASETS_CONFIG
from models import get_model, train, test, save_model
# 命令行解析器
import argparse
这里想说一点项目文件的组织形式(恕我是新手)。
main.py 一般是放的项目的主程序,整个代码的过程大概就是运行的流程
utils.py 一般是存放一些自己写的函数,在main.py被调用。
datasets.py 存放的是于数据集有关的函数、类……
models.py 是用到的模型和相关内容。
整体的意思,在 main.py 中不去定义函数,只调用函数。需要的自定义函数分门别类地写在其他的 .py 文件中。这样项目就有很好的组织形式。
# 获取DATASETS_CONFIG中的dataset_names,操作的对字典的操作。
dataset_names = [v['name'] if 'name' in v.keys() else k for k, v in DATASETS_CONFIG.items()]
这一部分是得到数据集的名字,以列表list的形式储存。
数据集的名字既包含预定义的数据集的名字,也包括自定义的数据集的名字。原因是在 datasets.py 的DATASETS_CONFIG
有一个try - except
的部分,通过字典更新update()
方法将自定义数据集字典CUSTOM_DATASETS_CONFIG
加入到DATASETS_CONFIG
中。
详细可见另一篇博客:argparse命令行解析器的使用。
parser = argparse.ArgumentParser(description="Run deep learning experiments on"
" various hyperspectral datasets")
argparse.ArgumentParser()
创建解析器,description
为帮助信息。
parser.add_argument('--dataset', type=str, default=None, choices=dataset_names,
help="Dataset to use.")
dataset
None
dataset_names
"Dataset to use."
parser.add_argument('--model', type=str, default=None,
help="Model to train. Available:\n"
"SVM (linear),\n "
"SVM_grid (grid search on linear, poly and RBF kernels), \n"
"baseline (fully connected NN), \n"
"hu (1D CNN), \n"
"hamida (3D CNN + 1D classifier), \n"
"lee (3D FCN), \n"
"chen (3D CNN), \n"
"li (3D CNN), \n"
"he (3D CNN), \n"
"luo (3D CNN), \n"
"sharma (2D CNN), \n"
"boulch (1D semi-supervised CNN), \n"
"liu (3D semi-supervised CNN), \n"
"mou (1D RNN)")
model
None
"Model to train. Available:\n", ……
parser.add_argument('--folder', type=str, help="Folder where to store the "
"datasets (defaults to the current working directory).",
default="./Datasets/")
folder
"./Datasets/"
"Folder where to store the " ……
parser.add_argument('--cuda', type=int, default=-1,
help="Specify CUDA device (defaults to -1, which learns on CPU)")
cuda
-1
"Specify CUDA device (defaults to -1, which learns on CPU)"
parser.add_argument('--runs', type=int, default=1, help="Number of runs (default: 1)")
runs
1
"Number of runs (default: 1)"
parser.add_argument('--restore', type=str, default=None,
help="Weights to use for initialization, e.g. a checkpoint")
restore
None
"Weights to use for initialization, e.g. a checkpoint"
group_dataset = parser.add_argument_group('Dataset')
这一行将下面增加的参数归为一组,大概的效果是这样:
Dataset:
--training_sample ……
--sampling_mode ……
--train_set ……
--test_set ……
group_dataset.add_argument('--training_sample', type=float, default=10,
help="Percentage of samples to use for training (default: 10%%)")
training_sample
10
"Percentage of samples to use for training (default: 10%%)"
group_dataset.add_argument('--sampling_mode', type=str, help="Sampling mode"
" (random sampling or disjoint, default: random)",
default='random')
sampling_mode
'random'
"Sampling mode (random sampling or disjoint, default: random)"
group_dataset.add_argument('--train_set', type=str, default=None,
help="Path to the train ground truth (optional, this "
"supersedes(取代版本) the --sampling_mode option)")
train_set
None
""Path to the train ground truth (optional, this supersedes(取代版本) the --sampling_mode option)"
需要额外说明的是,这个train_set
参数是参数sampling_mode
的取代版本。
group_dataset.add_argument('--test_set', type=str, default=None,
help="Path to the test set (optional, by default "
"the test_set is the entire ground truth minus the training)")
test_set
None
"Path to the test set (optional, by default the test_set is the entire ground truth minus the training) "
需要额外说明的是,test_set 默认是整个 ground truth 减去 the training。
group_train = parser.add_argument_group('Training')
这一行将下面增加的参数归为一组,大概的效果是这样:
Training:
--epoch EPOCH ……
--patch_size PATCH_SIZE ……
--lr LR ……
--class_balancing ……
--batch_size BATCH_SIZE ……
--test_stride TEST_STRIDE ……
group_train.add_argument('--epoch', type=int, help="Training epochs (optional, if"
" absent will be set by the model)")
epoch
"Training epochs (optional, if absent will be set by the model)"
需要额外说明的是,参数epoch
一般不特别指定,而是默认使用模型中预定义的值。
group_train.add_argument('--patch_size', type=int,
help="Size of the spatial neighbourhood 空间邻域的大小 (optional, if "
"absent will be set by the model)")
patch_size
"Size of the spatial neighbourhood 空间邻域的大小 (optional, if absent will be set by the model)"
需要额外说明的是,参数patch_size
一般不特别指定,而是默认使用模型中预定义的值。
这里想额外写一点关于patch
的东西。patch
的来源是图像分块处理。一般在实际应用中为比较小的size(3 × 3或5 × 5)。从之前看到的(2019.9.1)一般是先指定中心位置,然后再得到以这个位置为中心的patch
。所以这是空间上下文信息!!!
group_train.add_argument('--lr', type=float,
help="Learning rate, set by the model if not specified.")
lr
"Learning rate, set by the model if not specified."
group_train.add_argument('--class_balancing', action='store_true',
help="Inverse median frequency class balancing (default = False)")
class_balancing
'store_true'
"Inverse median frequency class balancing (default = False)"
,反中值频率类平衡?group_train.add_argument('--batch_size', type=int,
help="Batch size (optional, if absent will be set by the model")
batch_size
"Batch size (optional, if absent will be set by the model"
需要额外说明的是,参数batch_size
一般不特别指定,而是默认使用模型中预定义的值。
group_train.add_argument('--test_stride', type=int, default=1,
help="Sliding window step stride during inference (default = 1)")
test_stride
1
"Sliding window step stride during inference (default = 1)"
在 inference 期间滑动窗口步幅
group_da.add_argument('--flip_augmentation', action='store_true',
help="Random flips (if patch_size > 1)")
flip_augmentation
'store_true'
"Random flips (if patch_size > 1)"
group_da.add_argument('--radiation_augmentation', action='store_true',
help="Random radiation noise (illumination)")
radiation_augmentation
'store_true'
"Random radiation noise (illumination)"
group_da.add_argument('--mixture_augmentation', action='store_true',
help="Random mixes between spectra")
mixture_augmentation
'store_true'
"Random mixes between spectra"
parser.add_argument('--with_exploration', action='store_true',
help="See data exploration visualization")
with_exploration
'store_true'
"See data exploration visualization"
parser.add_argument('--download', type=str, default=None, nargs='+',
choices=dataset_names,
help="Download the specified datasets and quits.")
download
None
nargs
:+
dataset_names
""Download the specified datasets and quits."
nargs
:默认情况下 ArgumentParser
对象将参数与一个与action
一对一关联,通过指定 nargs
可以将多个参数与一个action
相关联。
args = parser.parse_args()
ArgumentParser
通过该parse_args()
方法解析参数 。这将检查命令行,将每个参数转换为适当的类型,然后调用相应的操作。在大多数情况下,这意味着Namespace
将从命令行解析的属性构建一个简单的对象。注意,函数返回的是对象类型,参数值为对象的属性!
在脚本中,parse_args
()通常会调用不带参数。
以运行这个命令为例:
python C:\Users\73416\PycharmProjects\HSIproject\main.py --model nn --dataset PaviaU --training_sample 0.1 --cuda 0
得到的args
如下:
Namespace(batch_size=None, class_balancing=False, cuda=0, dataset='PaviaU', download=None, epoch=None, flip_augmentation=False, folder='./Datasets/', lr=None, mixture_augmentation=False, model='nn', patch_size=None, radiation_augmentation=False, restore=None, runs=1, sampling_mode='random', test_set=None, test_stride=1, train_set=None, training_sample=0.1, with_exploration=False)
可以看到args
是字典dictionary类型,包含超参数的键值对。
这一部分是通过访问args
中的键,从而将值赋给相应的变量,并进行后续操作。
CUDA_DEVICE = get_device(args.cuda)
get_device()
是一个在 utlis.py 中自定义的函数,用来根据输入的不同选择对应的CUDA_DEVICE
。
SAMPLE_PERCENTAGE = args.training_sample
用来确定 training samples
在 entire samples
的占比。
FLIP_AUGMENTATION = args.flip_augmentation
RADIATION_AUGMENTATION = args.radiation_augmentation
MIXTURE_AUGMENTATION = args.mixture_augmentation
根据FLIP_AUGMENTATION
、RADIATION_AUGMENTATION
和MIXTURE_AUGMENTATION
的 bool 类型的真假,选择是否采用对应的数据增强的方法。
这三个值都默认为False
。
DATASET = args.dataset
获取数据集名称,存放在变量DATASET
中。
MODEL = args.model
获取模型名称,存放在变量MODEL
中。
N_RUNS = args.runs
获取运行次数,存放在变量N_RUNS
中。
空间上下文大小(每个空间方向上的邻居数)
PATCH_SIZE = args.patch_size
patch
的来源是图像分块处理。一般在实际应用中为比较小的size(3 × 3或5 × 5)。从之前看到的(2019.9.1)一般是先指定中心位置,然后再得到以这个位置为中心的patch
。所以一个patch
就是中心的目标样本 + 其空间上下文信息。
DATAVIZ = args.with_exploration
根据DATAVIZ
的 bool 类型的真假,选择是否增加光谱可视化。默认为False
。
FOLDER = args.folder
存放/下载/加载数据集的目标文件夹(Target folder to store/download/load the datasets)
EPOCH = args.epoch
一个epoch
就是把整个数据集过一遍。
采样模式。
SAMPLING_MODE = args.sampling_mode
Pre-computed weights to restore。要恢复的预先计算的权重。
CHECKPOINT = args.restore
随机梯度下降的学习率(Learning rate for the SGD)
LEARNING_RATE = args.lr
不懂什么意思,暂略。
CLASS_BALANCING = args.class_balancing
TRAIN_GT = args.train_set
但args
中并没有train_set
,为什么?怎么办?
TEST_GT = args.test_set
但args
中并没有test_set
,为什么?怎么办?
Sliding window step stride during inference.
TEST_STRIDE = args.test_stride
if args.download is not None and len(args.download) > 0:
for dataset in args.download:
get_dataset(dataset, target_folder=FOLDER)
quit()
viz = visdom.Visdom(env=DATASET + ' ' + MODEL) # 设置visdom环境
if not viz.check_connection: # 检测与visdom服务器的连接
print("Visdom is not connected. Did you run 'python -m visdom.server' ?")
这部分是设置visdom
环境,将环境名env
设置为DATASET + ' ' + MODEL
。
并检查与visdom
服务器的连接,若没能连接则打印报错信息。
# Load the dataset # 加载数据集
img, gt, LABEL_VALUES, IGNORED_LABELS, RGB_BANDS, palette = get_dataset(DATASET, FOLDER)
"""
img: 3D hyperspectral image (WxHxB)
gt: 2D int array of labels
label_values: list of class names
ignored_labels: list of int classes to ignore
rgb_bands: int tuple that correspond to red, green and blue bands
"""
get_dataset()
定义在datasets.py中,包括数据集下载、读取和预处理的操作。
给出指定的DATASET
和FOLDER
,就可以得到img
(WxHxBands)、gt
(2D int array of labels)、LABEL_VALUES
(list of class names)、IGNORED_LABELS
(list of int classes to ignore)和RGB_BANDS
(int tuple that correspond to red, green and blue bands)
# Number of classes
N_CLASSES = len(LABEL_VALUES)
# Number of bands (last dimension of the image tensor)
N_BANDS = img.shape[-1]
类数N_CLASSES
就是LABEL_VALUES
的长度。其中LABEL_VALUES
为:
['Undefined', 'Asphalt', 'Meadows', 'Gravel', 'Trees', 'Painted metal sheets', 'Bare Soil', 'Bitumen', 'Self-Blocking Bricks', 'Shadows']
波段数N_BANDS
是img
的最后一个维度 shape[-1]
。
# Parameters for the SVM grid search SVM参数设定
SVM_GRID_PARAMS = [{'kernel': ['rbf'], 'gamma': [1e-1, 1e-2, 1e-3],
'C': [1, 10, 100, 1000]},
{'kernel': ['linear'], 'C': [0.1, 1, 10, 100, 1000]},
{'kernel': ['poly'], 'degree': [3], 'gamma': [1e-1, 1e-2, 1e-3]}]
因为不用SVM,所以这部分暂略。
# mport numpy as np
# import seaborn as sns
# palette = None
# LABEL_VALUES = ['Undefined', 'Asphalt', 'Meadows', 'Gravel', 'Trees', 'Painted metal sheets', # 'Bare Soil', 'Bitumen', 'Self-Blocking Bricks', 'Shadows']
if palette is None: # 调色板
# Generate color palette
palette = {0: (0, 0, 0)}
for k, color in enumerate(sns.color_palette("hls", len(LABEL_VALUES) - 1)):
palette[k + 1] = tuple(np.asarray(255 * np.array(color), dtype='uint8'))
invert_palette = {v: k for k, v in palette.items()}
# print(palette)
# # {0: (0, 0, 0), 1: (219, 94, 86), 2: (219, 183, 86), 3: (167, 219, 86), 4: (86, 219, 94), 5: (86, 219, 183), 6: (86, 167, 219), 7: (94, 86, 219), 8: (183, 86, 219), 9: (219, 86, 167)}
# print(invert_palette)
# # {(0, 0, 0): 0, (219, 94, 86): 1, (219, 183, 86): 2, (167, 219, 86): 3, (86, 219, 94): 4, (86, 219, 183): 5, (86, 167, 219): 6, (94, 86, 219): 7, (183, 86, 219): 8, (219, 86, 167): 9}
注释起来的是测试代码。
首先说一下数目的问题。代码首先通过palette = {0: (0, 0, 0)}
将类别为0
(ignored_labels
)的调色板直接置为(0, 0, 0)
,作为背景。之后对剩下的9个类别(非ignored_labels
),再分别产生对应的颜色,所以参数是len(LABEL_VALUES) - 1
。
hls
是一种颜色空间,这是RGB值的简单转换。
这部分代码的整体的流程是:当palette
为None
的时候,执行后续的生成palette
的操作。之后首先将类别为0
(ignored_labels
)的调色板直接置为(0, 0, 0)
,作为背景。然后通过enumerate(sns.color_palette("hls", len(LABEL_VALUES) - 1))
构建颜色空间为hls
,数目为len(LABEL_VALUES) - 1
的三元素元组的索引序列。然后通过for循环遍历这个索引序列,palette[k + 1] = tuple(np.asarray(255 * np.array(color), dtype='uint8'))
,获得不同的类别的palette
值。
之后遍历palette
的时候调转键值对,得到invert_palette
。invert_palette = {v: k for k, v in palette.items()}
。
def convert_to_color(x):
return convert_to_color_(x, palette=palette)
def convert_from_color(x):
return convert_from_color_(x, palette=invert_palette)
x
是int 2D array of labels
,标签的二维矩阵。
palette
和invert_palette
在上面定义过。
这两个函数就是类别和RGB颜色相互转换的函数。其实说到底就是类别和显示颜色的对应。
# Instantiate the experiment based on predefined networks 根据预定义的网络实例化实验
hyperparams.update({'n_classes': N_CLASSES, 'n_bands': N_BANDS, 'ignored_labels': IGNORED_LABELS, 'device': CUDA_DEVICE})
hyperparams = dict((k, v) for k, v in hyperparams.items() if v is not None) # 遍历hyperparams将键值对再变成字典类型
类数N_CLASSES
、波段数N_BANDS
、IGNORED_LABELS
在加载数据集(和后续操作)就得到了。
CUDA_DEVICE
在命令行解析器的部分就读取了。
这里是把类数N_CLASSES
、波段数N_BANDS
、IGNORED_LABELS
和CUDA_DEVICE
这4个超参数的值,通过update()
更新到超参数字典hyperparams
中。
然后再通过遍历将超参数字典hyperparams
中的值为None
的键值对筛选掉。
# Show the image and the ground truth
display_dataset(img, gt, RGB_BANDS, LABEL_VALUES, palette, viz)
color_gt = convert_to_color(gt)
这部分是 Show the image and the ground truth,其中 ground truth 由颜色的不同来表示类别的不同。
其中visdom中展示的img
如下图:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-GOeGAzvc-1570353563159)(C:\Users\73416\Desktop\visdom_image (1)].jpg)
if DATAVIZ: # ???
# Data exploration : compute and show the mean spectrums
mean_spectrums = explore_spectrums(img, gt, LABEL_VALUES, viz,
ignored_labels=IGNORED_LABELS)
plot_spectrums(mean_spectrums, viz, title='Mean spectrum/class')
但DATAVIZ
默认为False
,所以这段代码也是默认不执行。
暂略。
results = []
这里是将存储最终结果的变量result
初始化为一个空的列表list。
后续通过append(
)方法来将结果加入到列表result
中。
# run the experiment several times
for run in range(N_RUNS):
……
run the experiment several times. 默认是1次。
if TRAIN_GT is not None and TEST_GT is not None:
train_gt = open_file(TRAIN_GT)
test_gt = open_file(TEST_GT)
elif TRAIN_GT is not None:
train_gt = open_file(TRAIN_GT)
test_gt = np.copy(gt)
w, h = test_gt.shape
test_gt[(train_gt > 0)[:w,:h]] = 0
elif TEST_GT is not None:
test_gt = open_file(TEST_GT)
else:
# Sample random training spectra 随机训练光谱样本(有训练集,有测试集)
train_gt, test_gt = sample_gt(gt, SAMPLE_PERCENTAGE, mode=SAMPLING_MODE)
print("{} samples selected (over {})".format(np.count_nonzero(train_gt),
np.count_nonzero(gt)))
在这次运行中,在执行这段代码之前,TRAIN_GT
和TEST_GT
都是None
。
所以默认执行的都是最后一个分支
else:
# Sample random training spectra 随机训练光谱样本(有训练集,有测试集)
train_gt, test_gt = sample_gt(gt, SAMPLE_PERCENTAGE, mode=SAMPLING_MODE)
sample_gt()
是从标签数组gt
中提取固定百分比SAMPLE_PERCENTAGE
的样本(Extract a fixed percentage of samples from an array of labels)。
需要强调的是,被分割为训练集和测试集的样本,不包括类别为ignored_labels
的sample。
print("{} samples selected (over {})".format(np.count_nonzero(train_gt),
np.count_nonzero(gt)))
# 4277 samples selected (over 42776)
这一句打印划分train_gt
和test_gt
的划分结果。
可以看到结果为从 entire sample(42776个sample) 中选取了4277个sample作为train_gt
。这与SAMPLE_PERCENTAGE
在这里被设定为0.1
相对应。
print("Running an experiment with the {} model".format(MODEL),
"run {}/{}".format(run + 1, N_RUNS))
在由“run the experiment”进入循环后,打印这是第几次run
。
run + 1
的原因是,在进入for循环的时候for run in range(N_RUNS):
,run
是从0
开始计数,但运行次数应该从1
开始计数。
if MODEL == 'SVM_grid':
print("Running a grid search SVM")
# Grid search SVM (linear and RBF)
X_train, y_train = build_dataset(img, train_gt,
ignored_labels=IGNORED_LABELS)
class_weight = 'balanced' if CLASS_BALANCING else None
clf = sklearn.svm.SVC(class_weight=class_weight)
clf = sklearn.model_selection.GridSearchCV(clf, SVM_GRID_PARAMS, verbose=5, n_jobs=4)
clf.fit(X_train, y_train)
print("SVM best parameters : {}".format(clf.best_params_))
prediction = clf.predict(img.reshape(-1, N_BANDS))
save_model(clf, MODEL, DATASET)
prediction = prediction.reshape(img.shape[:2])
elif MODEL == 'SVM':
X_train, y_train = build_dataset(img, train_gt,
ignored_labels=IGNORED_LABELS)
class_weight = 'balanced' if CLASS_BALANCING else None
clf = sklearn.svm.SVC(class_weight=class_weight)
clf.fit(X_train, y_train)
save_model(clf, MODEL, DATASET)
prediction = clf.predict(img.reshape(-1, N_BANDS))
prediction = prediction.reshape(img.shape[:2])
elif MODEL == 'SGD':
X_train, y_train = build_dataset(img, train_gt,
ignored_labels=IGNORED_LABELS)
X_train, y_train = sklearn.utils.shuffle(X_train, y_train)
scaler = sklearn.preprocessing.StandardScaler()
X_train = scaler.fit_transform(X_train)
class_weight = 'balanced' if CLASS_BALANCING else None
clf = sklearn.linear_model.SGDClassifier(class_weight=class_weight, learning_rate='optimal', tol=1e-3, average=10)
clf.fit(X_train, y_train)
save_model(clf, MODEL, DATASET)
prediction = clf.predict(scaler.transform(img.reshape(-1, N_BANDS)))
prediction = prediction.reshape(img.shape[:2])
elif MODEL == 'nearest':
X_train, y_train = build_dataset(img, train_gt,
ignored_labels=IGNORED_LABELS)
X_train, y_train = sklearn.utils.shuffle(X_train, y_train)
class_weight = 'balanced' if CLASS_BALANCING else None
clf = sklearn.neighbors.KNeighborsClassifier(weights='distance')
clf = sklearn.model_selection.GridSearchCV(clf, {'n_neighbors': [1, 3, 5, 10, 20]}, verbose=5, n_jobs=4)
clf.fit(X_train, y_train)
clf.fit(X_train, y_train)
save_model(clf, MODEL, DATASET)
prediction = clf.predict(img.reshape(-1, N_BANDS))
prediction = prediction.reshape(img.shape[:2])
这部分都是非神经网络的模型,不是我的方向,所以暂略。
else:
# Neural network
……
model, optimizer, loss, hyperparams = get_model(MODEL, **hyperparams)
超参数包括:模型model
,优化器optimizer
,损失函数loss
,超参数hyperparams
。
其中hyperparams
中是超参数的字典,包括epoch
,batch_size
,patch_size
等。
最后打印的model
、optimizer
、loss
、hyperparams
的信息如下:
Baseline(
(fc1): Linear(in_features=103, out_features=2048, bias=True)
(fc2): Linear(in_features=2048, out_features=4096, bias=True)
(fc3): Linear(in_features=4096, out_features=2048, bias=True)
(fc4): Linear(in_features=2048, out_features=10, bias=True)
)
Adam (
Parameter Group 0
amsgrad: False
betas: (0.9, 0.999)
eps: 1e-08
lr: 0.0001
weight_decay: 0
)
CrossEntropyLoss()
{'dataset': 'PaviaU', 'model': 'nn', 'folder': './Datasets/', 'cuda': 0, 'runs': 1, 'training_sample': 0.1, 'sampling_mode': 'random', 'class_balancing': False, 'test_stride': 1, 'flip_augmentation': False, 'radiation_augmentation': False, 'mixture_augmentation': False, 'with_exploration': False, 'n_classes': 10, 'n_bands': 103, 'ignored_labels': [0], 'device': device(type='cuda', index=0), 'weights': tensor([0., 1., 1., 1., 1., 1., 1., 1., 1., 1.], device='cuda:0'), 'patch_size': 1, 'dropout': False, 'learning_rate': 0.0001, 'epoch': 100, 'batch_size': 100, 'scheduler': , 'supervision': 'full', 'center_pixel': True}
可以看到网络各层的类型和参数,优化器optimizer
为喜闻乐见的Adam,损失函数为交叉熵 CrossEntropyLoss,还有hyperparams
的信息。
if CLASS_BALANCING:
weights = compute_imf_weights(train_gt, N_CLASSES, IGNORED_LABELS)
hyperparams['weights'] = torch.from_numpy(weights)
由于CLASS_BALANCING
默认是False
,所以这一部分,暂略。
train_gt, val_gt = sample_gt(train_gt, 0.95, mode='random')
train_loader = data.DataLoader(train_dataset,
batch_size=hyperparams['batch_size'],
#pin_memory=hyperparams['device'],
shuffle=True)
val_dataset = HyperX(img, val_gt, **hyperparams)
val_loader = data.DataLoader(val_dataset,
#pin_memory=hyperparams['device'],
batch_size=hyperparams['batch_size'])
train
、val
和test
的含义和区别见:训练过程中的train,val,test的区别
简单来说,val
是从train
中分出来的,用来检查网络是否对训练集过拟合(这种情况对应着随着训练的进行,train
的损失越来越小,但val
的损失越来越大)。
这一行代码是的意思是:从train_gt
中取95%作为新的train_gt
,而将5%作为val_gt
。
运行结果:
4063 samples selected (over 4277)
额外需要注意的是:选择的方式是将被选中的sample保持不变,未被选中的置为0(ignored_label
),所以维度依旧保持不变,依旧保持为(610, 340)。
print(train_gt)
# [0 0 0 ... 0 0 0]
# [0 0 0 ... 0 0 0]
# ...
# [0 0 0 ... 0 0 0]
# [0 0 0 ... 0 0 0]
# [0 0 0 ... 0 0 0]]
print(train_gt.shape)
# (610, 340)
print(val_gt)
# [0 0 0 ... 0 0 0]
# [0 0 0 ... 0 0 0]
# ...
# [0 0 0 ... 0 0 0]
# [0 0 0 ... 0 0 0]
# [0 0 0 ... 0 0 0]]
print(val_gt.shape)
# (610, 340)
train_dataset = HyperX(img, train_gt, **hyperparams)
train_loader = data.DataLoader(train_dataset,
batch_size=hyperparams['batch_size'],
#pin_memory=hyperparams['device'],
shuffle=True)
val_dataset = HyperX(img, val_gt, **hyperparams)
val_loader = data.DataLoader(val_dataset,
#pin_memory=hyperparams['device'],
batch_size=hyperparams['batch_size'])
train_dataset = HyperX(img, train_gt, **hyperparams)
第一句是,train_dataset
是从class HyperX
中 继承,属性为对象:
<datasets.HyperX object at 0x000001403EB62CF8>
之后通过这段代码查看对象的属性(自加,没在项目代码中):
for attr in dir(train_dataset):
print(attr)
print(getattr(train_dataset, attr))
查看到的属性有这些:
name
data
label
labels
indices
ignored_labels
name
# PaviaU
data
# [[[0.080875 0.062375 0.058 ... 0.402625 0.40475 0.40625 ]
# [0.0755 0.06825 0.065875 ... 0.30525 0.308 0.316 ]
# [0.077625 0.09325 0.0695 ... 0.2885 0.293125 0.295125]
# ...
# [0.074125 0.048375 0.0535 ... 0.29775 0.300875 0.302875]
# [0.074125 0.093875 0.081875 ... 0.289 0.2885 0.286125]
# [0.111125 0.09 0.056125 ... 0.302 0.305875 0.310625]]]
label
# [[0 0 0 ... 0 0 0]
# [0 0 0 ... 0 0 0]
# [0 0 0 ... 0 0 0]
# ...
# [0 0 0 ... 0 0 0]
# [2 0 2 ... 0 0 0]
# [0 0 0 ... 0 0 0]]
labels
# [1, 1, 4, 4, 1, 1, 1, 1, 1, 1, 1, 1, 4, 1, 1, 4, 4, 1, 4, 4, 4, 4, 1, 4, 4, 1, 4, 1, 4, 4, 1, 1, 4, 1, 1, 4, 4, 4, 4, 1, 1, 4, 1, 4, 4, 4, 1, 1, 1, 1, 4, 4, 4, 4, 1, 4, 1, 1, 1, 1, 4, 4, 4, …… 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]
indices
# [[511 237]
# [551 94]
# [595 176]
# ...
# [142 173]
# [403 75]
# [575 135]]
ignored_labels
# {0}
patch_size
center_pixel
mixture_augmentation
flip
radiation_augmentation
mixture_noise
patch_size
# 1
center_pixel
# True
mixture_augmentation
# False
flip_augmentation
# False
radiation_augmentation
#
mixture_noise
# >
flip
#
train_loader = data.DataLoader(train_dataset, ……)
这一句是应该是创建了一个类似DataLoader
的东西。
PyTorch中数据读取的一个重要接口是torch.utils.data.DataLoader,该接口定义在dataloader.py脚本中,只要是用PyTorch来训练模型基本都会用到该接口,该接口主要用来将自定义的数据读取接口的输出(读取的结果)或者PyTorch已有的数据读取接口的输出(读取的结果)按照
batch size
封装成Tensor,之后便可作为模型的输入,因此该接口有点承上启下的作用,比较重要。
这一段代码来打印train_loader
的信息:
print(train_loader)
print('Over')
for attr in dir(train_loader):
if '_' not in attr:
print(attr)
print(getattr(train_loader, attr))
print('Done!')
结果为:
<torch.utils.data.dataloader.DataLoader object at 0x000001BFC5A714E0>
Over
dataset
<datasets.HyperX object at 0x000001BFC5A71F98>
sampler
<torch.utils.data.sampler.RandomSampler object at 0x000001BFC5BBCC50>
timeout
0
Done!
下面两行的代码和这个一样,只是对val
操作,不再赘述。
打印结果为:
{'dataset': 'PaviaU', 'model': 'nn', 'folder': './Datasets/', 'cuda': 0, 'runs': 1, 'training_sample': 0.1, 'sampling_mode': 'random', 'class_balancing': False, 'test_stride': 1, 'flip_augmentation': False, 'radiation_augmentation': False, 'mixture_augmentation': False, 'with_exploration': False, 'n_classes': 10, 'n_bands': 103, 'ignored_labels': [0], 'device': device(type='cuda', index=0), 'weights': tensor([0., 1., 1., 1., 1., 1., 1., 1., 1., 1.], device='cuda:0'), 'patch_size': 1, 'dropout': False, 'learning_rate': 0.0001, 'epoch': 100, 'batch_size': 100, 'scheduler': <torch.optim.lr_scheduler.ReduceLROnPlateau object at 0x000001AC80528208>, 'supervision': 'full', 'center_pixel': True}
print("Network :")
with torch.no_grad():
for input, _ in train_loader:
break
# summary(model.to(hyperparams['device']), input.size()[1:], device=hyperparams['device'])
summary(model.to(hyperparams['device']), input.size()[1:])
# -------------------------------------------
# 自加
print('begin')
print('input: ',input)
print('input.shape: ',input.shape)
print('input.size()[1:]: ',input.size()[1:])
os.system('pause')
# --------------------------------------------
这部分是打印网络的信息,包括:
用到的函数库是torchsummary中的summary()
方法。summary()
方法的语法如下:
summary(your_model, input_size=(channels, H, W))
your_model
:定义好的模型。input_size
:输入数据的维度,顺序为channels × H × W
这部分的思路是:
首先设定with torch.no_grad():
(为什么?不知道,暂略)。
然后通过for input, _ in train_loader:
来获取train_loader
中的input
,获取之后(第一次进入for循环)就通过break
跳出循环。简单来说,train_loader
是一个迭代器,每次遍历(调用)这个迭代器,得到一个batch
的sample。
然后调用summary()
方法打印模型信息。summary(model.to(hyperparams['device']), input.size()[1:])
。
为了方便讲解,我写了一段打印相关信息的代码:
# --------------------------------------------
print('begin')
print('input: ',input)
print('input.shape: ',input.shape)
print('input.size()[1:]: ',input.size()[1:])
os.system('pause')
# --------------------------------------------
打印出的信息是:
Network :
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Linear-1 [-1, 2048] 212,992
Linear-2 [-1, 4096] 8,392,704
Linear-3 [-1, 2048] 8,390,656
Linear-4 [-1, 10] 20,490
================================================================
Total params: 17,016,842
Trainable params: 17,016,842
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.06
Params size (MB): 64.91
Estimated Total Size (MB): 64.98
----------------------------------------------------------------
begin
input: tensor([[0.1695, 0.1893, 0.1727, ..., 0.1466, 0.1446, 0.1451],
[0.0589, 0.0333, 0.0351, ..., 0.4984, 0.5005, 0.5049],
[0.1375, 0.1199, 0.0784, ..., 0.2763, 0.2820, 0.2906],
...,
[0.0237, 0.0671, 0.0636, ..., 0.4619, 0.4638, 0.4703],
[0.1209, 0.0480, 0.0318, ..., 0.4645, 0.4672, 0.4769],
[0.0845, 0.0955, 0.1168, ..., 0.2921, 0.2962, 0.3014]])
input.shape: torch.Size([100, 103])
input.size()[1:]: torch.Size([103])
另外强调一下为什么input
的shape
是 torch.Size([100, 103])
:因为data.DataLoader()
将自定义的数据读取接口的输出(读取的结果)或者PyTorch已有的数据读取接口的输出(读取的结果)按照batch size
封装成Tensor,而这里设定的batch_size
是100
,所以有 torch.Size([100, 103])
。
但是这里遗留一个问题(暂略):
summary()
语法的要求输入的参数是(channels, H, W)
,但是程序中输入的却是input.size()[1:]
,打印其值为torch.Size([103])
,不是(channels, H, W)
的形式。虽然在这个只有四个线性层的网络中计算 Output Shape 的时候并不需要H
和W
,但是就可以不输入了嘛?(个人的一个猜测,有趣前面的100
是batch_size
,后面的103
是channel
,而这里每一个sample的维度都是1 × 1 × 103
,是不是1
可以默认不写呢?)。
为了验证我的猜想,验证过程和结果见:summary():记一次不明不白的debug.txt吧,这里不赘述了。
if CHECKPOINT is not None:
model.load_state_dict(torch.load(CHECKPOINT))
其中load_state_dict()
的介绍是:
load_state_dict
(state_dict)[SOURCE]Loads the optimizer state.
Parameters
state_dict (dict) – optimizer state. Should be an object returned from a call to
state_dict()
.
其中 state_dict()
:
state_dict
()[SOURCE]Returns the state of the optimizer as a
dict
.It contains two entries:
state - a dict holding current optimization state. Its content
differs between optimizer classes.
param_groups - a dict containing all parameter groups
这部分代码的作用是:读取之前保存的模型的参数。而且这个“参数”指的是模型中的可训练参数。
原因是summary()
的时候打印:
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.06
Params size (MB): 64.91
Estimated Total Size (MB): 64.98
----------------------------------------------------------------
而我们检查保存的参数(C:\Users\73416\checkpoints\baseline\OwnData),发现大小是65.0 MB (68,191,571 字节),几乎完美匹配。
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-KLSBhwzW-1570353563160)(C:\Users\73416\AppData\Local\Temp\1567521820597.png)]
但是由于一般来说CHECKPOINT
为None
,所以一般不执行这部分的代码。
try:
train(model, optimizer, loss, train_loader, hyperparams['epoch'],
scheduler=hyperparams['scheduler'], device=hyperparams['device'],
supervision=hyperparams['supervision'], val_loader=val_loader,
display=viz)
except KeyboardInterrupt:
# Allow the user to stop the training
pass
这一部分的主要代码就是:
train(model, optimizer, loss, train_loader, hyperparams['epoch'],
scheduler=hyperparams['scheduler'], device=hyperparams['device'],
supervision=hyperparams['supervision'], val_loader=val_loader,
display=viz)
训练的具体过程,请见 model.py 中的 train()
函数。
try-except
仅仅是为了增加程序的鲁棒性。
probabilities = test(model, img, hyperparams)
prediction = np.argmax(probabilities, axis=-1)
获得图像img
的test的结果。
需要注意test()
函数的返回值是W × H × n_classes
,最后一个维度n_classes
正好对应分类的类数。
np.argmax()
是Returns the indices of the maximum values along an axis(返回沿axis
的最大值的索引。). ,需要注意的是返回值是索引,但正好对应n_classes
(在哪个位置的数值最大,就是哪一类)。
额外注意np.argmax()
的返回值:
Returns:
index_array : ndarray of ints
Array of indices into the array. It has the same shape as a.shape with the dimension along axis removed.
简单来说就是移除指定的维度,并寻找该维度的最大值,返回元素的索引。
用一个小demo演示这部分的两句代码:
# 1一个batch(100个sample)经过net()得到的output的数据
output = np.array([[-3.4172e+00, 4.7474e-01, 3.6548e+00, 9.0230e-01, 7.8329e-01,
-1.0330e+00, 2.2331e+00, -6.2564e-01, 4.0340e-01, -3.0324e+00],
[-3.1507e+00, 2.2456e+00, 2.9226e-01, 1.8920e+00, -8.8580e-01,
7.2265e-01, 1.0055e+00, -4.8283e-02, 1.5603e+00, -2.6959e+00],
[-5.2190e+00, 3.3996e+00, -1.6052e+00, 2.5934e+00, -1.1674e+00,
5.8359e+00, 4.2171e-01, 2.4511e-01, 2.2210e+00, -3.8763e+00],
[-2.6106e+00, -7.5499e-01, 4.3269e+00, -1.6255e-01, 2.8252e+00,
-1.4511e+00, 1.7670e+00, -8.3349e-01, -5.4468e-01, -2.2654e+00],
[-3.0543e+00, 9.9883e-01, 2.3457e+00, 1.1807e+00, 6.2201e-02,
-6.4643e-01, 1.7674e+00, -4.1443e-01, 7.6956e-01, -2.7224e+00],
[-2.3868e+00, -4.9510e-01, 3.7331e+00, -8.3138e-03, 2.1344e+00,
-1.2079e+00, 1.6056e+00, -6.8800e-01, -3.8621e-01, -2.0586e+00],
[-3.2416e+00, -4.5884e-02, 4.2205e+00, 5.1736e-01, 1.4380e+00,
-1.2833e+00, 2.2156e+00, -7.2556e-01, 2.3806e-02, -2.8677e+00],
[-2.9053e+00, 1.6935e-01, 3.4237e+00, 6.0061e-01, 1.0369e+00,
-9.5306e-01, 1.9077e+00, -5.9553e-01, 1.6974e-01, -2.5703e+00],
[-7.1136e+00, 4.7759e+00, -2.2519e+00, 3.6735e+00, -1.7759e+00,
7.8180e+00, 6.6654e-01, 3.6286e-01, 3.1942e+00, -5.3241e+00],
[-3.0338e+00, -6.3117e-02, 3.9705e+00, 4.5521e-01, 1.3632e+00,
-1.1764e+00, 2.0957e+00, -6.7216e-01, 1.7762e-02, -2.6780e+00],
[-2.4539e+00, 1.5082e+00, 6.6919e-01, 1.3515e+00, -4.8787e-01,
2.2641e-01, 9.3917e-01, -1.2048e-01, 1.0607e+00, -2.1299e+00],
[-3.2204e+00, 1.9349e+00, 9.7008e-01, 1.7553e+00, -6.2425e-01,
2.0920e-01, 1.2744e+00, -1.7639e-01, 1.3713e+00, -2.7858e+00],
[-2.4182e+00, -8.2613e-01, 4.0879e+00, -2.6228e-01, 2.9733e+00,
-1.4046e+00, 1.5851e+00, -8.0221e-01, -5.8830e-01, -2.0608e+00],
[-3.2013e+00, 1.9282e+00, 9.4064e-01, 1.7319e+00, -5.9678e-01,
2.6542e-01, 1.2547e+00, -1.6712e-01, 1.3512e+00, -2.7726e+00],
[-2.5719e+00, -5.3754e-01, 4.0272e+00, -4.3249e-03, 2.3080e+00,
-1.3214e+00, 1.7518e+00, -7.5447e-01, -4.1301e-01, -2.2544e+00],
[-3.1710e+00, 1.8300e+00, 1.0955e+00, 1.7135e+00, -5.8126e-01,
8.5049e-02, 1.3449e+00, -2.0962e-01, 1.3227e+00, -2.7708e+00],
[-2.1461e+00, 1.4345e+00, 3.2470e-01, 1.1824e+00, -5.0001e-01,
5.2405e-01, 7.0300e-01, -6.3770e-02, 9.5943e-01, -1.8195e+00],
[-2.4559e+00, 1.6129e+00, 4.1325e-01, 1.3720e+00, -5.4098e-01,
5.1285e-01, 8.0745e-01, -5.2784e-02, 1.0950e+00, -2.0901e+00],
[-2.5694e+00, 1.3131e+00, 1.1399e+00, 1.2092e+00, -2.4525e-01,
6.9929e-02, 1.0451e+00, -1.8160e-01, 8.5916e-01, -2.2093e+00],
[-2.0823e+00, 1.3539e+00, 4.1539e-01, 1.1599e+00, -4.1376e-01,
3.9781e-01, 6.8572e-01, -6.6757e-02, 9.0060e-01, -1.7785e+00],
[-2.2904e+00, 1.0222e-01, 2.7664e+00, 4.1632e-01, 9.4062e-01,
-8.1115e-01, 1.3992e+00, -4.5315e-01, 4.9409e-02, -1.9998e+00],
[-2.4030e+00, -1.2090e+00, 3.9664e+00, -6.5414e-01, 4.6159e+00,
-1.4920e+00, 1.2595e+00, -8.8912e-01, -8.5802e-01, -1.9434e+00],
[-3.0300e+00, -7.2083e-01, 4.8153e+00, -9.8287e-02, 2.8917e+00,
-1.5489e+00, 2.0796e+00, -8.7492e-01, -5.1128e-01, -2.5959e+00],
[-2.3685e+00, 1.6391e+00, 2.4802e-01, 1.3489e+00, -5.3888e-01,
6.4715e-01, 6.5692e-01, -4.1977e-02, 1.0578e+00, -2.0024e+00],
[-1.9444e+00, -5.2009e-01, 3.1849e+00, -9.3790e-02, 1.9984e+00,
-1.0692e+00, 1.3111e+00, -6.1101e-01, -3.8485e-01, -1.6877e+00],
[-3.0943e+00, 2.3930e+00, -3.5446e-01, 1.8568e+00, -9.6759e-01,
1.6489e+00, 5.9468e-01, 7.7725e-02, 1.5249e+00, -2.5396e+00],
[-2.6556e+00, -3.6977e-01, 3.9245e+00, 1.4978e-01, 1.8240e+00,
-1.2115e+00, 1.8738e+00, -6.9119e-01, -2.3623e-01, -2.3246e+00],
[-3.1395e+00, 4.1125e-02, 3.9727e+00, 5.3787e-01, 1.2732e+00,
-1.1591e+00, 2.1246e+00, -6.8051e-01, 6.9807e-02, -2.7764e+00],
[-2.9746e+00, 1.8802e-01, 3.5310e+00, 6.3683e-01, 9.6213e-01,
-1.0437e+00, 2.0196e+00, -6.0621e-01, 2.0326e-01, -2.6465e+00],
[-2.8328e+00, 1.6977e+00, 8.0364e-01, 1.4708e+00, -4.7245e-01,
3.5556e-01, 1.0646e+00, -1.3904e-01, 1.1448e+00, -2.4328e+00],
[-3.1625e+00, 2.0841e+00, 6.1937e-01, 1.7831e+00, -7.0946e-01,
5.2520e-01, 1.1075e+00, -1.1551e-01, 1.4229e+00, -2.7167e+00],
[-4.0167e+00, 1.7574e+00, 2.2763e+00, 1.7559e+00, -1.8198e-01,
-1.7509e-01, 1.9423e+00, -4.0833e-01, 1.2436e+00, -3.5162e+00],
[-2.4236e+00, -7.8535e-01, 4.0795e+00, -2.2951e-01, 2.8136e+00,
-1.3446e+00, 1.6017e+00, -7.9643e-01, -5.4431e-01, -2.0819e+00],
[-2.2582e+00, -4.6274e-01, 3.5520e+00, 1.3249e-02, 1.8428e+00,
-1.1697e+00, 1.5907e+00, -6.3526e-01, -3.2695e-01, -1.9612e+00],
[-3.4212e+00, 2.2788e+00, 6.3480e-01, 1.9573e+00, -7.8700e-01,
5.6299e-01, 1.1635e+00, -1.1935e-01, 1.5660e+00, -2.9292e+00],
[-2.5760e+00, -3.0569e-01, 3.7433e+00, 1.4583e-01, 1.8048e+00,
-1.1692e+00, 1.7010e+00, -6.4062e-01, -2.5369e-01, -2.2543e+00],
[-3.0344e+00, -1.4159e-01, 4.1256e+00, 3.8586e-01, 1.5449e+00,
-1.2349e+00, 2.0604e+00, -7.1976e-01, -7.2623e-02, -2.6631e+00],
[-2.8772e+00, -1.1097e+00, 4.8425e+00, -4.5681e-01, 4.0864e+00,
-1.6900e+00, 1.7676e+00, -9.4947e-01, -8.0563e-01, -2.4084e+00],
[-3.0412e+00, -1.5252e+00, 4.8663e+00, -8.4481e-01, 5.9819e+00,
-1.8587e+00, 1.5335e+00, -1.0996e+00, -1.0606e+00, -2.4177e+00],
[-2.2495e+00, -6.2242e-01, 3.6935e+00, -1.0642e-01, 2.3495e+00,
-1.2321e+00, 1.4774e+00, -7.2164e-01, -4.5571e-01, -1.9474e+00],
[-2.1554e+00, 1.3898e-01, 2.5519e+00, 4.1213e-01, 8.7232e-01,
-7.1161e-01, 1.2969e+00, -4.2464e-01, 5.3723e-02, -1.8957e+00],
[-2.9576e+00, -1.3114e+00, 5.1378e+00, -5.9229e-01, 4.5756e+00,
-1.7936e+00, 1.7785e+00, -1.0670e+00, -9.2556e-01, -2.4614e+00],
[-3.2555e+00, 1.5876e+00, 1.6341e+00, 1.5986e+00, -3.8576e-01,
-2.5096e-01, 1.5805e+00, -2.9209e-01, 1.1768e+00, -2.8661e+00],
[-3.4841e+00, 2.1306e+00, 9.4273e-01, 1.9013e+00, -6.5666e-01,
3.6733e-01, 1.3094e+00, -1.7591e-01, 1.4909e+00, -3.0080e+00],
[-2.6257e+00, -3.5297e-01, 3.8823e+00, 1.4276e-01, 1.8069e+00,
-1.1846e+00, 1.8600e+00, -6.9508e-01, -2.3632e-01, -2.3041e+00],
[-5.7164e+00, 4.1493e+00, 1.6602e-01, 3.4875e+00, -1.6735e+00,
1.7104e+00, 1.6943e+00, -1.0589e-02, 2.8905e+00, -4.8421e+00],
[-2.2660e+00, -1.0308e+00, 3.8669e+00, -4.9326e-01, 3.7366e+00,
-1.4057e+00, 1.3182e+00, -8.2074e-01, -7.1355e-01, -1.8841e+00],
[-2.5752e+00, -5.6564e-01, 4.0870e+00, -1.0507e-02, 2.2154e+00,
-1.4014e+00, 1.8387e+00, -7.4375e-01, -3.7901e-01, -2.2514e+00],
[-2.9679e+00, -1.0837e+00, 4.9977e+00, -3.9962e-01, 3.9723e+00,
-1.7014e+00, 1.8724e+00, -9.8039e-01, -7.8818e-01, -2.5105e+00],
[-3.2278e+00, 2.0735e+00, 7.2812e-01, 1.7949e+00, -6.9313e-01,
4.3198e-01, 1.1599e+00, -1.4001e-01, 1.4349e+00, -2.7687e+00],
[-2.9973e+00, -1.3724e+00, 5.0783e+00, -6.5903e-01, 4.9946e+00,
-1.8330e+00, 1.6992e+00, -1.0746e+00, -9.6564e-01, -2.4660e+00],
[-2.4994e+00, 1.7602e+00, 1.6353e-01, 1.4317e+00, -6.5011e-01,
7.9777e-01, 7.3128e-01, -2.0274e-02, 1.1598e+00, -2.1139e+00],
[-2.1210e+00, -4.5631e-01, 3.3404e+00, -8.2309e-04, 1.8792e+00,
-1.0719e+00, 1.4665e+00, -6.0914e-01, -3.1663e-01, -1.8488e+00],
[-2.0178e+00, 7.5944e-01, 1.3621e+00, 8.4683e-01, -8.1003e-02,
-3.3616e-01, 1.1177e+00, -2.3672e-01, 5.8855e-01, -1.7959e+00],
[-4.0473e+00, -1.5565e+00, 6.4261e+00, -6.8970e-01, 6.1694e+00,
-2.2335e+00, 2.2404e+00, -1.3308e+00, -1.1467e+00, -3.3744e+00],
[-3.7822e+00, -1.1448e+00, 6.0267e+00, -3.5539e-01, 4.6618e+00,
-1.9564e+00, 2.2661e+00, -1.1758e+00, -8.8636e-01, -3.1916e+00],
[-3.1418e-01, 1.5967e-01, 4.5192e-02, 1.1821e-01, 5.7125e-02,
1.3193e-01, -1.1263e-02, -2.5072e-02, 3.2257e-02, -1.6350e-01],
[-2.5006e+00, -3.7688e-01, 3.7182e+00, 9.5787e-02, 1.8344e+00,
-1.1744e+00, 1.7026e+00, -6.4518e-01, -2.8520e-01, -2.1683e+00],
[-2.7033e+00, -1.2633e-01, 3.6465e+00, 3.3531e-01, 1.3966e+00,
-1.0790e+00, 1.8082e+00, -6.1470e-01, -6.8538e-02, -2.3573e+00],
[-3.1769e+00, 2.5477e-01, 3.6971e+00, 7.0723e-01, 9.9888e-01,
-1.0660e+00, 2.1025e+00, -6.3563e-01, 2.2431e-01, -2.8429e+00],
[-3.1090e+00, -2.1315e-01, 4.2975e+00, 3.8161e-01, 1.6554e+00,
-1.3550e+00, 2.1961e+00, -7.5499e-01, -8.2681e-02, -2.7584e+00],
[-3.3324e+00, 1.0784e+00, 2.5192e+00, 1.2437e+00, 1.9829e-01,
-4.5998e-01, 1.7754e+00, -4.3879e-01, 7.8876e-01, -2.9155e+00],
[-1.9215e+00, -2.6999e-01, 2.8816e+00, 1.0510e-01, 1.3734e+00,
-8.9469e-01, 1.3081e+00, -5.1852e-01, -2.0649e-01, -1.6914e+00],
[-2.7723e+00, -3.9382e-02, 3.5956e+00, 4.6801e-01, 1.2204e+00,
-1.1168e+00, 1.9174e+00, -6.2143e-01, 3.5122e-02, -2.4784e+00],
[-3.1408e+00, 8.7559e-01, 2.7258e+00, 1.1185e+00, 1.7175e-01,
-8.6532e-01, 1.9308e+00, -4.5202e-01, 7.0314e-01, -2.8272e+00],
[-2.8434e+00, -1.1880e+00, 4.9346e+00, -4.9937e-01, 4.2047e+00,
-1.7423e+00, 1.7886e+00, -9.9772e-01, -8.4582e-01, -2.3973e+00],
[-2.6930e+00, 3.8193e-01, 2.8897e+00, 6.8692e-01, 6.8390e-01,
-7.9716e-01, 1.6867e+00, -4.8306e-01, 2.7250e-01, -2.3930e+00],
[-2.7698e+00, 1.6703e+00, 8.2543e-01, 1.4979e+00, -5.2919e-01,
1.8657e-01, 1.0821e+00, -1.5562e-01, 1.1658e+00, -2.3999e+00],
[-2.9499e+00, 2.0103e+00, 4.1318e-01, 1.6914e+00, -7.0104e-01,
6.3629e-01, 9.3164e-01, -5.5044e-02, 1.3423e+00, -2.5152e+00],
[-2.9679e+00, 8.5536e-01, 2.4318e+00, 1.0781e+00, 1.7901e-01,
-6.4314e-01, 1.7377e+00, -4.0450e-01, 6.6847e-01, -2.6346e+00],
[-2.1822e+00, -2.1775e-01, 3.0966e+00, 1.6233e-01, 1.4289e+00,
-9.0650e-01, 1.5137e+00, -5.4407e-01, -1.4786e-01, -1.9107e+00],
[-2.5802e+00, 1.7874e+00, 2.8635e-01, 1.4753e+00, -6.2779e-01,
6.7400e-01, 7.6219e-01, -2.9079e-02, 1.1687e+00, -2.1778e+00],
[-3.1480e+00, 1.5729e+00, 1.4711e+00, 1.5068e+00, -3.1229e-01,
-2.4183e-02, 1.3990e+00, -2.6736e-01, 1.1071e+00, -2.7398e+00],
[-2.4344e+00, -6.0649e-01, 3.9408e+00, -7.0979e-02, 2.3058e+00,
-1.2803e+00, 1.6756e+00, -7.3830e-01, -4.2828e-01, -2.1105e+00],
[-2.3972e+00, -8.4519e-01, 4.0643e+00, -3.0006e-01, 3.0803e+00,
-1.3756e+00, 1.5393e+00, -8.0675e-01, -6.0961e-01, -2.0428e+00],
[-3.2675e+00, 2.1267e+00, 6.8864e-01, 1.8853e+00, -7.4988e-01,
4.1254e-01, 1.1888e+00, -1.2328e-01, 1.5016e+00, -2.8328e+00],
[-2.2283e+00, -5.6476e-01, 3.6542e+00, -7.9373e-02, 2.1571e+00,
-1.1448e+00, 1.5443e+00, -6.7334e-01, -3.8690e-01, -1.9310e+00],
[-2.5529e+00, -1.6107e-01, 3.5172e+00, 2.6857e-01, 1.4543e+00,
-1.0681e+00, 1.7158e+00, -5.9504e-01, -1.2687e-01, -2.2421e+00],
[-2.9471e+00, 1.3661e+00, 1.5572e+00, 1.3502e+00, -1.8738e-01,
-1.0739e-01, 1.3368e+00, -2.8337e-01, 9.6195e-01, -2.5632e+00],
[-4.0490e+00, 2.6916e+00, -1.2811e+00, 2.0386e+00, -9.5085e-01,
4.5006e+00, 2.7572e-01, 1.7540e-01, 1.7477e+00, -3.0254e+00],
[-2.2811e+00, -6.6988e-01, 3.8223e+00, -1.5893e-01, 2.4005e+00,
-1.2246e+00, 1.5716e+00, -7.0980e-01, -4.5187e-01, -1.9646e+00],
[-2.6464e+00, 1.8344e+00, 2.7921e-01, 1.5202e+00, -6.4925e-01,
7.0848e-01, 8.1471e-01, -3.6928e-02, 1.2366e+00, -2.2512e+00],
[-2.2836e+00, 1.5426e+00, 3.0953e-01, 1.2679e+00, -5.0542e-01,
5.8676e-01, 6.8070e-01, -4.7086e-02, 9.9322e-01, -1.9262e+00],
[-2.1354e+00, 2.6699e-02, 2.6765e+00, 3.7227e-01, 9.4378e-01,
-7.5686e-01, 1.3870e+00, -4.7164e-01, 2.3713e-02, -1.8903e+00],
[-3.2814e+00, 1.3032e-01, 3.9823e+00, 6.4409e-01, 1.2050e+00,
-1.1468e+00, 2.1940e+00, -6.6835e-01, 1.5291e-01, -2.9050e+00],
[-2.5112e+00, 1.5445e+00, 7.0261e-01, 1.3935e+00, -5.2214e-01,
1.9404e-01, 9.8879e-01, -1.4678e-01, 1.1036e+00, -2.1830e+00],
[-2.3466e+00, -7.3760e-01, 3.8941e+00, -1.9764e-01, 2.7026e+00,
-1.3215e+00, 1.5114e+00, -7.5438e-01, -5.3157e-01, -1.9959e+00],
[-2.9707e+00, -1.3565e+00, 5.1351e+00, -6.3653e-01, 4.8045e+00,
-1.8174e+00, 1.7231e+00, -1.0866e+00, -9.6608e-01, -2.4615e+00],
[-4.7498e-01, 1.2949e-01, 3.7521e-01, 1.2195e-01, 1.4310e-01,
-2.9742e-02, 1.8941e-01, -6.1047e-02, 4.5253e-02, -3.7414e-01],
[-2.6389e+00, -2.8462e-01, 3.7923e+00, 2.1464e-01, 1.6867e+00,
-1.1546e+00, 1.8157e+00, -6.7523e-01, -1.9182e-01, -2.3107e+00],
[-3.0642e+00, 1.6007e-02, 3.9244e+00, 5.0872e-01, 1.3313e+00,
-1.1347e+00, 2.0441e+00, -6.8004e-01, 3.6863e-02, -2.7183e+00],
[-2.4678e+00, -4.8809e-01, 3.8563e+00, 2.0195e-02, 2.0485e+00,
-1.2082e+00, 1.7291e+00, -7.0810e-01, -3.3191e-01, -2.1626e+00],
[-3.5040e+00, 2.4570e+00, 3.6188e-01, 2.0714e+00, -9.0356e-01,
8.4894e-01, 1.0637e+00, -6.7259e-02, 1.6715e+00, -2.9999e+00],
[-2.2794e+00, -9.1547e-01, 3.9683e+00, -3.7716e-01, 3.1500e+00,
-1.3683e+00, 1.4711e+00, -7.7084e-01, -6.3944e-01, -1.8953e+00],
[-1.7820e+00, 1.2340e+00, 1.4878e-01, 9.9099e-01, -4.0080e-01,
5.7725e-01, 4.6312e-01, -2.5199e-02, 7.7039e-01, -1.4899e+00],
[-2.1745e+00, -1.7468e-01, 3.0339e+00, 1.9426e-01, 1.3234e+00,
-8.8056e-01, 1.4736e+00, -5.1925e-01, -1.3082e-01, -1.8905e+00],
[-2.8420e+00, 1.6299e+00, 9.3658e-01, 1.4826e+00, -4.5678e-01,
1.9835e-01, 1.1390e+00, -1.4821e-01, 1.1446e+00, -2.4459e+00],
[-3.5596e+00, 2.2245e+00, 8.8818e-01, 1.9778e+00, -7.1741e-01,
4.0682e-01, 1.3327e+00, -1.5880e-01, 1.5614e+00, -3.0715e+00],
[-2.8122e+00, -9.3059e-01, 4.7239e+00, -2.9683e-01, 3.4055e+00,
-1.5805e+00, 1.8424e+00, -9.1344e-01, -6.7323e-01, -2.4044e+00],
[-3.1333e+00, 6.4306e-01, 2.9880e+00, 9.6653e-01, 4.6051e-01,
-8.1380e-01, 2.0050e+00, -4.9463e-01, 5.4800e-01, -2.7895e+00]])
print(output)
print(output.shape)
# (100, 10)
prediction = np.argmax(output, axis=-1)
print(prediction)
# [2 1 5 2 2 2 2 2 5 2 1 1 2 1 2 1 1 1 1 1 2 4 2 1 2 1 2 2 2 1 1 2 2 2 1 2 2
# 2 4 2 2 2 2 1 2 1 2 2 2 1 2 1 2 2 2 2 1 2 2 2 2 2 2 2 2 2 2 1 1 2 2 1 1 2
# 2 1 2 2 2 5 2 1 1 2 2 1 2 2 2 2 2 2 1 2 1 2 1 1 2 2]
print(prediction.shape)
# (100,)
可以看到最后能够的到样本经网络预测后所属的类。
====
run_results = metrics(prediction, test_gt, ignored_labels=hyperparams['ignored_labels'], n_classes=N_CLASSES)
运行结果包括:
具体见 utils.py 的metrics()
函数。
mask = np.zeros(gt.shape, dtype='bool')
for l in IGNORED_LABELS:
mask[gt == l] = True
prediction[mask] = 0
这里接着“在测试集上测试并获得测试结果”的部分,指prediction
的结果。
这里创建一个蒙版mask
,目的是将prediction
所有的类别为ignored_label
的位置的值置零(因为ignored_label
对应的类别为0
)。
功能演示见下面的小demo:
IGNORED_LABELS = [0]
gt = np.array([[0,1,0],[2,3,2],[0,3,0]])
prediction = np.array([[10,1,10],[2,3,2],[10,3,10]])
print('former prediction:\n',prediction)
# former prediction:
# [[10 1 10]
# [ 2 3 2]
# [10 3 10]]
mask = np.zeros(gt.shape, dtype='bool')
for l in IGNORED_LABELS:
mask[gt == l] = True
prediction[mask] = 0
print('gt:\n',gt)
# gt:
# [[0 1 0]
# [2 3 2]
# [0 3 0]]
print('mask:\n',mask)
# mask:
# [[ True False True]
# [False False False]
# [ True False True]]
print('processed prediction:\n',prediction)
# processed prediction:
# [[0 1 0]
# [2 3 2]
# [0 3 0]]
color_prediction = convert_to_color(prediction)
display_predictions(color_prediction, viz, gt=convert_to_color(test_gt), caption="Prediction vs. test ground truth")
color_prediction = convert_to_color(prediction)
借用palette
将prediction
不同的类别转为RGB的颜色。
display_predictions()
函数在同一张图中可视化prediction
和test_gt
。
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-GbfG1rTx-1570353563161)(C:\Users\73416\AppData\Local\Temp\1567689193912.png)]
results.append(run_results)
show_results(run_results, viz, label_values=LABEL_VALUES)
results.append(run_results)
是将run_results
加入到results
中,但如果只运行一次(N_RUNS==1
)results
就是空的,这时直接打印run_results
也是一样的结果。但N_RUNS > 1
就不能打印run_results
了。
show_results
使用visdom可视化results
。
至此跳出了N_RUNS的循环,虽然一般也就一个RUNS
if N_RUNS > 1:
show_results(results, viz, label_values=LABEL_VALUES, agregated=True)
这部分是打印N_RUNS
(大于1时)的合并的results
。
##P.S
hyperparams
的初始化是hyperparams = vars(args)
,将命名空间的对象:
Namespace(batch_size=None, class_balancing=False, cuda=0, dataset='PaviaU', download=None, epoch=None, flip_augmentation=False, folder='./Datasets/', lr=None, mixture_augmentation=False, model='nn', patch_size=None, radiation_augmentation=False, restore=None, runs=1, sampling_mode='random', test_set=None, test_stride=1, train_set=None, training_sample=0.1, with_exploration=False)
转化为字典dictionary类型:
{'dataset': 'PaviaU', 'model': 'nn', 'folder': './Datasets/', 'cuda': 0, 'runs': 1, 'restore': None, 'training_sample': 0.1, 'sampling_mode': 'random', 'train_set': None, 'test_set': None, 'epoch': None, 'patch_size': None, 'lr': None, 'class_balancing': False, 'batch_size': None, 'test_stride': 1, 'flip_augmentation': False, 'radiation_augmentation': False, 'mixture_augmentation': False, 'with_exploration': False, 'download': None}
这时,只是读入了cuda
、dataset
、model
、training_sample
这四个超参数,其他的超参数(数据处理的超参数、训练模型的超参数)并没有保存到hyperparams
中,目前还是默认值default
。
# 操作参数
CUDA_DEVICE = get_device(args.cuda)
……
# Number of classes
N_CLASSES = len(LABEL_VALUES)
# Number of bands (last dimension of the image tensor)
N_BANDS = img.shape[-1]
……
img, gt, LABEL_VALUES, IGNORED_LABELS, RGB_BANDS, palette = get_dataset(DATASET, FOLDER)
……
hyperparams.update({'n_classes': N_CLASSES, 'n_bands': N_BANDS, 'ignored_labels': IGNORED_LABELS, 'device': CUDA_DEVICE})
hyperparams = dict((k, v) for k, v in hyperparams.items() if v is not None) # 遍历hyperparams将键值对再变成字典类型
跟新数据处理相关的超参数,并且去掉字典hyperparams
中值为None
的超参数键值对。
更新通过字典的update()
方法实现。
筛选就是for循环 + if判断。
最后的结果hyperparams
是:
hyperparams: {'dataset': 'PaviaU', 'model': 'nn', 'folder': './Datasets/', 'cuda': 0, 'runs': 1, 'training_sample': 0.1, 'sampling_mode': 'random', 'class_balancing': False, 'test_stride': 1, 'flip_augmentation': False, 'radiation_augmentation': False, 'mixture_augmentation': False, 'with_exploration': False, 'n_classes': 10, 'n_bands': 103, 'ignored_labels': [0], 'device': device(type='cuda', index=0)}
# Neural network
model, optimizer, loss, hyperparams = get_model(MODEL, **hyperparams)
if CLASS_BALANCING:
weights = compute_imf_weights(train_gt, N_CLASSES, IGNORED_LABELS)
hyperparams['weights'] = torch.from_numpy(weights)
# Split train set in train/val
train_gt, val_gt = sample_gt(train_gt, 0.95, mode='random')
# Generate the dataset
train_dataset = HyperX(img, train_gt, **hyperparams)
train_loader = data.DataLoader(train_dataset,
batch_size=hyperparams['batch_size'],
#pin_memory=hyperparams['device'],
shuffle=True)
val_dataset = HyperX(img, val_gt, **hyperparams)
val_loader = data.DataLoader(val_dataset,
#pin_memory=hyperparams['device'],
batch_size=hyperparams['batch_size'])
通过这段代码将模型的超参数更新到hyperparams
中。具体的细节问题还没看到。
最后的hyperparams
的结果是:
{'dataset': 'PaviaU', 'model': 'nn', 'folder': './Datasets/', 'cuda': 0, 'runs': 1, 'training_sample': 0.1, 'sampling_mode': 'random', 'class_balancing': False, 'test_stride': 1, 'flip_augmentation': False, 'radiation_augmentation': False, 'mixture_augmentation': False, 'with_exploration': False, 'n_classes': 10, 'n_bands': 103, 'ignored_labels': [0], 'device': device(type='cuda', index=0), 'weights': tensor([0., 1., 1., 1., 1., 1., 1., 1., 1., 1.], device='cuda:0'), 'patch_size': 1, 'dropout': False, 'learning_rate': 0.0001, 'epoch': 100, 'batch_size': 100, 'scheduler': , 'supervision': 'full', 'center_pixel': True}
ugmentation’: False, ‘with_exploration’: False, ‘n_classes’: 10, ‘n_bands’: 103, ‘ignored_labels’: [0], ‘device’: device(type=‘cuda’, index=0)}
##### 更新模型相关的超参数:
```python
# Neural network
model, optimizer, loss, hyperparams = get_model(MODEL, **hyperparams)
if CLASS_BALANCING:
weights = compute_imf_weights(train_gt, N_CLASSES, IGNORED_LABELS)
hyperparams['weights'] = torch.from_numpy(weights)
# Split train set in train/val
train_gt, val_gt = sample_gt(train_gt, 0.95, mode='random')
# Generate the dataset
train_dataset = HyperX(img, train_gt, **hyperparams)
train_loader = data.DataLoader(train_dataset,
batch_size=hyperparams['batch_size'],
#pin_memory=hyperparams['device'],
shuffle=True)
val_dataset = HyperX(img, val_gt, **hyperparams)
val_loader = data.DataLoader(val_dataset,
#pin_memory=hyperparams['device'],
batch_size=hyperparams['batch_size'])
通过这段代码将模型的超参数更新到hyperparams
中。具体的细节问题还没看到。
最后的hyperparams
的结果是:
{'dataset': 'PaviaU', 'model': 'nn', 'folder': './Datasets/', 'cuda': 0, 'runs': 1, 'training_sample': 0.1, 'sampling_mode': 'random', 'class_balancing': False, 'test_stride': 1, 'flip_augmentation': False, 'radiation_augmentation': False, 'mixture_augmentation': False, 'with_exploration': False, 'n_classes': 10, 'n_bands': 103, 'ignored_labels': [0], 'device': device(type='cuda', index=0), 'weights': tensor([0., 1., 1., 1., 1., 1., 1., 1., 1., 1.], device='cuda:0'), 'patch_size': 1, 'dropout': False, 'learning_rate': 0.0001, 'epoch': 100, 'batch_size': 100, 'scheduler': , 'supervision': 'full', 'center_pixel': True}