元学习数据集的预处理

mini-ImageNet

首先配置好数据集,images,train.csv,test.csv,val.csv,目录如下

miniimagenet/   
├── images 
   ├── n0210891500001298.jpg         
   ├── n0287152500001298.jpg     
   ...       
├── test.csv   
├── val.csv    
└── train.csv  
└── proc_images.py

images文件夹下是60000张图片,先对其归一化成84*84大小;csv文件中是图片文件名和其对应的标签
按照csv对应的标签分成训练集、测试集、验证集。训练集中每个文件夹都代表一类,其文件夹名称就是标签
例如:'images/n0153282900000005.jpg' -> 'train/n01532829/n0153282900000005.jpg'

 proc_images.py

首先给定images,train.csv,test.csv,val.csv文件,将images中的图片按照csv文件对应的分成训练集、测试集、验证集,自动生成三个文件夹(train、test、val)

'''
windows版本
'''
from __future__ import print_function
import csv
import glob
import os

from PIL import Image

path_to_images = 'imagess/'

all_images = glob.glob(path_to_images + '*')#调用glob函数读取文件中图片

# 将图片归一化为84*84大小
for i, image_file in enumerate(all_images):
    im = Image.open(image_file)
    im = im.resize((84, 84), resample=Image.LANCZOS)
    im.save(image_file)
    if i % 500 == 0:
        print(i)

# 根据csv文件从images中读取数据并分成三类,创建相应的目录文件夹(train、val、test)
for datatype in ['train', 'val', 'test']:
    os.mkdir(datatype)

    with open(datatype + '.csv', 'r') as f:
        reader = csv.reader(f, delimiter=',')
        last_label = ''
        for i, row in enumerate(reader):
            if i == 0: # skip the headers
                continue
            label = row[1]
            image_name = row[0]
            if label != last_label:
                cur_dir = datatype + '/' + label + '/'
                os.mkdir(cur_dir)
                last_label = label
            os.rename('imagess/' + image_name,cur_dir + image_name)

            '''
            首先配置好数据集,images,train.csv,test.csv,val.csv。
            images文件夹下是60000张图片,先对其归一化成84*84大小;csv文件中是图片文件名和其对应的标签
            按照csv对应的标签分成训练集、测试集、验证集。训练集中每个文件夹都代表一类,其文件夹名称就是标签
            例如:'images/n0153282900000005.jpg' -> 'train/n01532829/n0153282900000005.jpg'
            
            '''
'''
linux版本
'''

from __future__ import print_function
import csv
import glob
import os

from PIL import Image

path_to_images = 'images/'

all_images = glob.glob(path_to_images + '*')

# Resize images
for i, image_file in enumerate(all_images):
    im = Image.open(image_file)
    im = im.resize((84, 84), resample=Image.LANCZOS)
    im.save(image_file)
    if i % 500 == 0:
        print(i)

# Put in correct directory
for datatype in ['train', 'val', 'test']:
    os.system('mkdir ' + datatype)

    with open(datatype + '.csv', 'r') as f:
        reader = csv.reader(f, delimiter=',')
        last_label = ''
        for i, row in enumerate(reader):
            if i == 0:  # skip the headers
                continue
            label = row[1]
            image_name = row[0]
            if label != last_label:
                cur_dir = datatype + '/' + label + '/'
                os.system('mkdir ' + cur_dir)
                last_label = label
            os.system('mv images/' + image_name + ' ' + cur_dir)
'''
Pytorch版本
'''
import os
import torch
from torch.utils.data import Dataset
from torchvision.transforms import transforms
import numpy as np
import collections
from PIL import Image
import csv
import random


