PyTorch学习笔记(5)——Kaggle猫狗分类问题数据集读取和构建

Kaggle猫狗分类问题数据集处理

数据集官方下载:https://www.kaggle.com/c/dogs-vs-cats-redux-kernels-edition/data

参考资料:

  • https://github.com/ytchx1999/Pytorch-Camp
  • https://github.com/greebear/pytorch-learning
import torch
import os
from PIL import Image
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
# 数据集(训练集)所在目录
# ../data/Dog_Cat/train
label_mp = {"cat":0, "dog":1}
# 根据路径获取图片信息
def get_data(data_dir):
    data_dc = []
    for root, dirs, files in os.walk(data_dir):
        for img_name in files:
            img_path = os.path.join(data_dir, img_name)
            label = label_mp[img_name.split('.')[0]]
            data_dc.append((img_path, int(label)))
    return data_dc
# 自定义transform方法(椒盐噪声)
class MyTransform(object):
    def __init__(self, snr=0.9):
        super(MyTransform).__init__()
        self.snr = snr
    def __call__(self, img):
        img_ = np.array(img).copy()
        h, w, c = img_.shape
        mask = np.random.choice((0, 1, 2), size=(h, w, 1), p=[self.snr, (1 - self.snr) / 2, (1 - self.snr) / 2])
        mask = np.repeat(mask, c, axis=2)
        img_[mask == 1] = 255
        img_[mask == 2] = 0
        img_ = Image.fromarray(img_).convert('RGB')
        return img_
# 构建自己的Dataset
class DogCatDataset(torch.utils.data.Dataset):
    def __init__(self, data_dir, transfrom=None):
        super(DogCatDataset).__init__()
        self.data_dir = data_dir
        self.data_dc = get_data(data_dir)
        self.transform = transform
    
    def __getitem__(self, index):
        img_path, label = self.data_dc[index]
        img = Image.open(img_path).convert('RGB')
        if self.transform != None:
            img = self.transform(img)
        return img, label
    
    def __len__(self):
        return len(self.data_dc)
basic_dir = os.path.join(os.getcwd(), '../', 'data','Dog_Cat')
transform = transforms.Compose([
    transforms.Resize((32, 32)),
    MyTransform(snr=0.95),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.2, 0.2, 0.2]),
])
train_dataset = DogCatDataset(data_dir=os.path.join(basic_dir, 'train'), transfrom=transform)
img_path = os.path.join(train_dataset.data_dc[0][0])
img = Image.open(img_path).convert('RGB')
plt.imshow(img)
plt.show()

PyTorch学习笔记(5)——Kaggle猫狗分类问题数据集读取和构建_第1张图片

# Dataloader
train_loader = DataLoader(dataset=train_dataset, batch_size=16, shuffle=True)
# 打印一个batch
for i, data in enumerate(train_loader):
    img, label = data
    print(img.size())
    print(label)
    break
torch.Size([16, 3, 32, 32])
tensor([1, 1, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0])

你可能感兴趣的:(Pytorch学习笔记,pytorch,python,深度学习,kaggle)