ImageFolder是一个非常有用的类,只要数据集按照要求规范文件,就可以很轻松的,得到 文件路径和类型 的元祖。同时加载DataLoader 也非常方便,但是在实际用的时候发现缺少了划分数据集的功能,并且是按照数据也是按照文件夹依次得到的,这对划分数据集上非常不利的。
通过阅读ImageFolder源码并在其基础上继承并添加自己的功能。
import time
import torch
import visdom
import torchvision
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import Dataset, DataLoader # DataLoader可以实现一个加载一个batch的功能
import random
def denormalize(x_hat):
"""
将normalize后的图片返回原来正常的图片
:param x_hat:
:return:
"""
# x_hat[channel] = (x[channel] - mean[channel]) / std[channel]
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
# x:[c,h,w] mean:[3]->[3,1,1]
mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1)
std = torch.tensor(std).unsqueeze(1).unsqueeze(1)
x = x_hat * std + mean
return x
class Pokemon(ImageFolder):
def __init__(self,root,model):
super(Pokemon, self).__init__(root)
# 直接类内定义transforms的compose变换
tf = transforms.Compose([
# 自定义函数
transforms.Resize((280, 280)),
transforms.RandomRotation(15),
# 可能会有边缘没被看到,而出现黑边, 需要进行中心裁剪
transforms.CenterCrop(224),
transforms.ToTensor(),
# 归一化,数值是有imagenet统计得出的,更有普遍性
# 输出不再是0-1之间分布,而是在-1 - 1之间分布
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
self.transform = tf
# 打乱数据
random.seed(1234)
random.shuffle(self.imgs)
random.seed(1234)
random.shuffle(self.samples)
if model == "train": # 60%
self.imgs = self.imgs[:int(0.6*len(self.imgs))]
self.samples = self.samples[:int(0.6*len(self.samples))]
elif model == "val": # 20%
self.imgs = self.imgs[int(0.6 * len(self.imgs)):int(0.8 * len(self.imgs))]
self.samples = self.samples[int(0.6 * len(self.imgs)):int(0.8 * len(self.imgs))]
else: # 20% = 80% -> 100%
self.imgs = self.imgs[int(0.8 * len(self.imgs)):]
self.samples = self.samples[int(0.8 * len(self.imgs)):]
if __name__ == '__main__':
viz = visdom.Visdom()
# ImageFolder会自动完成文件夹的编码
db = Pokemon(root="pokemon", model='train')
print(db.class_to_idx)
# bath进行加载
loader = DataLoader(db, batch_size=32)
for x, y in loader: # 这里的x,y是一个batch的
viz.images(denormalize(x), nrow=8, win="batch", opts=dict(title="batch"))
print(y)
viz.text(str(y.tolist()), win="label", opts=dict(title="batch-label"))
time.sleep(10)
Setting up a new session...
{'bulbasaur': 0, 'charmander': 1, 'mewtwo': 2, 'pikachu': 3, 'squirtle': 4}
tensor([4, 3, 1, 0, 1, 3, 2, 3, 1, 0, 1, 3, 2, 0, 0, 4, 0, 0, 3, 0, 0, 0, 1, 4,
0, 2, 3, 2, 2, 2, 4, 3])