Few-shot learning 基于任务对模型进行训练,在N-way-K-shot中,一个任务中的meta-training中含有N类,每一类抽取K个样本构成support set, query set则是在刚才抽取的N类剩余的样本中sample一定数量的样本(可以是均匀采样,也可以是不均匀采样)。
针对上述情况,我们需要使用不同类别放置在不同文件夹的数据集。但有时,数据并没有按类放置,这时就需要对数据进行处理。下面以CIFAR100为列(不含N-way-k-shot的采样):
import os
from skimage import io
import torchvision as tv
import numpy as np
import torch
def Cifar100(root):
character = [[] for i in range(100)]
train_set = tv.datasets.CIFAR100(root, train=True, download=True)
test_set = tv.datasets.CIFAR100(root, train=False, download=True)
dataset = []
for (X, Y) in zip(train_set.train_data, train_set.train_labels): # 将train_set的数据和label读入列表
dataset.append(list((X, Y)))
for (X, Y) in zip(test_set.test_data, test_set.test_labels): # 将test_set的数据和label读入列表
dataset.append(list((X, Y)))
for X, Y in dataset:
character[Y].append(X) # 32*32*3
character = np.array(character)
character = torch.from_numpy(character)
# 按类打乱
np.random.seed(6)
shuffle_class = np.arange(len(character))
np.random.shuffle(shuffle_class)
character = character[shuffle_class]
# shape = self.character.shape
# self.character = self.character.view(shape[0], shape[1], shape[4], shape[2], shape[3]) # 将数据转成channel在前
meta_training, meta_validation, meta_testing = \
character[:64], character[64:80], character[80:] # meta_training : meta_validation : Meta_testing = 64类:16类:20类
dataset = [] # 释放内存
character = []
os.mkdir(os.path.join(root, 'meta_training'))
for i, per_class in enumerate(meta_training):
character_path = os.path.join(root, 'meta_training', 'character_' + str(i))
os.mkdir(character_path)
for j, img in enumerate(per_class):
img_path = character_path + '/' + str(j) + ".jpg"
io.imsave(img_path, img)
os.mkdir(os.path.join(root, 'meta_validation'))
for i, per_class in enumerate(meta_validation):
character_path = os.path.join(root, 'meta_validation', 'character_' + str(i))
os.mkdir(character_path)
for j, img in enumerate(per_class):
img_path = character_path + '/' + str(j) + ".jpg"
io.imsave(img_path, img)
os.mkdir(os.path.join(root, 'meta_testing'))
for i, per_class in enumerate(meta_testing):
character_path = os.path.join(root, 'meta_testing', 'character_' + str(i))
os.mkdir(character_path)
for j, img in enumerate(per_class):
img_path = character_path + '/' + str(j) + ".jpg"
io.imsave(img_path, img)
if __name__ == '__main__':
root = '/home/xie/文档/datasets/cifar_100'
Cifar100(root)
print("-----------------")