torchvision包


title: torchvision包
date: 2018-07-09 10:13:47
tags:
- torchvision
categories:
- pytorch
- torchvision


1. 子模块/包

torchvision官网API

  • dataset包

    其下的类都继承自torch.utils.data.Dataset

  • utils

    这里的 torchvison.utils只处理图像,而torch.utils.data有一个重要的class: DataLoader

  • transforms

    进行图像的变换,用于增广数据集

  • models

    直接使用经典的网络结构(也可加载预训练参数)

2. utils

2.1. torchvision.utils.make_grid(tensors)

将tensors合并成tensor。tensor.numpy()为BRG模式的图片。官网API

  • 参数
    • tensors: [ BATCH×C×H×W ]
  • 返回
    • tensor: [C×H×W]
  • 注意
    • 不论传入的图片们(tensors)的通道数C1 or 3, 返回tensor的通道数都是3

3. transforms

3.1. transforms.Compose[list]

核心:对于传入的PIL image每次transform之后,将结果传入到下次transform操作中

def __call__(self, img):
  for t in self.transforms:
      img = t(img)
  return img

3.1.1. PIL图像

  • 输入需是PIL图像,其size是[Width, Height]。因此transforms.Compose(list)是有顺序的
    transforms.Compose([
          transforms.RandomResizedCrop(224), # 输入PIL image
          transforms.RandomHorizontalFlip(), # 输入PIL image
          transforms.ToTensor(), # 放在最后,将PIL image(size是W,H)转换成Tensor(size是C,H,W)
          transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
      ])
    
  • PIL图像的转换
    import torchvision.transforms as transforms
    import torch
    import numpy as np
    from PIL import Image
    if __name__=="__main__":
        # 1. 读取成PIL image
        img=Image.open("./test.jpeg")
        print(img.size) # (3840,2160)=>(width, height)
        # 2. PIL转换成ndarray
        np_img=np.array(img)
        print(np_img.shape) # (2160, 3840, 3) => (height,width,c)
        # 3. PIL转换成tensor
        totensor=transforms.ToTensor()
        tensor_img=totensor(img)
        print(tensor_img.size()) # torch.Size([3, 2160, 3840])
        # 4. tensor转换成PIL
        topil=transforms.ToPILImage()
        pil=topil(tensor_img)
        print(pil.size) # (3840, 2160)
    
    • 之所以可以将PIL转换成ndarray,是因为PIL Imagearray_like的,具体见stackoverflow

3.2. transforms.RandomResizedCrop(targetsize)

Crop the given PIL Image to random size and aspect ratio. Then, this crop is finally resized to given size.[选定随机的面积 and 这个面积的纵横比,来裁剪PIL图像。最后将裁剪好的图像resize到高、宽都为targetsize]

3.2.1.核心

def __call__(self, img):
    # (i,j)左上角坐标
    i, j, h, w = self.get_params(img, self.scale, self.ratio)
    # 先对img进行crop,再通过self.interpolation插值成self.size
    return F.resized_crop(img, i, j, h, w, self.size, self.interpolation)

3.3. Resize()

  • 参数para
    • 二维sequence like:则缩放到给定的size
    • int: 在等比例缩放的前提下,将最小边缩放到para。 i.e, if height > width, then image will be rescaled to (size * height / width, size)

3.4. ToTensor()

[0, 255]的PIL image或者ndarray(H * W * C)转换成[0.0, 1.0]的Tensor

4. datasets

from torchvision import datasets 构造自己的/已有的数据集

4.1. 公共点

  • datasets模块下所有的类(ImageFolder, mnist等)都继承自torch.util.data.Dataset
    • 因此也常常通过torch.utils.data.DataLoader辅助加载数据
  • 构造函数,都可以传入transform

4.2. ImageFolder

  • 参数
    • root: Root directory path。组织形式如下:
      # root_dir/class_name/*.[png|jpg...]
      root/dog/xxx.png
      root/dog/xxy.png
      root/dog/xxz.png
      
      root/cat/123.png
      root/cat/nsdf3.png        
      root/cat/asd932_.png
      
    • transform: 对于输入图片的变化,可用于数据增广
      • 所有的datasets下面的类,都接受transform
  • 属性
    • classes (list): List of the class names.
    • class_to_idx (dict): Dict with items (class_name, class_index).
    • samples (list): List of (sample path, class_index) tuples

4.3. 辅助类torch.utils.data.DataLoader

可以将Dataset传入

你可能感兴趣的:(torchvision包)