写在前面
在使用CNN训练了MNIST数据集之后,不知道大家会不会有一种想要制作一个属于自己的数据集,并且用来训练的冲动呢?其实想要使用pytorch制作一个简单的数据集一点也不难,下面我们就一起来看看具体的实现方法吧。
在制作一个自己的数据之前,我们需要分析一下MNIST数据集,然后依葫芦画瓢,制作一个属于自己的数据集。首先,我们需要知道,MNIST数据集里面首先有图片,其次,每张图片都有对应的标签。如:
此外,如上图我们是可以将MNIST数据集中的图片和标签一起输出出来的,换句话说,我们可以读取数据集中的图片与标签,所以我们可以知道,制作数据集至少需要几点条件
第一点,图片我们可以从各种途径获得图片,这里需要注意一下图片的版权问题,大部分图片我们可以在网络上下载到。这里就不多做赘述了。
我们在这里着重会讲解第二点和第三点即,怎么制作一个标签文件,与如何才能读取到图片和图片对应的标签。
具体实现
首先我们可以知道一个数据集文件中包括图片与标签文件,如:
从上面这张图我们可以看到,前面是图片名称+格式,中间用一个空格作为分隔,后面接上图片的标签。
那么此时我们面临的最大问题一定是,如何在将我们想要的label以一定格式写入txt文件呢?(刚刚发现,这个问题好像与pytorch没有什么关联啊喂QAQ)
想要将我们的图片写入txt文件其实特别的简单。先上代码,
#定义一下文件地址
dir = '/content/dataset/' #写入你自己的数据集所在位置
label = 0
import os
def generate(dir,label):
files = os.listdir(dir)
files.sort()
listText = open('all_list.txt','a')
for file in files:
fileType = os.path.split(file)
if fileType[1] == '.txt':
continue
name = file + ' ' + str(int(label)) +'\n'
listText.write(name)
listText.close()
outer_path = '/content/dataset/' #这里是你的图片的目录
if __name__ == '__main__':
i = 0
folderlist = os.listdir(outer_path) #列举文件夹
for folder in folderlist:
generate(os.path.join(outer_path, folder),i)
i += 1
通过这段脚本我们就可以实现以这种固定格式写入txt啦。
接下来我们看第二个问题,如何才能读取我们数据集中的文件和标签呢?要知道我们想要训练数据集必然就离不开对数据集的各种操作。其实这个问题pytorch也早就帮我们设计好啦。
首先导入一些必要的包
#导入必要的包
import torch
import torchvision
from torch import nn, optim
from torch.autograd import Variable as var
import torchvision.transforms as transforms
from PIL import Image
具体的实现代码来自这篇博客,十分推荐大家看一看。
Pytorch打怪路(三)Pytorch创建自己的数据集1
其实这段代码的主要目的就是为了让我们能够方便的读取数据集中的图片以及标签信息。
#创建自己的类:MyDataset,这个类是继承的torch.utils.data.Dataset
class MyDataset(torch.utils.data.Dataset):
def __init__(self,root, datatxt, transform=None, target_transform=None): #初始化一些需要传入的参数
super(MyDataset,self).__init__()
#按照传入的路径和txt文本参数,打开这个文本,并读取内容
fh = open(root + datatxt, 'r')
#创建一个名为img的空列表,一会儿用来装东西
imgs = []
#按行循环txt文本中的内容
for line in fh:
# 删除 本行string 字符串末尾的指定字符,这个方法的详细介绍自己查询python
line = line.rstrip()
words = line.split() #通过指定分隔符对字符串进行切片,默认为所有的空字符,包括空格、换行、制表符等
imgs.append((words[0],int(words[1]))) #把txt里的内容读入imgs列表保存,具体是words几要看txt内容而定
#word[0]是文件名, word[1]为标签
self.imgs = imgs
self.transform = transform
self.target_transform = target_transform
# 很显然,根据我刚才截图所示txt的内容,words[0]是图片信息,words[1]是lable
def __getitem__(self, index):
fn, label = self.imgs[index] #fn是图片path #fn和label分别获得imgs[index]也即是刚才每行中word[0]和word[1]的信息
img = Image.open(root+fn).convert('L') #按照path读入图片from PIL import Image # 按照路径读取图片
if self.transform is not None:
img = self.transform(img) #是否进行transform
return img,label
#return很关键,return回哪些内容,那么我们在训练时循环读取每个batch时,就能获得哪些内容
def __len__(self): #这个函数也必须要写,它返回的是数据集的长度,也就是多少张图片,要和loader的长度作区分
return len(self.imgs)
#根据自己定义的那个勒MyDataset来创建数据集!注意是数据集!而不是loader迭代器
train_data=MyDataset(root, datatxt = 'labels.txt', transform=transforms.ToTensor())
test_data=MyDataset(txt=root+'test.txt', transform=transforms.ToTensor())
至此我们就成功实现了一个属于自己的数据集啦。可以看看效果,这里我读取一下我数据集中的图片:
idx = 2
img = train_data[idx][0].numpy()
plt.imshow(img[0], cmap = 'gray')
plt.axis('off') # 关掉坐标轴为 off
print('label:',train_data[idx][1])#train[][0]为图片信息,train[][1]为label
plt.show()
输出:
我们可以看到,我们已经可以将数据集中的图片以及标签进行输出啦。最后希望大家都能搭建属于自己的数据集,要是有什么问题大家可以多多交流。
下次我们来学习一下基于pytorch的RNN(Recurrent Neural Network)循环神经网络吧。