class MiniImagenet(Dataset):
    """
    put mini-imagenet files as :
    root :
        |- images/*.jpg includes all imgeas
        |- train.csv
        |- test.csv
        |- val.csv
    注意:元学习不同于一般的监督学习,尤其是批处理和集合的概念。
    批处理:包含多个集合
    集合: n_way * k_shot为元训练集, n_way * n_query为元测试集.
    """

    def __init__(self, root, mode, batchsz, n_way, k_shot, k_query, resize, startidx=0):
        """

        :param root: root path of mini-imagenet
        :param mode: train, val or test
        :param batchsz: batch size of sets, not batch of imgs
        :param n_way:
        :param k_shot:
        :param k_query: num of qeruy imgs per class
        :param resize: resize to
        :param startidx: 从startidx开始索引标签
        """

        self.batchsz = batchsz  # batch of set, not batch of imgs
        self.n_way = n_way  # n-way
        self.k_shot = k_shot  # k-shot
        self.k_query = k_query  # for evaluation
        self.setsz = self.n_way * self.k_shot  # num of samples per set
        self.querysz = self.n_way * self.k_query  # number of samples per set for evaluation
        self.resize = resize  # resize to
        self.startidx = startidx  # index label not from 0, but from startidx
        print('shuffle DB :%s, b:%d, %d-way, %d-shot, %d-query, resize:%d' % (
        mode, batchsz, n_way, k_shot, k_query, resize))

        if mode == 'train':
            self.transform = transforms.Compose([lambda x: Image.open(x).convert('RGB'),
                                                 transforms.Resize((self.resize, self.resize)),
                                                 # transforms.RandomHorizontalFlip(),
                                                 # transforms.RandomRotation(5),
                                                 transforms.ToTensor(),
                                                 transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
                                                 ])
        else:
            self.transform = transforms.Compose([lambda x: Image.open(x).convert('RGB'),
                                                 transforms.Resize((self.resize, self.resize)),
                                                 transforms.ToTensor(),
                                                 transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
                                                 ])

        self.path = os.path.join(root, 'images')  # image path
        csvdata = self.loadCSV(os.path.join(root, mode + '.csv'))  # csv path
        self.data = []
        self.img2label = {}
        for i, (k, v) in enumerate(csvdata.items()):
            self.data.append(v)  # [[img1, img2, ...], [img111, ...]]
            self.img2label[k] = i + self.startidx  # {"img_name[:9]":label}
        self.cls_num = len(self.data)

        self.create_batch(self.batchsz)

    def loadCSV(self, csvf):
        """
        return a dict saving the information of csv
        :param splitFile: csv file name
        :return: {label:[file1, file2 ...]}
        """
        dictLabels = {}
        with open(csvf) as csvfile:
            csvreader = csv.reader(csvfile, delimiter=',')
            next(csvreader, None)  # skip (filename, label)
            for i, row in enumerate(csvreader):
                filename = row[0]
                label = row[1]
                # append filename to current label
                if label in dictLabels.keys():
                    dictLabels[label].append(filename)
                else:
                    dictLabels[label] = [filename]
        return dictLabels

    def create_batch(self, batchsz):
        """
        create batch for meta-learning.
        ×episode× here means batch, and it means how many sets we want to retain.
        :param episodes: batch size
        :return:
        """
        self.support_x_batch = []  # support set batch
        self.query_x_batch = []  # query set batch
        for b in range(batchsz):  # for each batch
            # 1.select n_way classes randomly
            selected_cls = np.random.choice(self.cls_num, self.n_way, False)  # no duplicate
            np.random.shuffle(selected_cls)
            support_x = []
            query_x = []
            for cls in selected_cls:
                # 2. select k_shot + k_query for each class
                selected_imgs_idx = np.random.choice(len(self.data[cls]), self.k_shot + self.k_query, False)
                np.random.shuffle(selected_imgs_idx)
                indexDtrain = np.array(selected_imgs_idx[:self.k_shot])  # idx for Dtrain
                indexDtest = np.array(selected_imgs_idx[self.k_shot:])  # idx for Dtest
                support_x.append(
                    np.array(self.data[cls])[indexDtrain].tolist())  # get all images filename for current Dtrain
                query_x.append(np.array(self.data[cls])[indexDtest].tolist())

            # shuffle the correponding relation between support set and query set
            random.shuffle(support_x)
            random.shuffle(query_x)

            self.support_x_batch.append(support_x)  # append set to current sets
            self.query_x_batch.append(query_x)  # append sets to current sets

    def __getitem__(self, index):
        """
        index means index of sets, 0<= index <= batchsz-1
        :param index:
        :return:
        """
        # [setsz, 3, resize, resize]
        support_x = torch.FloatTensor(self.setsz, 3, self.resize, self.resize)
        # [setsz]
        support_y = np.zeros((self.setsz), dtype=np.int)
        # [querysz, 3, resize, resize]
        query_x = torch.FloatTensor(self.querysz, 3, self.resize, self.resize)
        # [querysz]
        query_y = np.zeros((self.querysz), dtype=np.int)

        flatten_support_x = [os.path.join(self.path, item)
                             for sublist in self.support_x_batch[index] for item in sublist]
        support_y = np.array(
            [self.img2label[item[:9]]  # filename:n0153282900000005.jpg, the first 9 characters treated as label
             for sublist in self.support_x_batch[index] for item in sublist]).astype(np.int32)

        flatten_query_x = [os.path.join(self.path, item)
                           for sublist in self.query_x_batch[index] for item in sublist]
        query_y = np.array([self.img2label[item[:9]]
                            for sublist in self.query_x_batch[index] for item in sublist]).astype(np.int32)

        # print('global:', support_y, query_y)
        # support_y: [setsz]
        # query_y: [querysz]
        # unique: [n-way], sorted
        unique = np.unique(support_y)
        random.shuffle(unique)
        # relative means the label ranges from 0 to n-way
        support_y_relative = np.zeros(self.setsz)
        query_y_relative = np.zeros(self.querysz)
        for idx, l in enumerate(unique):
            support_y_relative[support_y == l] = idx
            query_y_relative[query_y == l] = idx

        # print('relative:', support_y_relative, query_y_relative)

        for i, path in enumerate(flatten_support_x):
            support_x[i] = self.transform(path)

        for i, path in enumerate(flatten_query_x):
            query_x[i] = self.transform(path)
        # print(support_set_y)
        # return support_x, torch.LongTensor(support_y), query_x, torch.LongTensor(query_y)

        return support_x, torch.LongTensor(support_y_relative), query_x, torch.LongTensor(query_y_relative)

    def __len__(self):
        # as we have built up to batchsz of sets, you can sample some small batch size of sets.
        return self.batchsz


