[高光谱] Hyperspectral-Classification Pytorch 数据集的读取、划分、加载

Hyperspectral-Classification Pytorch 数据集的读取、划分、加载

文章目录

  • Hyperspectral-Classification Pytorch 数据集的读取、划分、加载
    • 数据集读取:
        • 流程:
        • 代码:
          • main.py:
          • datasets.py:
          • utils.py:
    • ground truth划分:
        • 流程:
        • 代码:
          • main.py:
          • utils.py:
    • 生成样本的dataset和dataloader:
        • 流程
        • 代码:
        • 打印信息:
      • 生成DataLoader
        • 流程
        • 代码:
        • 打印信息
        • 打印信息

这里只关心 samplegt的读取、划分、加载。

不关心其他参数如ignored_labels等。

数据集读取:

流程:

位置 函数 作用
主程序调用 —— ——
datasets.py get_dataset() 得到img, gt, LABEL_VALUES, IGNORED_LABELS, RGB_BANDS, palette
utilis.py open_file() 读取数据集,包括PaviaU.mat和PaviaU_gt.mat

代码:

main.py:
img, gt, LABEL_VALUES, IGNORED_LABELS, RGB_BANDS, palette = get_dataset(DATASET, FOLDER)
datasets.py:
elif dataset_name == 'PaviaU':
    # Load the image
    img = open_file(folder + 'PaviaU.mat')['paviaU']

    rgb_bands = (55, 41, 12)

    gt = open_file(folder + 'PaviaU_gt.mat')['paviaU_gt']

    label_values = ['Undefined', 'Asphalt', 'Meadows', 'Gravel', 'Trees',
                    'Painted metal sheets', 'Bare Soil', 'Bitumen',
                    'Self-Blocking Bricks', 'Shadows']

    ignored_labels = [0]
utils.py:
def open_file(dataset):
    _, ext = os.path.splitext(dataset)
    ext = ext.lower()
    if ext == '.mat':
        # Load Matlab array
        return io.loadmat(dataset)
    elif ext == '.tif' or ext == '.tiff':
        # Load TIFF file
        return misc.imread(dataset)
    elif ext == '.hdr':
        img = spectral.open_image(dataset)
        return img.load()
    else:
        raise ValueError("Unknown file format: {}".format(ext))

ground truth划分:

流程:

位置 函数 作用
主程序调用 —— ——
utils.py sample_gt() 将非ignored_labels通过取索引的方式,随机划分到train_gttest_gt
_spilt.py sklearn.model_selection.train_test_split() 根据所给比例,随机划分输入数据

额外说明:

通过gt划分得到的train_gttest_gt,仍在原图像维度中,只是将取得元素的位置设定为对应的label,而其他位置置零(默认0为ignored_labels)。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-bMhMLvGE-1570353298738)(C:\Users\73416\Desktop\MarkDown\图片文件夹\Hyperspectral-Classification Pytorch 数据集的读取、划分、加载\Train ground truth.jpg)]

代码:

main.py:
train_gt, test_gt = sample_gt(gt, SAMPLE_PERCENTAGE, mode=SAMPLING_MODE)
……
……
test_gt, val_gt = sample_gt(test_gt, 0.95, mode='random')
utils.py:
def sample_gt(gt, train_size, mode='random'):
    """Extract a fixed percentage of samples from an array of labels.   
    Args:
        gt: a 2D array of int labels
        percentage: [0, 1] float
    Returns:
        train_gt, test_gt: 2D arrays of int labels
    """
    
    indices = np.nonzero(gt)
    X = list(zip(*indices)) # x,y features  (r,c)形式的位置的索引
    y = gt[indices].ravel() # classes
    train_gt = np.zeros_like(gt)
    test_gt = np.zeros_like(gt)
    if train_size > 1:
       train_size = int(train_size)
    
    if mode == 'random':
       train_indices, test_indices = sklearn.model_selection.train_test_split(X, train_size=train_size, stratify=y)
       train_indices = [list(t) for t in zip(*train_indices)]
       test_indices = [list(t) for t in zip(*test_indices)]
       train_gt[train_indices] = gt[train_indices]
       test_gt[test_indices] = gt[test_indices]
    elif mode == 'fixed':
       ……
       ……
       ……
    else:
        raise ValueError("{} sampling is not implemented yet.".format(mode))
    return train_gt, test_gt

