sample
的
gt
的读取、划分、加载。
不关心其他参数如ignored_labels
等。
位置 | 函数 | 作用 |
---|---|---|
主程序调用 | —— | —— |
datasets.py | get_dataset() | 得到img , gt , LABEL_VALUES , IGNORED_LABELS , RGB_BANDS , palette |
utilis.py | open_file() | 读取数据集,包括PaviaU.mat和PaviaU_gt.mat |
img, gt, LABEL_VALUES, IGNORED_LABELS, RGB_BANDS, palette = get_dataset(DATASET, FOLDER)
elif dataset_name == 'PaviaU':
# Load the image
img = open_file(folder + 'PaviaU.mat')['paviaU']
rgb_bands = (55, 41, 12)
gt = open_file(folder + 'PaviaU_gt.mat')['paviaU_gt']
label_values = ['Undefined', 'Asphalt', 'Meadows', 'Gravel', 'Trees',
'Painted metal sheets', 'Bare Soil', 'Bitumen',
'Self-Blocking Bricks', 'Shadows']
ignored_labels = [0]
def open_file(dataset):
_, ext = os.path.splitext(dataset)
ext = ext.lower()
if ext == '.mat':
# Load Matlab array
return io.loadmat(dataset)
elif ext == '.tif' or ext == '.tiff':
# Load TIFF file
return misc.imread(dataset)
elif ext == '.hdr':
img = spectral.open_image(dataset)
return img.load()
else:
raise ValueError("Unknown file format: {}".format(ext))
位置 | 函数 | 作用 |
---|---|---|
主程序调用 | —— | —— |
utils.py | sample_gt() | 将非ignored_labels 通过取索引的方式,随机划分到train_gt 和test_gt |
_spilt.py | sklearn.model_selection.train_test_split() | 根据所给比例,随机划分输入数据 |
额外说明:
通过gt
划分得到的train_gt
和test_gt
,仍在原图像维度中,只是将取得元素的位置设定为对应的label
,而其他位置置零(默认0为ignored_labels
)。
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-bMhMLvGE-1570353298738)(C:\Users\73416\Desktop\MarkDown\图片文件夹\Hyperspectral-Classification Pytorch 数据集的读取、划分、加载\Train ground truth.jpg)]
train_gt, test_gt = sample_gt(gt, SAMPLE_PERCENTAGE, mode=SAMPLING_MODE)
……
……
test_gt, val_gt = sample_gt(test_gt, 0.95, mode='random')
def sample_gt(gt, train_size, mode='random'):
"""Extract a fixed percentage of samples from an array of labels.
Args:
gt: a 2D array of int labels
percentage: [0, 1] float
Returns:
train_gt, test_gt: 2D arrays of int labels
"""
indices = np.nonzero(gt)
X = list(zip(*indices)) # x,y features (r,c)形式的位置的索引
y = gt[indices].ravel() # classes
train_gt = np.zeros_like(gt)
test_gt = np.zeros_like(gt)
if train_size > 1:
train_size = int(train_size)
if mode == 'random':
train_indices, test_indices = sklearn.model_selection.train_test_split(X, train_size=train_size, stratify=y)
train_indices = [list(t) for t in zip(*train_indices)]
test_indices = [list(t) for t in zip(*test_indices)]
train_gt[train_indices] = gt[train_indices]
test_gt[test_indices] = gt[test_indices]
elif mode == 'fixed':
……
……
……
else:
raise ValueError("{} sampling is not implemented yet.".format(mode))
return train_gt, test_gt
###HyperX的对象train_dataset
位置 | 函数 | 作用 |
---|---|---|
主程序调用 | —— | —— |
datasets.py | class HyperX(torch.utils.data.Dataset): | Generic class for a hyperspectral scene |
train_dataset = HyperX(img, train_gt, **hyperparams)
……
val_dataset = HyperX(img, val_gt, **hyperparams)
强调一个疑问
这里生成对象train_dataset
和val_dataset
的时候,用的dataset
是整个的img
,但是用到的ground_truth
是部分gt
(train_gt
和val_gt
)。
为了解决这个疑问,我做了一个小测试:
for batch_idx, (data, target) in tqdm(enumerate(data_loader), total=len(data_loader)):
# Load the data into the GPU if required
data, target = data.to(device), target.to(device)
# ------------自加打印原始的输入维度------------------
# print(type(target))
# print(target.shape)
if batch_idx % 100 ==0:
print('initial data shape:',data.shape)
if 0 in target:
os.system('pause')
# ------------自加打印原始的输入维度------------------
在训练网络的时候,会从data_loader
中取出training sample,我加入了一个判断代码,如果取出的training sample中有**‘0’**标签,则打断程序的运行。
整体的检验只用一个epoch就可以。
实际运行的结果发现并不会传入标签为0的元素。标签为0的元素表示本身为标签为0或者没有被分到相应的数据集(这里是训练集)。
结论就是,生成高光谱训练样本的对象时,可以传入整张高光谱图像,只在gt
上区分数据集类型。
打印对象train_dataset
的属性:
运行示例命令:python C:\Users\73416\PycharmProjects\HSIproject\main.py --model nn --dataset PaviaU --training_sample 0.1 --cuda 0
。
center_pixel
True
data
[[[0.080875 0.062375 0.058 ... 0.402625 0.40475 0.40625 ]
[0.0755 0.06825 0.065875 ... 0.30525 0.308 0.316 ]
[0.077625 0.09325 0.0695 ... 0.2885 0.293125 0.295125]
...
...
[0.074125 0.048375 0.0535 ... 0.29775 0.300875 0.302875]
[0.074125 0.093875 0.081875 ... 0.289 0.2885 0.286125]
[0.111125 0.09 0.056125 ... 0.302 0.305875 0.310625]]]
……
flip_augmentation
False
ignored_labels
{0}
indices
[[175 4]
[536 248]
[362 26]
...
[369 37]
[295 212]
[511 223]]
label(与像素点对应的label值)
[[0 0 0 ... 0 0 0]
[0 0 0 ... 0 0 0]
[0 0 0 ... 0 0 0]
...
[0 0 0 ... 0 0 0]
[0 0 0 ... 0 0 0]
[0 0 0 ... 0 0 0]]
labels(除去ignored_labels的label值)
[4, 1, 1, 4, 4, 1, 4, 4, 4, 1, 1, 1, 4, 4, 4, 4, 1, 1, 4, 4, 4, 1, 1, 4, 1, 4, 1, 1, 4, 4, 4, 1, 4, 1, 1, 4, 4, 4, 1, 1, 4, 4, 4, 1, 4, 1, 1, 1, 1, 1, 1, 4, 1, 1, 1, 1, 1, 4, 4, 1, 1, 1, 4, 1,
……
……
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]
mixture_augmentation
False
……
name
PaviaU
patch_size
1
radiation_augmentation
False
……
位置 | 函数 | 作用 |
---|---|---|
主程序调用 | —— | —— |
dataloader.py | —— | —— |
train_loader = data.DataLoader(train_dataset,
batch_size=hyperparams['batch_size'],
#pin_memory=hyperparams['device'],
shuffle=True)
……
val_loader = data.DataLoader(val_dataset,
#pin_memory=hyperparams['device'],
batch_size=hyperparams['batch_size'])
……
batch_sampler
……
batch_size
100
……
dataset
……
……
……
batch_sampler
……
batch_size
100
……
dataset
……
……