PyTorch将CIFAR100数据按类标归类保存

few-shot learning的采样

Few-shot learning 基于任务对模型进行训练,在N-way-K-shot中,一个任务中的meta-training中含有N类,每一类抽取K个样本构成support set, query set则是在刚才抽取的N类剩余的样本中sample一定数量的样本(可以是均匀采样,也可以是不均匀采样)。

对数据按类标归类

针对上述情况,我们需要使用不同类别放置在不同文件夹的数据集。但有时,数据并没有按类放置,这时就需要对数据进行处理。下面以CIFAR100为列(不含N-way-k-shot的采样):

import os
from skimage import io
import torchvision as tv
import numpy as np
import torch


def Cifar100(root):


    character = [[] for i in range(100)]

    train_set = tv.datasets.CIFAR100(root, train=True, download=True)
    test_set = tv.datasets.CIFAR100(root, train=False, download=True)

    dataset = []
    for (X, Y) in zip(train_set.train_data, train_set.train_labels):  # 将train_set的数据和label读入列表
        dataset.append(list((X, Y)))
    for (X, Y) in zip(test_set.test_data, test_set.test_labels):  # 将test_set的数据和label读入列表
        dataset.append(list((X, Y)))

    for X, Y in dataset:
        character[Y].append(X)  # 32*32*3

    character = np.array(character)
    character = torch.from_numpy(character)

    # 按类打乱
    np.random.seed(6)
    shuffle_class = np.arange(len(character))
    np.random.shuffle(shuffle_class)
    character = character[shuffle_class]

    # shape = self.character.shape
    # self.character = self.character.view(shape[0], shape[1], shape[4], shape[2], shape[3])  # 将数据转成channel在前
    meta_training, meta_validation, meta_testing = \
    character[:64], character[64:80], character[80:]  # meta_training : meta_validation : Meta_testing = 64类:16类:20类


    dataset = []  # 释放内存
    character = []

    os.mkdir(os.path.join(root, 'meta_training'))
    for i, per_class in enumerate(meta_training):
        character_path = os.path.join(root, 'meta_training', 'character_' + str(i))
        os.mkdir(character_path)
        for j, img in enumerate(per_class):
            img_path = character_path + '/' + str(j) + ".jpg"
            io.imsave(img_path, img)

    os.mkdir(os.path.join(root, 'meta_validation'))
    for i, per_class in enumerate(meta_validation):
        character_path = os.path.join(root, 'meta_validation', 'character_' + str(i))
        os.mkdir(character_path)
        for j, img in enumerate(per_class):
            img_path = character_path + '/' + str(j) + ".jpg"
            io.imsave(img_path, img)

    os.mkdir(os.path.join(root, 'meta_testing'))
    for i, per_class in enumerate(meta_testing):
        character_path = os.path.join(root, 'meta_testing', 'character_' + str(i))
        os.mkdir(character_path)
        for j, img in enumerate(per_class):
            img_path = character_path + '/' + str(j) + ".jpg"
            io.imsave(img_path, img)


if __name__ == '__main__':
    root = '/home/xie/文档/datasets/cifar_100'
    Cifar100(root)
    print("-----------------")


你可能感兴趣的:(PyTorch)