计算pytorch标准化(Normalize)所需要数据集的均值和方差

先说明一下情况

1、如果是自己的数据集,mean 和 std 肯定要在normalize之前自己先算好再传进去的

2、有两种情况:
a)数据集在加载的时候就已经转换成了[0, 1].
b)应用了torchvision.transforms.ToTensor,其作用是
( Converts a PIL Image or numpy.ndarray (H x W x C) in the range [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] )

3、[0.485, 0.456, 0.406]这一组平均值一般是抽样算出来的。

如何计算数据集的mean和std

数据存放路劲:

计算pytorch标准化(Normalize)所需要数据集的均值和方差_第1张图片

import os
import random
from PIL import Image
from torch.utils.data import Dataset

random.seed(1)
rmb_label = {"eyesclosed": 0, "lookingarroud": 1, "safedriving": 2, "smoking": 3, "yawning": 4}


class MyDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        """
        rmb面额分类任务的Dataset
        :param data_dir: str, 数据集所在路径
        :param transform: torch.transform,数据预处理
        """
        self.label_name = {"eyesclosed": 0, "lookingarroud": 1, "safedriving": 2, "smoking": 3, "yawning": 4}
        self.data_info = self.get_img_info(data_dir)  # data_info存储所有图片路径和标签,在DataLoader中通过index读取样本
        self.transform = transform

    def __getitem__(self, index):
        path_img, label = self.data_info[index]
        img = Image.open(path_img).convert('RGB')     # 0~255

        if self.transform is not None:
            img = self.transform(img)   # 在这里做transform,转为tensor等等

        return img, label

    def __len__(self):
        return len(self.data_info)

    @staticmethod
    def get_img_info(data_dir):
        data_info = list()
        for root, dirs, _ in os.walk(data_dir):
            # 遍历类别
            for sub_dir in dirs:
                img_names = os.listdir(os.path.join(root, sub_dir))
                img_names = list(filter(lambda x: x.endswith('.jpg'), img_names))

                # 遍历图片
                for i in range(len(img_names)):
                    img_name = img_names[i]
                    path_img = os.path.join(root, sub_dir, img_name)
                    label = rmb_label[sub_dir]
                    data_info.append((path_img, int(label)))

        return data_info
import torch
import numpy as np
import torchvision.transforms as transforms
import torchvision
from torch.utils.data import DataLoader
from my_dataset import MyDataset
import os


train_dir = os.path.join('.', "train")

train_transform = transforms.Compose([
    transforms.Resize((32, 32)), # 可以改成你图片近似大小或者模型要求大小
    transforms.ToTensor(),
])

train_data = MyDataset(data_dir=train_dir, transform=train_transform)
train_loader = DataLoader(dataset=train_data, batch_size=500, shuffle=True) # # 500张图片的mean std
train = iter(train_loader).next()[0]  # 500张图片的mean std
train_mean = np.mean(train.numpy(), axis=(0, 2, 3))
train_std = np.std(train.numpy(), axis=(0, 2, 3))


print(train_mean, train_std)

你可能感兴趣的:(计算pytorch标准化(Normalize)所需要数据集的均值和方差)