torchvison源码剖析【1】transform.ToTensor()

PIL转tensor:

def pil2tensor():
    img=Image.open(img_path)
    totensor=torchvision.transforms.ToTensor()
    #torchvision.transforms.ToTensor()函数自动转格式
    img=totensor(img)
    img=img.cpu()
    print(img.size())# torch.Size([3, 300, 533])
    print(img.dtype)# torch.float32
    print(type(img))# 
pil2tensor()

源码:

class ToTensor(object):
    """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.

    Converts a PIL Image or numpy.ndarray (H x W x C) in the range
    [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0].
    """

    def __call__(self, pic):
        """
        Args:
            pic (PIL Image or numpy.ndarray): Image to be converted to tensor.

        Returns:
            Tensor: Converted image.
        """
        #核心还是调用Functional
        return F.to_tensor(pic)

    def __repr__(self):
        return self.__class__.__name__ + '()'

在Functional中的源码:

def to_tensor(pic):
  """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.

  See ``ToTensor`` for more details.

  Args:
      pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
      #可以处理PIL,也可以处理Opencv

  Returns:
      Tensor: Converted image.
  """
  #类型检查
  if not(_is_pil_image(pic) or _is_numpy_image(pic)):
      raise TypeError('pic should be PIL Image or ndarray. Got {}'.format(type(pic)))

  if isinstance(pic, np.ndarray):
      # handle numpy array
      #交换维度,因为ndarray中为H  W C的顺序
      img = torch.from_numpy(pic.transpose((2, 0, 1)))
      # backward compatibility
      if isinstance(img, torch.ByteTensor):
          return img.float().div(255)#归一化
      else:
          return img

  if accimage is not None and isinstance(pic, accimage.Image):
      nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32)
      pic.copyto(nppic)
      return torch.from_numpy(nppic)

  # handle PIL Image
  #处理 PIL中不同的数据类型
  if pic.mode == 'I':
      img = torch.from_numpy(np.array(pic, np.int32, copy=False))
  elif pic.mode == 'I;16':
      img = torch.from_numpy(np.array(pic, np.int16, copy=False))
  elif pic.mode == 'F':
      img = torch.from_numpy(np.array(pic, np.float32, copy=False))
  elif pic.mode == '1':
      img = 255 * torch.from_numpy(np.array(pic, np.uint8, copy=False))
  else:
      img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
  # PIL image mode: L, P, I, F, RGB, YCbCr, RGBA, CMYK
  if pic.mode == 'YCbCr':
      nchannel = 3
  elif pic.mode == 'I;16':
      nchannel = 1
  else:
      nchannel = len(pic.mode)
  img = img.view(pic.size[1], pic.size[0], nchannel)
  # put it from HWC to CHW format
  # yikes, this transpose takes 80% of the loading time/CPU
  img = img.transpose(0, 1).transpose(0, 2).contiguous()
  if isinstance(img, torch.ByteTensor):
      return img.float().div(255)
  else:
      return img

你可能感兴趣的:(pytorch)