首先配置好数据集,images,train.csv,test.csv,val.csv,目录如下
miniimagenet/ ├── images ├── n0210891500001298.jpg ├── n0287152500001298.jpg ... ├── test.csv ├── val.csv └── train.csv └── proc_images.py
images文件夹下是60000张图片,先对其归一化成84*84大小;csv文件中是图片文件名和其对应的标签
按照csv对应的标签分成训练集、测试集、验证集。训练集中每个文件夹都代表一类,其文件夹名称就是标签
例如:'images/n0153282900000005.jpg' -> 'train/n01532829/n0153282900000005.jpg'
首先给定images,train.csv,test.csv,val.csv文件,将images中的图片按照csv文件对应的分成训练集、测试集、验证集,自动生成三个文件夹(train、test、val)
'''
windows版本
'''
from __future__ import print_function
import csv
import glob
import os
from PIL import Image
path_to_images = 'imagess/'
all_images = glob.glob(path_to_images + '*')#调用glob函数读取文件中图片
# 将图片归一化为84*84大小
for i, image_file in enumerate(all_images):
im = Image.open(image_file)
im = im.resize((84, 84), resample=Image.LANCZOS)
im.save(image_file)
if i % 500 == 0:
print(i)
# 根据csv文件从images中读取数据并分成三类,创建相应的目录文件夹(train、val、test)
for datatype in ['train', 'val', 'test']:
os.mkdir(datatype)
with open(datatype + '.csv', 'r') as f:
reader = csv.reader(f, delimiter=',')
last_label = ''
for i, row in enumerate(reader):
if i == 0: # skip the headers
continue
label = row[1]
image_name = row[0]
if label != last_label:
cur_dir = datatype + '/' + label + '/'
os.mkdir(cur_dir)
last_label = label
os.rename('imagess/' + image_name,cur_dir + image_name)
'''
首先配置好数据集,images,train.csv,test.csv,val.csv。
images文件夹下是60000张图片,先对其归一化成84*84大小;csv文件中是图片文件名和其对应的标签
按照csv对应的标签分成训练集、测试集、验证集。训练集中每个文件夹都代表一类,其文件夹名称就是标签
例如:'images/n0153282900000005.jpg' -> 'train/n01532829/n0153282900000005.jpg'
'''
'''
linux版本
'''
from __future__ import print_function
import csv
import glob
import os
from PIL import Image
path_to_images = 'images/'
all_images = glob.glob(path_to_images + '*')
# Resize images
for i, image_file in enumerate(all_images):
im = Image.open(image_file)
im = im.resize((84, 84), resample=Image.LANCZOS)
im.save(image_file)
if i % 500 == 0:
print(i)
# Put in correct directory
for datatype in ['train', 'val', 'test']:
os.system('mkdir ' + datatype)
with open(datatype + '.csv', 'r') as f:
reader = csv.reader(f, delimiter=',')
last_label = ''
for i, row in enumerate(reader):
if i == 0: # skip the headers
continue
label = row[1]
image_name = row[0]
if label != last_label:
cur_dir = datatype + '/' + label + '/'
os.system('mkdir ' + cur_dir)
last_label = label
os.system('mv images/' + image_name + ' ' + cur_dir)
'''
Pytorch版本
'''
import os
import torch
from torch.utils.data import Dataset
from torchvision.transforms import transforms
import numpy as np
import collections
from PIL import Image
import csv
import random
class MiniImagenet(Dataset):
"""
put mini-imagenet files as :
root :
|- images/*.jpg includes all imgeas
|- train.csv
|- test.csv
|- val.csv
注意:元学习不同于一般的监督学习,尤其是批处理和集合的概念。
批处理:包含多个集合
集合: n_way * k_shot为元训练集, n_way * n_query为元测试集.
"""
def __init__(self, root, mode, batchsz, n_way, k_shot, k_query, resize, startidx=0):
"""
:param root: root path of mini-imagenet
:param mode: train, val or test
:param batchsz: batch size of sets, not batch of imgs
:param n_way:
:param k_shot:
:param k_query: num of qeruy imgs per class
:param resize: resize to
:param startidx: 从startidx开始索引标签
"""
self.batchsz = batchsz # batch of set, not batch of imgs
self.n_way = n_way # n-way
self.k_shot = k_shot # k-shot
self.k_query = k_query # for evaluation
self.setsz = self.n_way * self.k_shot # num of samples per set
self.querysz = self.n_way * self.k_query # number of samples per set for evaluation
self.resize = resize # resize to
self.startidx = startidx # index label not from 0, but from startidx
print('shuffle DB :%s, b:%d, %d-way, %d-shot, %d-query, resize:%d' % (
mode, batchsz, n_way, k_shot, k_query, resize))
if mode == 'train':
self.transform = transforms.Compose([lambda x: Image.open(x).convert('RGB'),
transforms.Resize((self.resize, self.resize)),
# transforms.RandomHorizontalFlip(),
# transforms.RandomRotation(5),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
else:
self.transform = transforms.Compose([lambda x: Image.open(x).convert('RGB'),
transforms.Resize((self.resize, self.resize)),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
self.path = os.path.join(root, 'images') # image path
csvdata = self.loadCSV(os.path.join(root, mode + '.csv')) # csv path
self.data = []
self.img2label = {}
for i, (k, v) in enumerate(csvdata.items()):
self.data.append(v) # [[img1, img2, ...], [img111, ...]]
self.img2label[k] = i + self.startidx # {"img_name[:9]":label}
self.cls_num = len(self.data)
self.create_batch(self.batchsz)
def loadCSV(self, csvf):
"""
return a dict saving the information of csv
:param splitFile: csv file name
:return: {label:[file1, file2 ...]}
"""
dictLabels = {}
with open(csvf) as csvfile:
csvreader = csv.reader(csvfile, delimiter=',')
next(csvreader, None) # skip (filename, label)
for i, row in enumerate(csvreader):
filename = row[0]
label = row[1]
# append filename to current label
if label in dictLabels.keys():
dictLabels[label].append(filename)
else:
dictLabels[label] = [filename]
return dictLabels
def create_batch(self, batchsz):
"""
create batch for meta-learning.
×episode× here means batch, and it means how many sets we want to retain.
:param episodes: batch size
:return:
"""
self.support_x_batch = [] # support set batch
self.query_x_batch = [] # query set batch
for b in range(batchsz): # for each batch
# 1.select n_way classes randomly
selected_cls = np.random.choice(self.cls_num, self.n_way, False) # no duplicate
np.random.shuffle(selected_cls)
support_x = []
query_x = []
for cls in selected_cls:
# 2. select k_shot + k_query for each class
selected_imgs_idx = np.random.choice(len(self.data[cls]), self.k_shot + self.k_query, False)
np.random.shuffle(selected_imgs_idx)
indexDtrain = np.array(selected_imgs_idx[:self.k_shot]) # idx for Dtrain
indexDtest = np.array(selected_imgs_idx[self.k_shot:]) # idx for Dtest
support_x.append(
np.array(self.data[cls])[indexDtrain].tolist()) # get all images filename for current Dtrain
query_x.append(np.array(self.data[cls])[indexDtest].tolist())
# shuffle the correponding relation between support set and query set
random.shuffle(support_x)
random.shuffle(query_x)
self.support_x_batch.append(support_x) # append set to current sets
self.query_x_batch.append(query_x) # append sets to current sets
def __getitem__(self, index):
"""
index means index of sets, 0<= index <= batchsz-1
:param index:
:return:
"""
# [setsz, 3, resize, resize]
support_x = torch.FloatTensor(self.setsz, 3, self.resize, self.resize)
# [setsz]
support_y = np.zeros((self.setsz), dtype=np.int)
# [querysz, 3, resize, resize]
query_x = torch.FloatTensor(self.querysz, 3, self.resize, self.resize)
# [querysz]
query_y = np.zeros((self.querysz), dtype=np.int)
flatten_support_x = [os.path.join(self.path, item)
for sublist in self.support_x_batch[index] for item in sublist]
support_y = np.array(
[self.img2label[item[:9]] # filename:n0153282900000005.jpg, the first 9 characters treated as label
for sublist in self.support_x_batch[index] for item in sublist]).astype(np.int32)
flatten_query_x = [os.path.join(self.path, item)
for sublist in self.query_x_batch[index] for item in sublist]
query_y = np.array([self.img2label[item[:9]]
for sublist in self.query_x_batch[index] for item in sublist]).astype(np.int32)
# print('global:', support_y, query_y)
# support_y: [setsz]
# query_y: [querysz]
# unique: [n-way], sorted
unique = np.unique(support_y)
random.shuffle(unique)
# relative means the label ranges from 0 to n-way
support_y_relative = np.zeros(self.setsz)
query_y_relative = np.zeros(self.querysz)
for idx, l in enumerate(unique):
support_y_relative[support_y == l] = idx
query_y_relative[query_y == l] = idx
# print('relative:', support_y_relative, query_y_relative)
for i, path in enumerate(flatten_support_x):
support_x[i] = self.transform(path)
for i, path in enumerate(flatten_query_x):
query_x[i] = self.transform(path)
# print(support_set_y)
# return support_x, torch.LongTensor(support_y), query_x, torch.LongTensor(query_y)
return support_x, torch.LongTensor(support_y_relative), query_x, torch.LongTensor(query_y_relative)
def __len__(self):
# as we have built up to batchsz of sets, you can sample some small batch size of sets.
return self.batchsz
if __name__ == '__main__':
# 下面的章节是通过tensorboard查看一组图像。
from torchvision.utils import make_grid
from matplotlib import pyplot as plt
from tensorboardX import SummaryWriter
import time
plt.ion()# 打开交互模式
tb = SummaryWriter('runs', 'mini-imagenet')#记录
#对数据预处理,
mini = MiniImagenet('./data/miniimagenet/', mode='train', n_way=5, k_shot=1, k_query=1, batchsz=1000, resize=168)
for i, set_ in enumerate(mini):
# support_x: [k_shot*n_way, 3, 84, 84]
support_x, support_y, query_x, query_y = set_
support_x = make_grid(support_x, nrow=2)
query_x = make_grid(query_x, nrow=2)
plt.figure(1)
plt.imshow(support_x.transpose(2, 0).numpy())
plt.pause(0.5)
plt.figure(2)
plt.imshow(query_x.transpose(2, 0).numpy())
plt.pause(0.5)
tb.add_image('support_x', support_x)
tb.add_image('query_x', query_x)
time.sleep(5)
tb.close()
# -*- coding: utf-8 -*-
import numpy as np
import os
import random
from sys import platform as sys_pf
import matplotlib
if sys_pf == 'darwin':
matplotlib.use("TkAgg")
from matplotlib import pyplot as plt
'''
演示如何为字符加载图像和笔画数据
'''
def plot_motor_to_image(I,drawing,lw=2):
'''
在图像上绘制运动轨迹
:param I: [105 x 105 nump] 灰度图像
:param drawing: [ns list] 运动空间的笔画 (numpy arrays)
:param lw: 线宽
:return:
'''
drawing = [d[:,0:2] for d in drawing] # strip off the timing data (third column)
drawing = [space_motor_to_img(d) for d in drawing] # convert to image space
plt.imshow(I,cmap='gray')
ns = len(drawing)
for sid in range(ns): # for each stroke
plot_traj(drawing[sid],get_color(sid),lw)
plt.xticks([])
plt.yticks([])
def plot_traj(stk,color,lw):
'''
展示单独笔画
:param stk:[n x 2] 笔画
:param color:笔画颜色
:param lw:线宽
:return:
'''
n = stk.shape[0]
if n > 1:
plt.plot(stk[:,0],stk[:,1],color=color,linewidth=lw)
else:
plt.plot(stk[0,0],stk[0,1],color=color,linewidth=lw,marker='.')
# 为索引k的笔画的颜色映射
def get_color(k):
scol = ['g','r','b','m','c']
ncol = len(scol)
if k < ncol:
out = scol[k]
else:
out = scol[-1]
return out
# 转换成str格式,并添加从零到个位数
def num2str(idx):
if idx < 10:
return '0'+str(idx)
return str(idx)
def load_img(fn):
'''
加载字符的二进制图像
:param fn: 文件名
:return:
'''
I = plt.imread(fn)
I = np.array(I,dtype=bool)
return I
def load_motor(fn):
'''
从文本文件中加载字符的笔画数据
:param fn:文件名
:return:运动:笔画列表(每个是一个[n x 3] numpy数组),前两列是坐标;
最后一列是计时数据(以毫秒为单位)
'''
motor = []
with open(fn,'r') as fid:
lines = fid.readlines()
lines = [l.strip() for l in lines]
for myline in lines:
if myline =='START': # beginning of character
stk = []
elif myline =='BREAK': # break between strokes
stk = np.array(stk)
motor.append(stk) # add to list of strokes
stk = []
else:
arr = np.fromstring(myline,dtype=float,sep=',')
stk.append(arr)
return motor
def space_motor_to_img(pt):
'''
从运动空间映射到图像空间(反之亦然)
:param pt:在运动上的[nx2]点(行)坐标
:return:在图像的[nx2]点(行)坐标
'''
pt[:,1] = -pt[:,1]
return pt
def space_img_to_motor(pt):
pt[:,1] = -pt[:,1]
return
if __name__ == "__main__":
img_dir = 'images_background'
stroke_dir = 'strokes_background'
nreps = 20 # 每个字符的显示数量
nalpha = 5 # 要显示的字符数
alphabet_names = [a for a in os.listdir(img_dir) if a[0] != '.'] # 获取文件名
alphabet_names = random.sample(alphabet_names,nalpha) # 选择随机字符
for a in range(nalpha): # for each alphabet
print('generating figure ' + str(a+1) + ' of ' + str(nalpha))
alpha_name = alphabet_names[a]
# 从字母中随机选择字符
character_id = random.randint(1,len(os.listdir(os.path.join(img_dir,alpha_name))))
# 获取此字符的图像和笔画方向
img_char_dir = os.path.join(img_dir,alpha_name,'character'+num2str(character_id))
stroke_char_dir = os.path.join(stroke_dir,alpha_name,'character'+num2str(character_id))
# 获取基本的字符文件名
fn_example = os.listdir(img_char_dir)[0]
fn_base = fn_example[:fn_example.find('_')]
plt.figure(a,figsize=(10,8))
plt.clf()
for r in range(1,nreps+1): # for each rendition
plt.subplot(4,5,r)
fn_stk = stroke_char_dir + '/' + fn_base + '_' + num2str(r) + '.txt'
fn_img = img_char_dir + '/' + fn_base + '_' + num2str(r) + '.png'
motor = load_motor(fn_stk)
I = load_img(fn_img)
plot_motor_to_image(I,motor)
if r==1:
plt.title(alpha_name[:15] + '\n character ' + str(character_id))
plt.tight_layout()
plt.show()