链接:数据集
提取码:onda
pytorch给我们提供了很多已经封装好的数据集,但是我们经常得使用自己找到的数据集,因此,想要得到一个好的训练结果,合理的数据处理是必不可少的。我们以1400张猫狗图片来进行分析。
训练集包含500张狗的图片以及500张猫的图片,测试接包含200张狗的图片以及200张猫的图片。
def init_process(path, lens):
data = []
name = find_label(path)
for i in range(lens[0], lens[1]):
data.append([path % i, name])
return data
现有数据的命名都是有序号的,训练集中数据编号为0-499,测试集中编号为1000-1200,因此我们可以根据这个规律来读取文件名,比如参数传入:
path1 = 'cnn_data/data/training_data/cats/cat.%d.jpg'
data1 = init_process(path1, [0, 500])
data1就是一个包含五百个文件名以及标签的列表。find_label来判断标签是dog还是cat:
def find_label(str):
first, last = 0, 0
for i in range(len(str) - 1, -1, -1):
if str[i] == '%' and str[i - 1] == '.':
last = i - 1
if (str[i] == 'c' or str[i] == 'd') and str[i - 1] == '/':
first = i
break
name = str[first:last]
if name == 'dog':
return 1
else:
return 0
dog返回1,cat返回0。
有了上面两个函数之后,我们经过四次操作,就可以得到四个列表:
path1 = 'cnn_data/data/training_data/cats/cat.%d.jpg'
data1 = init_process(path1, [0, 500])
path2 = 'cnn_data/data/training_data/dogs/dog.%d.jpg'
data2 = init_process(path2, [0, 500])
path3 = 'cnn_data/data/testing_data/cats/cat.%d.jpg'
data3 = init_process(path3, [1000, 1200])
path4 = 'cnn_data/data/testing_data/dogs/dog.%d.jpg'
data4 = init_process(path4, [1000, 1200])
随便输出一个列表的前五个:
[['cnn_data/data/testing_data/dogs/dog.1000.jpg', 1], ['cnn_data/data/testing_data/dogs/dog.1001.jpg', 1], ['cnn_data/data/testing_data/dogs/dog.1002.jpg', 1], ['cnn_data/data/testing_data/dogs/dog.1003.jpg', 1], ['cnn_data/data/testing_data/dogs/dog.1004.jpg', 1]]
def Myloader(path):
return Image.open(path).convert('RGB')
class MyDataset(Dataset):
def __init__(self, data, transform, loder):
self.data = data
self.transform = transform
self.loader = loder
def __getitem__(self, item):
img, label = self.data[item]
img = self.loader(img)
img = self.transform(img)
return img, label
def __len__(self):
return len(self.data)
里面有2个比较重要的函数:
transform = transforms.Compose([
transforms.CenterCrop(224),
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) # 归一化
])
对上面四个操作做一些解释:
1)、transforms.CenterCrop(224)
,从图像中心开始裁剪图像,224为裁剪大小
2)、transforms.Resize((224, 224))
,重新定义图像大小
3)、transforms.ToTensor()
,很重要的一步,将图像数据转为Tensor
4)、transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
,归一化
因此我们只需要:
train_data = data1 + data2 + data3[0:150] + data4[0:150]
train = MyDataset(train_data, transform=transform, loder=Myloader)
test_data = data3[150:200] + data4[150:200]
test= MyDataset(test_data, transform=transform, loder=Myloader)
就可以得到处理好的Dataset,其中训练集我给了1300张图片,测试集只给了100张。
train_data = DataLoader(dataset=train, batch_size=5, shuffle=True, num_workers=0, pin_memory=True)
test_data = DataLoader(dataset=test, batch_size=1, shuffle=True, num_workers=0, pin_memory=True)
最后我们只要给定义好的神经网络模型喂数据就OK了!!!
完整代码:
# -*- coding: utf-8 -*-
"""
@Time : 2020/8/18 9:11
@Author :KI
@File :CNN.py
@Motto:Hungry And Humble
"""
import torch
from torch import optim
import torch.nn as nn
from torch.autograd import Variable
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
def Myloader(path):
return Image.open(path).convert('RGB')
#得到一个包含路径与标签的列表
def init_process(path, lens):
data = []
name = find_label(path)
for i in range(lens[0], lens[1]):
data.append([path % i, name])
return data
class MyDataset(Dataset):
def __init__(self, data, transform, loder):
self.data = data
self.transform = transform
self.loader = loder
def __getitem__(self, item):
img, label = self.data[item]
img = self.loader(img)
img = self.transform(img)
return img, label
def __len__(self):
return len(self.data)
def find_label(str):
first, last = 0, 0
for i in range(len(str) - 1, -1, -1):
if str[i] == '%' and str[i - 1] == '.':
last = i - 1
if (str[i] == 'c' or str[i] == 'd') and str[i - 1] == '/':
first = i
break
name = str[first:last]
if name == 'dog':
return 1
else:
return 0
def load_data():
transform = transforms.Compose([
#transforms.RandomHorizontalFlip(p=0.5),
#transforms.RandomVerticalFlip(p=0.5),
transforms.CenterCrop(224),
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) # 归一化
])
path1 = 'cnn_data/data/training_data/cats/cat.%d.jpg'
data1 = init_process(path1, [0, 500])
path2 = 'cnn_data/data/training_data/dogs/dog.%d.jpg'
data2 = init_process(path2, [0, 500])
path3 = 'cnn_data/data/testing_data/cats/cat.%d.jpg'
data3 = init_process(path3, [1000, 1200])
path4 = 'cnn_data/data/testing_data/dogs/dog.%d.jpg'
data4 = init_process(path4, [1000, 1200])
train_data = data1 + data2 + data3[0:150] + data4[0:150]
train = MyDataset(train_data, transform=transform, loder=Myloader)
test_data = data3[150:200] + data4[150:200]
test= MyDataset(test_data, transform=transform, loder=Myloader)
train_data = DataLoader(dataset=train, batch_size=5, shuffle=True, num_workers=0, pin_memory=True)
test_data = DataLoader(dataset=test, batch_size=1, shuffle=True, num_workers=0, pin_memory=True)
return train_data, test_data
train_data以及test_data就是我们最终需要得到的数据。
对猫狗数据分类的具体实现请见:CNN简单实战:pytorch搭建CNN对猫狗图片进行分类