pytorch学习:准备自己的图片数据

    图片数据一般有两种情况:

    1. 所有图片放在一个文件夹内,另外有一个txt文件显示标签。

    2. 不同类别的图片放在不同的文件夹内,文件夹就是图片的类别。

    两种情况,第一种可以自定义Dataset,第二种情况直接调用torchvision.datasets.ImageFolder处理,具体如下:

一、 所有图片均放在一个文件夹内

    以mnist数据集的10000个test为例,先将test集里面的10000图片保存出来,并生着对应的txt标签文件。先在当前目录创建一个空文件夹mnist_test,用于保存10000张图片,接着运行代码:

import torch
import torchvision
import matplotlib.pyplot as plt
from skimage import io
mnist_test= torchvision.datasets.MNIST(
    ‘./mnist‘, train=False, download=True
)
print(‘test set:‘, len(mnist_test))

f=open(‘mnist_test.txt‘,‘w‘)
for i,(img,label) in enumerate(mnist_test):
    img_path="./mnist_test/"+str(i)+".jpg"
    io.imsave(img_path,img)
    f.write(img_path+‘ ‘+str(label)+‘\n‘)
f.close()

    如此,图片就保存mnist_test文件夹里面,并在当前目录下生成了一个mnist_test.txt文件,大致如下:

pytorch学习:准备自己的图片数据_第1张图片

    然后就正式开始处理数据:

from torchvision import transforms, utils
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
from PIL import Image


def default_loader(path):
    return Image.open(path).convert(‘RGB‘)


class MyDataset(Dataset):
    def __init__(self, txt, transform=None, target_transform=None, loader=default_loader):
        fh = open(txt, ‘r‘)
        imgs = []
        for line in fh:
            line = line.strip(‘\n‘)
            line = line.rstrip()
            words = line.split()
            imgs.append((words[0],int(words[1])))
        self.imgs = imgs
        self.transform = transform
        self.target_transform = target_transform
        self.loader = loader

    def __getitem__(self, index):
        fn, label = self.imgs[index]
        img = self.loader(fn)
        if self.transform is not None:
            img = self.transform(img)
        return img,label

    def __len__(self):
        return len(self.imgs)

train_data=MyDataset(txt=‘mnist_test.txt‘, transform=transforms.ToTensor())
data_loader = DataLoader(train_data, batch_size=100,shuffle=True)
print(len(data_loader))


def show_batch(imgs):
    grid = utils.make_grid(imgs)
    plt.imshow(grid.numpy().transpose((1, 2, 0)))
    plt.title(‘Batch from dataloader‘)


for i, (batch_x, batch_y) in enumerate(data_loader):
    if(i<4):
        print(i, batch_x.size(),batch_y.size())
        show_batch(batch_x)
        plt.axis(‘off‘)
        plt.show()

二、 不同类别图片放在不同的文件夹内

    首先依旧是准备数据,以flowers数据集为例,下载地址为:

    http://download.tensorflow.org/example_images/flower_photos.tgz

    一共五类,分别放在5个文件夹中,大致如下图:

pytorch学习:准备自己的图片数据_第2张图片

    路径为d:/flowers/。那么处理数据如下:

import torch
import torchvision
from torchvision import transforms, utils
import matplotlib.pyplot as plt

img_data = torchvision.datasets.ImageFolder(‘D:/bnu/database/flower‘,
                                            transform=transforms.Compose([
                                                transforms.Scale(256),
                                                transforms.CenterCrop(224),
                                                transforms.ToTensor()])
                                            )

print(len(img_data))
data_loader = torch.utils.data.DataLoader(img_data, batch_size=20,shuffle=True)
print(len(data_loader))


def show_batch(imgs):
    grid = utils.make_grid(imgs,nrow=5)
    plt.imshow(grid.numpy().transpose((1, 2, 0)))
    plt.title(‘Batch from dataloader‘)


for i, (batch_x, batch_y) in enumerate(data_loader):
    if(i<4):
        print(i, batch_x.size(), batch_y.size())

        show_batch(batch_x)
        plt.axis(‘off‘)
        plt.show()
转载链接:http://www.bubuko.com/infodetail-2304938.html


    

你可能感兴趣的:(机器学习,深度学习,图像_Python)