Pytorch-GAN-加载数据

 

对于大部分GAN,由于没有标签,所以加载数据的情况可能分为以下几种:

所有的图片放在一个根目录下,根目录可能存在一些目录把这些图片分放在不同的文件夹下

原图片整合成一个numpy包

原图片整合成一个csv文件

不管如何加载数据,最后肯定得处理成CHW格式,最好还得归一化

 

放在目录

这种放在目录中,直接采用pytorch官方函数动态加载就好了,不做任何处理,加载的数据值在【0,1】。

from torch.utils.data import DataLoader, Dataset, TensorDataset
import torchvision.datasets as ds
import torchvision.transforms as transforms
import numpy as np
import torch
import pathlib
import matplotlib.pyplot as plt

dataroot = r'E:\data\FFHQ\thumbnails128x128'
batch_size = 32
dataset = ds.ImageFolder(root=dataroot,
                         transform=transforms.Compose([
                             # transforms.Resize(),
                             # transforms.CenterCrop(image_size),
                             transforms.ToTensor(),
                             # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                         ]))
# Create the dataloader
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

# 注意ImageFolder将所有图片解析成(图片、标签)形式,所以取第0维。
# for i, data in enumerate(dataloader):
#     print(data[0].shape)

# print(next(iter(dataloader))[0].shape)
# plt.imshow(np.transpose(next(iter(dataloader))[0][0].numpy(),(1,2,0)))
# plt.show()
 

也可以用自己包装的类:

import pathlib
import numpy as np
import cv2


class DataGen():
    def __init__(self, path, img_shape, once=False):
        data_root = pathlib.Path(path)
        self.images_path = list(data_root.glob('*'))
        self.img_shape = img_shape
        self.length = len(self.images_path)
        self.once = once
        if once:
            self.images = self.load_all()

    def load_one(self, path: str):
        image = cv2.imread(path)
        image = cv2.resize(image, (self.img_shape[1], self.img_shape[0]))
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = (image.astype(np.float32) - 127.5) / 127.5
        return image

    def load_all(self):
        images = np.zeros((self.length, self.img_shape[0], self.img_shape[1], self.img_shape[2]), np.float32)
        for i, path in enumerate(self.images_path):
            images[i] = self.load_one(str(path))
            if i % 1000 == 0: print('loading:', i)
        return images

    def gen(self, batch_size):
        idx = np.random.randint(0, self.length, batch_size)
        if self.once:
            # return self.images[idx]
            return np.transpose(self.images[idx], (0, 3, 1, 2))
        else:
            images = [self.load_one(str(self.images_path[i])) for i in idx]
            images = np.array(images)
            # return images
            return np.transpose(images, (0, 3, 1, 2))


dg = DataGen(r'E:\Data\anime_girls\images', (64, 64, 3), once=False)
print(dg.gen(32).shape)

 

整合成numpy包

这种就不适用于使用pytorch官方函数了

import pathlib
import numpy as np
import cv2


class DataGen():
    def __init__(self, path, img_shape=(64, 64, 3)):
        self.img_shape = img_shape
        self.images = np.load(path)
        self.length = len(self.images)

    def pro(self, image):
        image = cv2.resize(image, (self.img_shape[1], self.img_shape[0]))
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        image = (image.astype(np.float32) - 127.5) / 127.5
        return image

    def gen(self, batch_size):
        idx = np.random.randint(0, self.length, batch_size)
        images = np.array([self.pro(image) for image in self.images[idx]])
        return np.transpose(images, (0, 3, 1, 2))


dg = DataGen(r'E:\Data\anime_girls\numpy\anime_girls_64x64.npy', (64, 64, 3))
print(dg.gen(32).shape)

 

整合成CSV文件

待续

你可能感兴趣的:(pytorch)