torchvision.transforms.ToTensor()和torchvision.transforms.Normalize()

import torchvision.transforms as transform
transform.ToTensor()
transform.Normalize((0.5, 0.5, 0.5),(0.5, 0.5, 0.5))

torchvision.transforms.ToTensor()

一般读入图像像素值值域范围为[0, 255],ToTensor()能够把范围从[0, 255]变换到[0, 1]。
torchvision.transforms.ToTensor()和torchvision.transforms.Normalize()_第1张图片

class ToTensor(object):
    """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.

    Converts a PIL Image or numpy.ndarray (H x W x C) in the range
    [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]
    if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1)
    or if the numpy.ndarray has dtype = np.uint8

    In the other cases, tensors are returned without scaling.
    """

    def __call__(self, pic):
        """
        Args:
            pic (PIL Image or numpy.ndarray): Image to be converted to tensor.

        Returns:
            Tensor: Converted image.
        """
        return F.to_tensor(pic)

    def __repr__(self):
        return self.__class__.__name__ + '()'

torchvision.transforms.Normalize()

Normalize()把值域范围从[0, 1]变换到[-1, 1]。对每个通道执行value_n=(value-mean)/std。其中均值mean和标准差std分别由(0.5,0.5,0.5)和(0.5,0.5,0.5)指定,原来的最小值0就变成(0-0.5)/0.5=-1,最大值1变成(1-0.5)/0.5=1,最终从[0, 1]变成[-1, 1]。
torchvision.transforms.ToTensor()和torchvision.transforms.Normalize()_第2张图片

class Normalize(object):
    """Normalize a tensor image with mean and standard deviation.
    Given mean: ``(M1,...,Mn)`` and std: ``(S1,..,Sn)`` for ``n`` channels, this transform
    will normalize each channel of the input ``torch.*Tensor`` i.e.
    ``input[channel] = (input[channel] - mean[channel]) / std[channel]``

    .. note::
        This transform acts out of place, i.e., it does not mutates the input tensor.

    Args:
        mean (sequence): Sequence of means for each channel.
        std (sequence): Sequence of standard deviations for each channel.
    """

    def __init__(self, mean, std, inplace=False):
        self.mean = mean
        self.std = std
        self.inplace = inplace

    def __call__(self, tensor):
        """
        Args:
            tensor (Tensor): Tensor image of size (C, H, W) to be normalized.

        Returns:
            Tensor: Normalized Tensor image.
        """
        return F.normalize(tensor, self.mean, self.std, self.inplace)

    def __repr__(self):
        return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)

不同channels下torchvision.transforms.Normalize()的写法

channels=3通道的写法

  • 方法1(方括号[])
transform = transforms.Compose([  \
                transforms.ToTensor(), \
                transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])])
  • 方法2(括号())
transform = transforms.Compose([  \
                transforms.ToTensor(), \
                transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
  • 方法3(忽略关键字:mean,std)
transform = transforms.Compose([  \
                transforms.ToTensor(), \
                transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
transform = transforms.Compose([  \
                transforms.ToTensor(), \
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

channels=1通道的写法

  • 方法1(方括号[])
transform = transforms.Compose([  \
                transforms.ToTensor(), \
                transforms.Normalize(mean=[0.5], std=[0.5])])
  • 方法2(括号(),注意后面有个,)
transform = transforms.Compose([  \
                transforms.ToTensor(), \
                transforms.Normalize(mean=(0.5,), std=(0.5,))])
  • 方法3(忽略关键字:mean,std)
transform = transforms.Compose([  \
                transforms.ToTensor(), \
                transforms.Normalize([0.5], [0.5])])
transform = transforms.Compose([  \
                transforms.ToTensor(), \
                transforms.Normalize((0.5, ), (0.5, ))])

参考资料
PYTORCH函数之TORCHVISION.TRANSFORMS.TOTENSOR()和NORMALIZE()
torchvision中transform参数使用
『pytorch』Pytorch - 图像变换函数集合
torchvision 之 transforms 模块详解
torchvision.transforms.ToTensor(),torchvision.trasnsforms.Normalize()

你可能感兴趣的:(pytorch,torchvision,ToTensor,Normalize)