数据集官方下载:https://www.kaggle.com/c/dogs-vs-cats-redux-kernels-edition/data
参考资料:
import torch
import os
from PIL import Image
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
# 数据集(训练集)所在目录
# ../data/Dog_Cat/train
label_mp = {"cat":0, "dog":1}
# 根据路径获取图片信息
def get_data(data_dir):
data_dc = []
for root, dirs, files in os.walk(data_dir):
for img_name in files:
img_path = os.path.join(data_dir, img_name)
label = label_mp[img_name.split('.')[0]]
data_dc.append((img_path, int(label)))
return data_dc
# 自定义transform方法(椒盐噪声)
class MyTransform(object):
def __init__(self, snr=0.9):
super(MyTransform).__init__()
self.snr = snr
def __call__(self, img):
img_ = np.array(img).copy()
h, w, c = img_.shape
mask = np.random.choice((0, 1, 2), size=(h, w, 1), p=[self.snr, (1 - self.snr) / 2, (1 - self.snr) / 2])
mask = np.repeat(mask, c, axis=2)
img_[mask == 1] = 255
img_[mask == 2] = 0
img_ = Image.fromarray(img_).convert('RGB')
return img_
# 构建自己的Dataset
class DogCatDataset(torch.utils.data.Dataset):
def __init__(self, data_dir, transfrom=None):
super(DogCatDataset).__init__()
self.data_dir = data_dir
self.data_dc = get_data(data_dir)
self.transform = transform
def __getitem__(self, index):
img_path, label = self.data_dc[index]
img = Image.open(img_path).convert('RGB')
if self.transform != None:
img = self.transform(img)
return img, label
def __len__(self):
return len(self.data_dc)
basic_dir = os.path.join(os.getcwd(), '../', 'data','Dog_Cat')
transform = transforms.Compose([
transforms.Resize((32, 32)),
MyTransform(snr=0.95),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.2, 0.2, 0.2]),
])
train_dataset = DogCatDataset(data_dir=os.path.join(basic_dir, 'train'), transfrom=transform)
img_path = os.path.join(train_dataset.data_dc[0][0])
img = Image.open(img_path).convert('RGB')
plt.imshow(img)
plt.show()
# Dataloader
train_loader = DataLoader(dataset=train_dataset, batch_size=16, shuffle=True)
# 打印一个batch
for i, data in enumerate(train_loader):
img, label = data
print(img.size())
print(label)
break
torch.Size([16, 3, 32, 32])
tensor([1, 1, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0])