目录
Cifar10 转 png
第一步:下载 cifar-10-python.tar.gz
第二步:运行 1_1_cifar10_to_png.py
主要模块
scipy.misc.imsave()函数:
pickle模块
os.path.join
第三步: 训练集、验证集和测试集的划分
主要模块
glob模块
shutil模块
os.walk()方法
split()和os.path.split()
让 PyTorch 能读数据集
Dataset 类
1. 制作图片数据的索引
2. 构建 Dataset 子类
图片从硬盘到模型
这里主要介绍一下这个数据集的目录结构以及内部数据组织格式.
下载方式:1. 官网: http://www.cs.toronto.edu/~kriz/cifar.html
1_1_cifar10_to_png.py代码如下:
# coding:utf-8
"""
将cifar10的data_batch_12345 转换成 png格式的图片
每个类别单独存放在一个文件夹,文件夹名称为0-9
"""
from scipy.misc import imsave
import numpy as np
import os
import pickle
data_dir = os.path.join("..", "..", "Data", "cifar-10-batches-py")
train_o_dir = os.path.join("..", "..", "Data", "cifar-10-png", "raw_train")
test_o_dir = os.path.join("..", "..", "Data", "cifar-10-png", "raw_test")
Train = False # 不解压训练集,仅解压测试集
# 解压缩,返回解压后的字典
def unpickle(file):
with open(file, 'rb') as fo:
dict_ = pickle.load(fo, encoding='bytes')
return dict_
def my_mkdir(my_dir):
if not os.path.isdir(my_dir):
os.makedirs(my_dir)
# 生成训练集图片,
if __name__ == '__main__':
if Train:
for j in range(1, 6):
data_path = os.path.join(data_dir, "data_batch_" + str(j)) # data_batch_12345
train_data = unpickle(data_path)
print(data_path + " is loading...")
for i in range(0, 10000):
img = np.reshape(train_data[b'data'][i], (3, 32, 32))
img = img.transpose(1, 2, 0)
label_num = str(train_data[b'labels'][i])
o_dir = os.path.join(train_o_dir, label_num)
my_mkdir(o_dir)
img_name = label_num + '_' + str(i + (j - 1)*10000) + '.png'
img_path = os.path.join(o_dir, img_name)
imsave(img_path, img)
print(data_path + " loaded.")
print("test_batch is loading...")
# 生成测试集图片
test_data_path = os.path.join(data_dir, "test_batch")
test_data = unpickle(test_data_path)
for i in range(0, 10000):
img = np.reshape(test_data[b'data'][i], (3, 32, 32))
img = img.transpose(1, 2, 0)
label_num = str(test_data[b'labels'][i])
o_dir = os.path.join(test_o_dir, label_num)
my_mkdir(o_dir)
img_name = label_num + '_' + str(i) + '.png'
img_path = os.path.join(o_dir, img_name)
imsave(img_path, img)
print("test_batch loaded.")
这个函数用于储存图片,将数组保存为图像。此功能仅在安装了Python Imaging Library(PIL)时可用。新的替代它的是imageio.imwrite()
用法:imsave(name, arr, format=None)
参数:
#灰度图像
from scipy.misc import imsave
x = np.zeros((255, 255))
x = np.zeros((255, 255), dtype=np.uint8)
x[:] = np.arange(255)
imsave('gradient.png', x)
#RGB图像
rgb = np.zeros((255, 255, 3), dtype=np.uint8)
rgb[..., 0] = np.arange(255)
rgb[..., 1] = 55
rgb[..., 2] = 1 - np.arange(255)
imsave('rgb_gradient.png', rgb)
值得注意的是,这个函数默认的情况下,会检测你输入的RGB值的范围,如果都在0到1之间的话,那么会自动扩大范围至0到255。也就是说,这个时候你乘不乘255输出图片的效果一样的。
本次主要使用pickle模块将下载的序列化图像信息读出来,转变为bytes编码。
1 2 3 4 5 |
|
Python提供的pickle模块可以序列化对象并保存到磁盘中,并在需要的时候读取出来,任何对象都可以执行序列化操作。
1、pickle.dump(object, file, protocol=) 将object对象序列化到打开的文件夹file中。protocol是序列化协议,默认是0,如果是负数或者HIGHEST_PROTOCOL,则使用最高版本序列化协议。
2、pickle.load(file,encoding) 把file中的对象读出,encoding 参数可置为 'bytes' 来将这些 8 位字符串实例读取为字节对象。
3、pickle.dumps(obj,protocol=None) 将 obj 打包以后的对象作为 'bytes'类型直接返回,而不是将其写入到文件。
4、pickle.loads(bytes_object) 对于打包生成的对象 bytes_object,还原出原对象的结构并返回。
dump() 和 load() 与 dumps() 和 loads()的区别 :dump()函数能一个接着一个地将几个对象序列化存储到同一个文件中,随后调用load()来以同样的顺序反序列化读出这些对象。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 |
|
输出:
1 2 |
|
python中有join()和os.path.join()两个函数,具体作用如下:
‘sep’.join(seq)
参数:
sep:分隔符。可以为空
seq:要连接的元素序列、字符串、元组、字典
上面的语法即:以sep作为分隔符,将seq所有的元素合并成一个新的字符串
返回值:返回一个以分隔符sep连接各个元素后生成的字符串
os.path.join()函数
语法: os.path.join(path1[,path2[,……]])
返回值:将多个路径组合后返回os.path.join()函数:连接两个或更多的路径名组件
- 如果各组件名首字母不包含’/’,则函数会自动加上
- 如果有一个组件是一个绝对路径,则在它之前的所有组件均会被舍弃
- 如果最后一个组件为空,则生成的路径以一个’/’分隔符结尾
注意:若出现”./”开头的参数,会从”./”开头的参数的上一个参数开始拼接
运行完毕后,可在文件夹 Data/cifar-10-png/raw_test/下看到 0-9 个文件夹,对应 9 个类别。 脚本中未将训练集解压出来,这里只是为了实验,因此未使用过多的数据。这里仅将测试集中的 10000 张图片解压出来,作为原始图片,将从这 10000 张图片中划分出训练集(train),验证集(valid),测试集(test)。
"""
将原始数据集进行划分成训练集、验证集和测试集
"""
import os
import glob
import random
import shutil
dataset_dir = os.path.join("..", "..", "Data", "cifar-10-png", "raw_test")
train_dir = os.path.join("..", "..", "Data", "train")
valid_dir = os.path.join("..", "..", "Data", "valid")
test_dir = os.path.join("..", "..", "Data", "test")
train_per = 0.8
valid_per = 0.1
test_per = 0.1
def makedir(new_dir):
if not os.path.exists(new_dir):
os.makedirs(new_dir)
if __name__ == '__main__':
for root, dirs, files in os.walk(dataset_dir):
for sDir in dirs:
imgs_list = glob.glob(os.path.join(root, sDir, '*.png'))
random.seed(666)
random.shuffle(imgs_list)
imgs_num = len(imgs_list)
train_point = int(imgs_num * train_per)
valid_point = int(imgs_num * (train_per + valid_per))
for i in range(imgs_num):
if i < train_point:
out_dir = os.path.join(train_dir, sDir)
elif i < valid_point:
out_dir = os.path.join(valid_dir, sDir)
else:
out_dir = os.path.join(test_dir, sDir)
makedir(out_dir)
#out_path = os.path.join(out_dir, os.path.split(imgs_list[i])[-1])
#shutil.copy(imgs_list[i], out_path)
shutil.copy(imgs_list[i], out_dir)
print('Class:{}, train:{}, valid:{}, test:{}'.format(sDir, train_point, valid_point-train_point, imgs_num-valid_point))
glob模块用来一次性读取对应文件夹下所有符合要求的子文件夹和子文件夹下的文件列表,常见的两个方法有glob.glob()和glob.iglob(),可以和常用的find功能进行类比,glob支持
*?[]
这三种通配符
- *代表0个或多个字符
- ?代表一个字符
- [ ]匹配指定范围内的字符,如[0-9]匹配数字
- copy()
功能:复制文件 格式:shutil.copy('来源文件','目标地址')
其他用法:python中的shutil模块
os.walk的函数声明为:
walk(top, topdown=True, οnerrοr=None, followlinks=False)
参数
- top 是遍历的目录地址
- topdown 为真,则优先遍历top目录,否则优先遍历top的子目录(默认为开启)
- onerror 需要一个 callable 对象,当walk需要异常时,会调用
- followlinks 如果为真,则会遍历目录下的快捷方式(linux 下是 symbolic link)实际所指的目录(默认关闭)
os.walk 的返回值是一个生成器(generator),也就是说我们需要不断的遍历它,来获得所有的内容。
每次遍历的对象都是返回的是一个三元组(root,dirs,files)返回参数说明:
- root 所指的是当前正在遍历的这个文件夹的本身的地址
- dirs 是一个 list ,内容是该文件夹中所有的目录的名字(不包括子目录)
- files 同样是 list , 内容是该文件夹中所有的文件(不包括子目录) files 同样是 list , 内容是该文件夹中所有的文件(不包括子目录)
语法:str.split(str="",num=string.count(str))[n]
参数说明:
- str:表示为分隔符,默认为空格,但是不能为空(’ ’)。若字符串中没有分隔符,则把整个字符串作为列表的一个元素
- num:表示分割次数。如果存在参数num,则仅分隔成 num+1 个子字符串,并且每一个子字符串可以赋给新的变量
- [n]:表示选取第n个分片
注意:当使用空格作为分隔符时,对于中间为空的项会自动忽略
语法:os.path.split(‘PATH’)
参数说明:
- PATH指一个文件的全路径作为参数:
- 如果给出的是一个目录和文件名,则输出路径和文件名
- 如果给出的是一个目录名,则输出路径和为空文件名
class Dataset(object):
"""An abstract class representing a Dataset.
All other datasets should subclass it. All subclasses should override
``__len__``, that provides the size of the dataset, and ``__getitem__``,
supporting integer indexing in range from 0 to len(self) exclusive.
"""
def __getitem__(self, index):
raise NotImplementedError
def __len__(self):
raise NotImplementedError
def __add__(self, other):
return ConcatDataset([self, other])
# coding:utf-8
import os
'''
为数据集生成对应的txt文件
'''
train_txt_path = os.path.join("..", "..", "Data", "train.txt")
train_dir = os.path.join("..", "..", "Data", "train")
valid_txt_path = os.path.join("..", "..", "Data", "valid.txt")
valid_dir = os.path.join("..", "..", "Data", "valid")
def gen_txt(txt_path, img_dir):
f = open(txt_path, 'w')
for root, s_dirs, _ in os.walk(img_dir, topdown=True): # 获取 train文件下各文件夹名称
for sub_dir in s_dirs:
i_dir = os.path.join(root, sub_dir) # 获取各类的文件夹 绝对路径
img_list = os.listdir(i_dir) # 获取类别文件夹下所有png图片的路径
for i in range(len(img_list)):
if not img_list[i].endswith('png'): # 若不是png文件,跳过
continue
label = img_list[i].split('_')[0]
img_path = os.path.join(i_dir, img_list[i])
line = img_path + ' ' + label + '\n'
f.write(line)
f.close()
if __name__ == '__main__':
gen_txt(train_txt_path, train_dir)
gen_txt(valid_txt_path, valid_dir)
# coding: utf-8
from PIL import Image
from torch.utils.data import Dataset
class MyDataset(Dataset):
def __init__(self, txt_path, transform = None, target_transform = None):
fh = open(txt_path, 'r')
imgs = []
for line in fh:
line = line.rstrip()
words = line.split()
imgs.append((words[0], int(words[1])))
self.imgs = imgs
self.transform = transform
self.target_transform = target_transform
def __getitem__(self, index):
fn, label = self.imgs[index]
img = Image.open(fn).convert('RGB')
if self.transform is not None:
img = self.transform(img)
return img, label
def __len__(self):
return len(self.imgs)
首先看看初始化,初始化中从我们准备好的 txt 里获取图片的路径和标签,并且存储 在 self.imgs , self.imgs 就是上面提到的 list ,其一个元素对应一个样本的路径和标签,其实 就是 txt 中的一行。