pytorch应用(一)如何制作数据集

 

 

1.下载fashion-mnist数据集

因为是二进制文件,所以需要自己转换成图片、txt标签

#调用一些和操作系统相关的函数
import os
#输入输出相关
from skimage import io
#dataset相关
import torchvision.datasets.mnist as mnist

#路径
root="/home/s/PycharmProjects/untitled/fashion-mnist/data/fashion"
#读取二进制文件,这里不知道是不是必须使用mnist读
train_set = (
    mnist.read_image_file(os.path.join(root, 'train-images-idx3-ubyte')),#路径拼接,split()是分割路径与文件名,和这个正好相反
    mnist.read_label_file(os.path.join(root, 'train-labels-idx1-ubyte'))
        )
test_set = (
    mnist.read_image_file(os.path.join(root, 't10k-images-idx3-ubyte')),
    mnist.read_label_file(os.path.join(root, 't10k-labels-idx1-ubyte'))
        )

#打印test_set类型
print(type(test_set))
>>>out:
#打印test_set中元素个数
print(len(test_set))
>>>out:2
#打印第元素类型,都是tensor
print(type(test_set[0]))
print(type(test_set[1]))
>>>out:
>>>out:
#打印元素形状,可以第一个元素是所有照片的tensor,第二个元素是所有标签的tensor.这里用test_set[0].shape是一样的
print("test set[0] :",test_set[0].size())
print("test set[1] :",test_set[1].size())
>>>out:('test set[0] :', (10000, 28, 28))
>>>out:('test set[1] :', (10000,))
#取出一个图片看一下,这两种都可以,就是看一下这个tensor的形状
a = test_set[0]
print(a[0].shape)
print(test_set[0][0].shape)
>>>out:(28, 28)
>>>out:(28, 28)

#定义一个tensor转图片的函数
def convert_to_img(train=True):
    if(train):
        #创建一个train.txt文件,用来保存标签
        f=open(root+'train.txt','w')#python中并没有这种路径表示方式,这个不对
        data_path=root+'/train/'
        #如果不存在这个路径,就创建文件夹
        if(not os.path.exists(data_path)):
            os.makedirs(data_path)
        #zip打包成元组,train_set本来不就是元组么?
        for i, (img,label) in enumerate(zip(train_set[0],train_set[1])):
            img_path=data_path+str(i)+'.jpg'
            #tensor与numpy格式转换tensor_img = torch.from_numpy(numpy_img)
            io.imsave(img_path,img.numpy())
            a=str(label)
            a = a.rstrip(')')
            a = a.strip('tensor(')#这里如果不进行字符串的处理,会输出“tensor(9)”而不是“9”
            f.write(img_path+' '+ a +'\n')
        f.close()
    else:
        f = open(root + 'test.txt', 'w')
        data_path = root + '/test/'
        if (not os.path.exists(data_path)):
            os.makedirs(data_path)
        for i, (img,label) in enumerate(zip(test_set[0],test_set[1])):
            img_path = data_path+ str(i) + '.jpg'
            io.imsave(img_path, img.numpy())
            a=str(label)
            a = a.rstrip(')')
            a = a.strip('tensor(')
            f.write(img_path + ' ' + a + '\n')
        f.close()

convert_to_img(True)
convert_to_img(False)

2.用图片、txt标签制作torch.dataset格式数据集

主要是以torch.utils.data.Dataset为基类进行编写:

__init__

__getitem__

__len__

这几个函数也都是大同小异,可以添加一些自己需要的返回值

import torch
from torch.utils.data import Dataset
from PIL import Image