if __name__ == '__main__':
    # 下面的章节是通过tensorboard查看一组图像。
    from torchvision.utils import make_grid
    from matplotlib import pyplot as plt
    from tensorboardX import SummaryWriter
    import time

    plt.ion()# 打开交互模式

    tb = SummaryWriter('runs', 'mini-imagenet')#记录
    #对数据预处理,
    mini = MiniImagenet('./data/miniimagenet/', mode='train', n_way=5, k_shot=1, k_query=1, batchsz=1000, resize=168)

    for i, set_ in enumerate(mini):
        # support_x: [k_shot*n_way, 3, 84, 84]
        support_x, support_y, query_x, query_y = set_

        support_x = make_grid(support_x, nrow=2)
        query_x = make_grid(query_x, nrow=2)

        plt.figure(1)
        plt.imshow(support_x.transpose(2, 0).numpy())
        plt.pause(0.5)
        plt.figure(2)
        plt.imshow(query_x.transpose(2, 0).numpy())
        plt.pause(0.5)

        tb.add_image('support_x', support_x)
        tb.add_image('query_x', query_x)

        time.sleep(5)

    tb.close()

Omniglot

# -*- coding: utf-8 -*-
import numpy as np
import os
import random
from sys import platform as sys_pf
import matplotlib
if sys_pf == 'darwin':
	matplotlib.use("TkAgg")
from matplotlib import pyplot as plt
'''
演示如何为字符加载图像和笔画数据
'''

def plot_motor_to_image(I,drawing,lw=2):
	'''
	在图像上绘制运动轨迹
	:param I: [105 x 105 nump] 灰度图像
	:param drawing: [ns list] 运动空间的笔画 (numpy arrays)
	:param lw: 线宽
	:return:
	'''
	drawing = [d[:,0:2] for d in drawing] # strip off the timing data (third column)
	drawing = [space_motor_to_img(d) for d in drawing] # convert to image space
	plt.imshow(I,cmap='gray')
	ns = len(drawing)
	for sid in range(ns): # for each stroke
		plot_traj(drawing[sid],get_color(sid),lw)
	plt.xticks([])
	plt.yticks([])


