对于大部分GAN,由于没有标签,所以加载数据的情况可能分为以下几种:
所有的图片放在一个根目录下,根目录可能存在一些目录把这些图片分放在不同的文件夹下
原图片整合成一个numpy包
原图片整合成一个csv文件
不管如何加载数据,最后肯定得处理成CHW格式,最好还得归一化。
这种放在目录中,直接采用pytorch官方函数动态加载就好了,不做任何处理,加载的数据值在【0,1】。
from torch.utils.data import DataLoader, Dataset, TensorDataset
import torchvision.datasets as ds
import torchvision.transforms as transforms
import numpy as np
import torch
import pathlib
import matplotlib.pyplot as plt
dataroot = r'E:\data\FFHQ\thumbnails128x128'
batch_size = 32
dataset = ds.ImageFolder(root=dataroot,
transform=transforms.Compose([
# transforms.Resize(),
# transforms.CenterCrop(image_size),
transforms.ToTensor(),
# transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]))
# Create the dataloader
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
# 注意ImageFolder将所有图片解析成(图片、标签)形式,所以取第0维。
# for i, data in enumerate(dataloader):
# print(data[0].shape)
# print(next(iter(dataloader))[0].shape)
# plt.imshow(np.transpose(next(iter(dataloader))[0][0].numpy(),(1,2,0)))
# plt.show()
也可以用自己包装的类:
import pathlib
import numpy as np
import cv2
class DataGen():
def __init__(self, path, img_shape, once=False):
data_root = pathlib.Path(path)
self.images_path = list(data_root.glob('*'))
self.img_shape = img_shape
self.length = len(self.images_path)
self.once = once
if once:
self.images = self.load_all()
def load_one(self, path: str):
image = cv2.imread(path)
image = cv2.resize(image, (self.img_shape[1], self.img_shape[0]))
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = (image.astype(np.float32) - 127.5) / 127.5
return image
def load_all(self):
images = np.zeros((self.length, self.img_shape[0], self.img_shape[1], self.img_shape[2]), np.float32)
for i, path in enumerate(self.images_path):
images[i] = self.load_one(str(path))
if i % 1000 == 0: print('loading:', i)
return images
def gen(self, batch_size):
idx = np.random.randint(0, self.length, batch_size)
if self.once:
# return self.images[idx]
return np.transpose(self.images[idx], (0, 3, 1, 2))
else:
images = [self.load_one(str(self.images_path[i])) for i in idx]
images = np.array(images)
# return images
return np.transpose(images, (0, 3, 1, 2))
dg = DataGen(r'E:\Data\anime_girls\images', (64, 64, 3), once=False)
print(dg.gen(32).shape)
这种就不适用于使用pytorch官方函数了
import pathlib
import numpy as np
import cv2
class DataGen():
def __init__(self, path, img_shape=(64, 64, 3)):
self.img_shape = img_shape
self.images = np.load(path)
self.length = len(self.images)
def pro(self, image):
image = cv2.resize(image, (self.img_shape[1], self.img_shape[0]))
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = (image.astype(np.float32) - 127.5) / 127.5
return image
def gen(self, batch_size):
idx = np.random.randint(0, self.length, batch_size)
images = np.array([self.pro(image) for image in self.images[idx]])
return np.transpose(images, (0, 3, 1, 2))
dg = DataGen(r'E:\Data\anime_girls\numpy\anime_girls_64x64.npy', (64, 64, 3))
print(dg.gen(32).shape)
待续