生成样本的dataset和dataloader:

###HyperX的对象train_dataset

流程

位置 函数 作用
主程序调用 —— ——
datasets.py class HyperX(torch.utils.data.Dataset): Generic class for a hyperspectral scene

代码:

train_dataset = HyperX(img, train_gt, **hyperparams)
……
val_dataset = HyperX(img, val_gt, **hyperparams)

强调一个疑问

这里生成对象train_datasetval_dataset的时候,用的dataset是整个的img,但是用到的ground_truth是部分gttrain_gtval_gt)。

为了解决这个疑问,我做了一个小测试:

for batch_idx, (data, target) in tqdm(enumerate(data_loader), total=len(data_loader)):
    # Load the data into the GPU if required
    data, target = data.to(device), target.to(device)

    # ------------自加打印原始的输入维度------------------
    # print(type(target))
    # print(target.shape)
    if batch_idx % 100 ==0:
        print('initial data shape:',data.shape)
    if 0 in target:
        os.system('pause')
    # ------------自加打印原始的输入维度------------------

在训练网络的时候,会从data_loader中取出training sample,我加入了一个判断代码,如果取出的training sample中有**‘0’**标签,则打断程序的运行。

整体的检验只用一个epoch就可以。

实际运行的结果发现并不会传入标签为0的元素。标签为0的元素表示本身为标签为0或者没有被分到相应的数据集(这里是训练集)。

结论就是,生成高光谱训练样本的对象时,可以传入整张高光谱图像,只在gt上区分数据集类型。

打印信息:

打印对象train_dataset的属性:

运行示例命令:python C:\Users\73416\PycharmProjects\HSIproject\main.py --model nn --dataset PaviaU --training_sample 0.1 --cuda 0

center_pixel
True

data
[[[0.080875 0.062375 0.058    ... 0.402625 0.40475  0.40625 ]
  [0.0755   0.06825  0.065875 ... 0.30525  0.308    0.316   ]
  [0.077625 0.09325  0.0695   ... 0.2885   0.293125 0.295125]
  ...
  ...
  [0.074125 0.048375 0.0535   ... 0.29775  0.300875 0.302875]
  [0.074125 0.093875 0.081875 ... 0.289    0.2885   0.286125]
  [0.111125 0.09     0.056125 ... 0.302    0.305875 0.310625]]]
  
……

flip_augmentation
False

ignored_labels
{0}

indices
[[175   4]
 [536 248]
 [362  26]
 ...
 [369  37]
 [295 212]
 [511 223]]
 
label(与像素点对应的label值)
[[0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 ...
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]]
 
labels(除去ignored_labels的label值)
[4, 1, 1, 4, 4, 1, 4, 4, 4, 1, 1, 1, 4, 4, 4, 4, 1, 1, 4, 4, 4, 1, 1, 4, 1, 4, 1, 1, 4, 4, 4, 1, 4, 1, 1, 4, 4, 4, 1, 1, 4, 4, 4, 1, 4, 1, 1, 1, 1, 1, 1, 4, 1, 1, 1, 1, 1, 4, 4, 1, 1, 1, 4, 1, 
……
……
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]

mixture_augmentation
False

……

name
PaviaU

patch_size
1

radiation_augmentation
False

……

生成DataLoader

流程

位置 函数 作用
主程序调用 —— ——
dataloader.py —— ——

代码:

train_loader = data.DataLoader(train_dataset,
                               batch_size=hyperparams['batch_size'],
                               #pin_memory=hyperparams['device'],
                               shuffle=True)
……
val_loader = data.DataLoader(val_dataset,
                                     #pin_memory=hyperparams['device'],
                                     batch_size=hyperparams['batch_size'])

打印信息

……
batch_sampler

……
batch_size
100
……
dataset

……

……

打印信息

……
batch_sampler

……
batch_size
100
……
dataset

……

……

你可能感兴趣的:(开源项目使用,高光谱图像分类)