def plot_traj(stk,color,lw):
	'''
	展示单独笔画
	:param stk:[n x 2] 笔画
	:param color:笔画颜色
	:param lw:线宽
	:return:
	'''
	n = stk.shape[0]
	if n > 1:
		plt.plot(stk[:,0],stk[:,1],color=color,linewidth=lw)
	else:
		plt.plot(stk[0,0],stk[0,1],color=color,linewidth=lw,marker='.')

# 为索引k的笔画的颜色映射
def get_color(k):	
    scol = ['g','r','b','m','c']
    ncol = len(scol)
    if k < ncol:
       out = scol[k]
    else:
       out = scol[-1]
    return out

# 转换成str格式,并添加从零到个位数
def num2str(idx):
	if idx < 10:
		return '0'+str(idx)
	return str(idx)


def load_img(fn):
	'''
	加载字符的二进制图像
	:param fn: 文件名
	:return:
	'''
	I = plt.imread(fn)
	I = np.array(I,dtype=bool)
	return I


def load_motor(fn):
	'''
	从文本文件中加载字符的笔画数据
	:param fn:文件名
	:return:运动:笔画列表(每个是一个[n x 3] numpy数组),前两列是坐标;
	最后一列是计时数据(以毫秒为单位)
	'''
	motor = []
	with open(fn,'r') as fid:
		lines = fid.readlines()
	lines = [l.strip() for l in lines]
	for myline in lines:
		if myline =='START': # beginning of character
			stk = []
		elif myline =='BREAK': # break between strokes
			stk = np.array(stk)
			motor.append(stk) # add to list of strokes
			stk = [] 
		else:
			arr = np.fromstring(myline,dtype=float,sep=',')
			stk.append(arr)
	return motor


def space_motor_to_img(pt):
	'''
	从运动空间映射到图像空间(反之亦然)
	:param pt:在运动上的[nx2]点(行)坐标
	:return:在图像的[nx2]点(行)坐标
	'''
	pt[:,1] = -pt[:,1]
	return pt
def space_img_to_motor(pt):
	pt[:,1] = -pt[:,1]
	return

if __name__ == "__main__":
	img_dir = 'images_background'
	stroke_dir = 'strokes_background'
	nreps = 20 # 每个字符的显示数量
	nalpha = 5 # 要显示的字符数

	alphabet_names = [a for a in os.listdir(img_dir) if a[0] != '.'] # 获取文件名
	alphabet_names = random.sample(alphabet_names,nalpha) # 选择随机字符

	for a in range(nalpha): # for each alphabet
		print('generating figure ' + str(a+1) + ' of ' + str(nalpha))
		alpha_name = alphabet_names[a]
		
		# 从字母中随机选择字符
		character_id = random.randint(1,len(os.listdir(os.path.join(img_dir,alpha_name))))

		# 获取此字符的图像和笔画方向
		img_char_dir = os.path.join(img_dir,alpha_name,'character'+num2str(character_id))
		stroke_char_dir = os.path.join(stroke_dir,alpha_name,'character'+num2str(character_id))

		# 获取基本的字符文件名
		fn_example = os.listdir(img_char_dir)[0]
		fn_base = fn_example[:fn_example.find('_')] 

		plt.figure(a,figsize=(10,8))
		plt.clf()
		for r in range(1,nreps+1): # for each rendition
			plt.subplot(4,5,r)
			fn_stk = stroke_char_dir + '/' + fn_base + '_' + num2str(r) + '.txt'
			fn_img = img_char_dir + '/' + fn_base + '_' + num2str(r) + '.png'			
			motor = load_motor(fn_stk)
			I = load_img(fn_img)
			plot_motor_to_image(I,motor)
			if r==1:
				plt.title(alpha_name[:15] + '\n character ' + str(character_id))
		plt.tight_layout()
	plt.show()

 

 

 

你可能感兴趣的:(元学习)