transforms.normalize如何对特定数据集设定标准化参数

在图像分类实验中,经常能看到对数据集进行数据增强操作,其中包括transforms.Normalize(),这个函数的定义如下:

torchvision.transforms.Normalize(mean, std, inplace=False)

功能:针对RGB3个 channel 分别对图像进行标准化

output = ( input - mean ) / std

  • mean: 各通道的均值
  • std: 各通道的标准差
  • inplace: 是否原地操作

通常ImageNet有自己的标准化参数,是通过抽样统计图像的均值方差得到的,那么针对本地特定数据集,如何获取到适合的参数呢?我参考了PyTorch数据归一化处理:transforms.Normalize及计算图像数据集的均值和方差_紫芝的博客-CSDN博客_pytorch 数据归一化

原文代码有一处错误,需要先把transform设置为transforms.ToTensor(),而不是None,否则会运行错误。以下是改正后的代码:

def getStat(train_data):
    '''
    Compute mean and variance for training data
    :param train_data: 自定义类Dataset(或ImageFolder即可)
    :return: (mean, std)
    '''
    print('Compute mean and variance for training data.')
    print(len(train_data))
    train_loader = torch.utils.data.DataLoader(
        train_data, batch_size=1, shuffle=False, num_workers=0,
        pin_memory=True)
    mean = torch.zeros(3)
    std = torch.zeros(3)
    for X, _ in train_loader:
        for d in range(3):
            mean[d] += X[:, d, :, :].mean()
            std[d] += X[:, d, :, :].std()
    mean.div_(len(train_data))
    std.div_(len(train_data))
    return list(mean.numpy()), list(std.numpy())


if __name__ == '__main__':
    train_dataset = ImageFolder(root=r'/data1/sharedata/leafseg/', transform=transforms.ToTensor())
    print(getStat(train_dataset))

Compute mean and variance for training data.
3257
([0.059938803, 0.08676067, 0.041085023], [0.10522498, 0.1488454, 0.07508467])

将结果写入transform列表中即可。 

data_transforms = {
    'train': transforms.Compose([
        transforms.Resize(640),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.0599, 0.0868, 0.0411], [0.1052, 0.1488, 0.0751])
    ]),
    'val': transforms.Compose([
        transforms.Resize(640),
        transforms.ToTensor(),
        transforms.Normalize([0.0599, 0.0868, 0.0411], [0.1052, 0.1488, 0.0751])
    ]),
}

你可能感兴趣的:(图像分类,数据增强,数据集,深度学习,人工智能)