Pytorch生成数据集均值和方差

在深度学习训练过程中,我们需要对数据集进行normalize,使数据集中的各数据满足同一分布,更加容易收敛

在pytorch中提供了torchvision.transforms 接口来对数据进行归一化

from torchvision import transforms

trans = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])

img = trans(img)

这里的(0.485, 0.456, 0.406),(0.229, 0.224, 0.225)分别是均值和方差

他由是imagNet数据集计算出的,通常情况下,数据为 (people, buildings, animals, varied lighting/angles/backgrounds)等类型的情况下 均可以使用该均值和方差

我们也可以自己生成归一化的方差和均值:

from torchvision import transforms
from tqdm import tqdm 
from PIL import Image

def calculate(path):
    files = os.listdir(path)
    trans = transforms.ToTensor()
    mean = torch.zeros(3)
    std = torch.zeros(3)

    for file in tqdm(files):
        img = Image.open(path + file)
        img = trans(img)
        for i in range(3): # RGB图像的通道数,因为 transforms.Normalize()是在通道上进行归一化
            mean[i] += img[i, :, :].mean()
            std[i] += img[i, :, :].std()
    mean.div_(files.__len__())
    std.div_(files.__len__())
    print(mean, std)
    return mean, std

利用生成的mean, std赋值torchvision.transforms.Normalize(mean, std)
 

你可能感兴趣的:(python,opencv,深度学习)