在训练经典的数据集如cifar10,minsit等,可以用官方自带的数据集格式几行就写出来,如果是自己下载的数据集,那么我们应该如何用pytorch来读取呢?其实是有模板可以直接仿照着写的。
本次案例采用的是pokeman数据集,并用该数据集进行分类。该数据如下所示:
其中文件夹的名字便是标签。数据集大小划分为:皮卡丘 234、超梦239、杰尼龟223、小火龙 238、妙蛙种子234张图。
在深度学习中一般的流程是:加载数据—>构建模型—>训练和测试。
在pytorch读取数据,采用3个步骤
from torch.utils.data.Dataset
from torch.utils.data import Dataset, DataLoader
class NumberDataset(Dataset): #首先要继承Dataset母类
def __init__(self, training=True): #区分训练和测试
if training:
self.samples = list(range(1, 1001)) #加载数据,一般是存放数据的地址,不然内存爆炸
else:
self.samples = list(range(1001, 15001))
def __len__(self):
return len(self.samples) #
def __getitem__(self, idx): # idx 是位置标号,在len(self.samples) 内,一个一个的读取该位置数据
return self.samples[idx]
小结:1、首先得到所有的数据的地址名字(训练或测试);2、给出数据集长度;3、返回指定位置的数据内容,可以在该数据上进行任何预处理操作。
python代码框架为:
from torch.utils.data import Dataset, DataLoader #自定义的母类,必须的
class Pokemon(Dataset):
def __init__(self): #去读数据路径
super(Pokemon, self).__init__()
pass
def __len__(self): #返回数据长度
pass
def __getitem__(self, idx): #返回当前位置的数据和标签
pass
接下来就是填充每一块函数里面的内容了。
首先需要加载数据和标签,因为标签需要转化成0,1,2,3,4,最好保存为csv文件,下次便可以直接加载csv文件。因此我们需要事先写一个函数保存csv文件,不写也可以,最好是写成csv。
下面这个函数可以单独写成一个文件,也可以放在class Pokemon(Dataset)里面。
def load_csv(self, filename):
if not os.path.exists(os.path.join(self.root, filename)):
#如果没有保存csv文件,那么我们需要写一个csv文件,如果有了直接读取csv文件
images = []
for name in self.name2label.keys():
# 'pokemon\\mewtwo\\00001.png
images += glob.glob(os.path.join(self.root, name, '*.png'))
images += glob.glob(os.path.join(self.root, name, '*.jpg'))
images += glob.glob(os.path.join(self.root, name, '*.jpeg'))
# 1167, 'pokemon\\bulbasaur\\00000000.png'
print(len(images), images)
random.shuffle(images)
with open(os.path.join(self.root, filename), mode='w', newline='') as f:
writer = csv.writer(f)
for img in images: # 'pokemon\\bulbasaur\\00000000.png'
name = img.split(os.sep)[-2] #从名字就可以读取标签
label = self.name2label[name]
# 'pokemon\\bulbasaur\\00000000.png', 0
writer.writerow([img, label]) #写进csv文件
print('writen into csv file:', filename)
# read from csv file
images, labels = [], []
with open(os.path.join(self.root, filename)) as f:
reader = csv.reader(f)
for row in reader:
# 'pokemon\\bulbasaur\\00000000.png', 0
img, label = row
label = int(label)
images.append(img)
labels.append(label)
assert len(images) == len(labels)
return images, labels
上面函数可以得到数据地址及其标签,接下来就是初始化,得到数据地址名和标签保存:
def __init__(self, root, resize, mode):
super(Pokemon, self).__init__()
self.root = root
self.resize = resize
self.name2label = {} # "sq...":0
for name in sorted(os.listdir(os.path.join(root))):
if not os.path.isdir(os.path.join(root, name)):
continue
self.name2label[name] = len(self.name2label.keys()) #将英文标签名转化数字0-4
# print(self.name2label)
# image, label
self.images, self.labels = self.load_csv('images.csv') #csv文件存在 直接读取
if mode == 'train': # 60%
self.images = self.images[:int(0.6 * len(self.images))]
self.labels = self.labels[:int(0.6 * len(self.labels))]
elif mode == 'val': # 20% = 60%->80%
self.images = self.images[int(
0.6 * len(self.images)):int(0.8 * len(self.images))]
self.labels = self.labels[int(
0.6 * len(self.labels)):int(0.8 * len(self.labels))]
else: # 20% = 80%->100%
self.images = self.images[int(0.8 * len(self.images)):]
self.labels = self.labels[int(0.8 * len(self.labels)):]
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
# idx~[0~len(images)]
# self.images, self.labels
# img: 'pokemon\\bulbasaur\\00000000.png'
# label: 0
img, label = self.images[idx], self.labels[idx]
tf = transforms.Compose([ #常用的数据变换器
lambda x:Image.open(x).convert('RGB'), # string path= > image data
#这里开始读取了数据的内容了
transforms.Resize( #数据预处理部分
(int(self.resize * 1.25), int(self.resize * 1.25))),
transforms.RandomRotation(15),
transforms.CenterCrop(self.resize), #防止旋转后边界出现黑框部分
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
img = tf(img)
label = torch.tensor(label) #转化tensor
return img, label #返回当前的数据内容和标签
完成上面的步骤,我们只能得到一个一个数据,且需用迭代器表示,即iter:
db = Pokemon('pokemon', 64, 'train')
x, y = next(iter(db))
print('sample:', x.shape, y.shape, y)
因此还需要DataLoader来加载批量的数据:
loader = DataLoader(db, batch_size=32, shuffle=True, num_workers=8)
for x, y in loader: #此时x,y是批量的数据
pass
当我们完成数据集读取部分,可视化也是必须的。我们采用的是visdom来可视化。
import visdom
import time
for x, y in loader:
viz.images(
db.denormalize(x), #因为对原始数据归一化,所以可视化需要返回去,该函数需要自己写下。
nrow=8, #每行显示8张图
win='batch',
opts=dict(title='batch'))
viz.text(str(y.numpy()), win='label', opts=dict(title='batch-y'))
time.sleep(10)
如果visdom连接超时,那么需要:
>python -m visdom.server
def denormalize(self, x_hat):
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
# x_hat = (x-mean)/std
# x = x_hat*std = mean
# x: [c, h, w]
# mean: [3] => [3, 1, 1]
mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1)
std = torch.tensor(std).unsqueeze(1).unsqueeze(1)
# print(mean.shape, std.shape)
x = x_hat * std + mean
return x
如果文件结构是二级目录,且代码和文件夹在同一个目录:
那么可以用一行代码来写:
tf = transforms.Compose([
transforms.Resize((64,64)),
transforms.ToTensor(),
])
db = torchvision.datasets.ImageFolder(root='pokemon', transform=tf)
loader = DataLoader(db, batch_size=32, shuffle=True)
print(db.class_to_idx)
for x,y in loader:
viz.images(x, nrow=8, win='batch', opts=dict(title='batch'))
viz.text(str(y.numpy()), win='label', opts=dict(title='batch-y'))
time.sleep(10)
用ImageFolder即可以写,不过该情况受限,因此不建议。还是用前面的函数自己去定义,方便对数据修改,或者额外引入标签。
接下来就是如何训练了,可参考我写的训练模板:https://blog.csdn.net/lifei1229/article/details/105530012
https://blog.csdn.net/lifei1229/article/details/105527312