pytorch自定义数据集

使用函数

torch.utils.data.Dataset
torch.utils.data.DataLoader

数据准备

以猫狗为例实现分类
按照如下图所示建立文件和文件夹,我这里自己准备了20张猫狗图像。
pytorch自定义数据集_第1张图片
test.txt文件是后面代码生成的,先不用管,cats和dogs里面放上自己的图片,然后通过脚本生成test.txt文件,text.txt的脚本 代码如下:

#!/usr/bin/python
# -*- coding:utf-8 -*-
import os
def generate(dir, label):
    files = os.listdir(dir)
    files.sort()
    listText = open('data/test/test.txt', 'a')
    for file in files:
        fileType = os.path.split(file)
        if fileType[1] == '.txt':
            continue
        name = "/test/cats/" + file + ' ' + str(int(label)) + '\n'
        print(name)
        listText.write(name)
    listText.close()
def generate1(dir, label):
    files = os.listdir(dir)
    files.sort()
    listText = open('data/test/test.txt', 'a')
    for file in files:
        fileType = os.path.split(file)
        if fileType[1] == '.txt':
            continue
        name = "/test/dogs/" + file + ' ' + str(int(label)) + '\n'
        print(name)
        listText.write(name)
    listText.close()

outer_path = 'data/test'  # 这里是你的图片的目录

if __name__ == '__main__':
    i = 0
    folderlist = os.listdir(outer_path)  # 列举文件夹
    for folder in folderlist:
        if i == 0:
            generate(os.path.join(outer_path, folder), i)
        if i == 1:
            generate1(os.path.join(outer_path, folder), i)
        i += 1

由于就两个文件,此处就直接用两个相同的代码生成(鄙人代码功底不好,凑合着看),生成后的text.txt文件如下样式:前面是路径,后面是对应的标签。
pytorch自定义数据集_第2张图片
到这里,样本集的收集以及简单归类已经完成啦,下面我们将开始采用pytorch的数据集相关API和类,也就是我们以后要经常用到的dataset和dataloader。
Dataset类的使用: 是一个抽象类,所有的类都应该是此类的子类(也就是说应该继承该类)。 所有的子类都要重写 len(), getitem() 这两个魔术方法(魔术方法在执行程序的时候会自动执行)。
len() 此方法应该提供数据集的大小(容量)
getitem() 此方法应该提供支持下标索方式引访问数据集
dataloader:对dataset获取的数据可以进行打包,打乱,变换操作。
简言之:dataset是获取数据,dataloader是对获取的数据进行变换等操作。

代码实现:

定义mydataset:

class MyDataset(Dataset):

    def __init__(self, root_dir, names_file, transform=None):
        self.root_dir = root_dir
        self.names_file = names_file
        self.transform = transform
        self.size = 0
        self.names_list = []

        if not os.path.isfile(self.names_file):
            print(self.names_file + 'i does not exist!')
        file = open(self.names_file)
        for f in file:
            self.names_list.append(f)
            self.size += 1

    def __len__(self):
        return self.size

    def __getitem__(self, idx):
        image_path = self.root_dir + self.names_list[idx].split(' ')[0]
        print(image_path)
        # image_path = self.names_list[idx].split(' ')[0]
        if not os.path.isfile(image_path):
            print(image_path + 'you does not exist!')
            return None
        image = io.imread(image_path)   # use skitimage
        label = int(self.names_list[idx].split(' ')[1])

        sample = {'image': image, 'label': label}
        if self.transform:
            sample = self.transform(sample)

        return sample

定义一个实例化对象:

train_dataset = MyDataset(root_dir='./data/train',
                          names_file='./data/train/train.txt',
                          transform=None)

plt.figure()
for (cnt,i) in enumerate(train_dataset):
    image = i['image']
    label = i['label']

    ax = plt.subplot(4, 4, cnt+1)
    ax.axis('off')
    ax.imshow(image)
    ax.set_title('label {}'.format(label))
    plt.pause(0.001)

    if cnt == 15:
        break

pytorch自定义数据集_第3张图片
注意修改一下自己的路径。
以上并没有用到dataloader这个类,下面使用dataloader对dataset得到的数据集进行变换:
先对图像数据集进行resize和转变为tensor向量:

# 变换Resize
class Resize(object):

    def __init__(self, output_size: tuple):
        self.output_size = output_size

    def __call__(self, sample):
        # 图像
        image = sample['image']
        # 使用skitimage.transform对图像进行缩放
        image_new = transform.resize(image, self.output_size)
        return {'image': image_new, 'label': sample['label']}

# # 变换ToTensor
class ToTensor(object):

    def __call__(self, sample):
        image = sample['image']
        image_new = np.transpose(image, (2, 0, 1))
        return {'image': torch.from_numpy(image_new),
                'label': sample['label']}

然后调用dataloader函数对数据集进行处理:

# 对原始的训练数据集进行变换
transformed_trainset = MyDataset(root_dir='./data',
                          names_file='./data/test/test.txt',
                          transform=transforms.Compose(
                              [Resize((512,512)),
                               ToTensor()]
                          ))

dataloader函数可以完成数据集打乱shuffle,batch,numworks(多线程)
可视化使用dataloader后的代码:

def show_images_batch(sample_batched):
    images_batch, labels_batch = \
    sample_batched['image'], sample_batched['label']
    grid = make_grid(images_batch)
    plt.imshow(grid.numpy().transpose(1, 2, 0))


# sample_batch:  Tensor , NxCxHxW
plt.figure()
for i_batch, sample_batch in enumerate(trainset_dataloader):
    show_images_batch(sample_batch)
    plt.axis('off')
    plt.ioff()
    plt.show()
plt.show()

结果展示:
pytorch自定义数据集_第4张图片

你可能感兴趣的:(pytorch,机器学习,人工智能,python)