#以torch.utils.data.Dataset为基类创建MyDataset
class MyDataset(Dataset):
    #stpe1:初始化
    def __init__(self, txt, transform=None, target_transform=None,):
        fh = open(txt, 'r')#打开标签文件
        imgs = []#创建列表,装东西
        for line in fh:#遍历标签文件每行
            line = line.rstrip()#删除字符串末尾的空格
            words = line.split()#通过空格分割字符串,变成列表
            imgs.append((words[0],int(words[1])))#把图片名words[0],标签int(words[1])放到imgs里
        self.imgs = imgs
        self.transform = transform
        self.target_transform = target_transform

    def __getitem__(self, index):#检索函数
        fn, label = self.imgs[index]#读取文件名、标签
        img = Image.open(fn).convert('RGB')#通过PIL.Image读取图片
        if self.transform is not None:
            img = self.transform(img)
        return img,label

    def __len__(self):
        return len(self.imgs)

官方文档里关于torch.utils.data.Dataset这个基类的说明

pytorch应用(一)如何制作数据集_第1张图片

用到的一些函数

1)Python rstrip() 删除 string 字符串末尾的指定字符(默认为空格)从后边删

str.rstrip([chars])

str = "     this is string example....wow!!!     ";
print str.rstrip();
str = "88888888this is string example....wow!!!8888888";
print str.rstrip('8');

out:
>>>     this is string example....wow!!!
>>>88888888this is string example....wow!!!

2)Python strip() 方法用于移除字符串头尾指定的字符(默认为空格或换行符)或字符序列。两头删

str.strip([chars])

str = "00000003210Runoob01230000000"; 
print str.strip( '0' ); 
str2 = "   Runoob      "; 
print str2.strip();

out:
3210Runoob0123
Runoob

3)split() 通过指定分隔符对字符串进行切片,如果参数 num 有指定值,则分隔 num+1 个子字符串。变成了列表

str = "Line1-abcdef \nLine2-abc \nLine4-abcd";
print str.split( );       # 以空格为分隔符,包含 \n
print str.split(' ', 1 ); # 以空格为分隔符,分隔成两个

out:
>>>['Line1-abcdef', 'Line2-abc', 'Line4-abcd']
>>>['Line1-abcdef', '\nLine2-abc \nLine4-abcd']

3.加载自己的dataset并使用

1.对图像进行预处理

from torchvision import transforms as transforms

trans_form = transforms.Compose([
    transforms.Resize(96), # 缩放到 96 * 96 大小
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # 归一化
])

通过 torchvision.transforms模块对数据进行预处理

 

torchvision.transforms.Compose(transforms)可以把许多transforms合在一起

transforms.ToTensor()是必须做的,这里也可以不用官方给的,自己写data_tf函数

具体有那些transforms:https://pytorch.org/docs/stable/torchvision/transforms.html?highlight=transforms

2.dataset、dataloader加载

train_data=MyDataset(txt='train.txt', transform=trans_form)
test_data=MyDataset(txt='test.txt', transform=trans_form)

train_loader = DataLoader(dataset=train_data, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_data, batch_size=64)

dataloader是一个可迭代的对象,意味着我们可以像使用迭代器一样使用它

迭代器返回的是一个list[imgs,labels]

imgs是64张28X28X3的图片组成的一个tensor

labels是64个标签的tensor

3.划分训练接、验证集、测试集

https://cloud.tencent.com/developer/article/1435013

train_data=MyDataset(txt='train.txt', transform=transforms.ToTensor())
test_data=MyDataset(txt='test.txt', transform=transforms.ToTensor())
print('train:', len(train_data), 'test:', len(test_data))

train_data, val_data = torch.utils.data.random_split(train_data, [55000, 5000])

print('train:', len(train_data), 'validation:', len(val_data))

>>>('train:', 60000, 'test:', 10000)
>>>('train:', 55000, 'validation:', 5000)

pytorch 0.4.1版本以上才支持random_split函数 

参考链接:

https://www.runoob.com/python/att-string-split.html

https://www.cnblogs.com/denny402/p/7520063.html

https://blog.csdn.net/Teeyohuang/article/details/79587125

https://pytorch.org/docs/stable/data.html

https://blog.csdn.net/TH_NUM/article/details/80877687

你可能感兴趣的:(机器学习,pytorch)