torchvision-transforms 常用函数总结

文章目录

  • torchvision-transforms 常用函数总结
    • 一、概述——为何要用transforms
    • 二、函数介绍
      • 1、ToTensor
      • 2、Normalize
      • 3、Resize(非常常用)
      • 4、Compose
      • 5、RandomCrop
    • 三、transforms和数据集的结合使用

torchvision-transforms 常用函数总结

一、概述——为何要用transforms

torchvision-transforms 常用函数总结_第1张图片
需求是多样的,因此可以通过实例化一个transforms,满足转换的需要。具体的class可以参考transforms.py中的描述

二、函数介绍

1、ToTensor

功能:将PIL.image读取的PIL类型图片或者cv2.imread读取的numpy.ndarray转化为tensor类型
最简单的函数,没什么参数,直接默认构造函数然后调用即可,具体如下:

from torchvision import transforms
from PIL import Image

if __name__ == '__main__':
    img_path = "data/hymenoptera_data/train/ants/5650366_e22b7e1065.jpg"
    img = Image.open(img_path)
    img2tensor = transforms.ToTensor()
    img_tensor = img2tensor(img)
    print(img_tensor)

2、Normalize

功能:输入RGB三通道的标准差和方差,输出正则化的图像矩阵

from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
from PIL import Image
import cv2

if __name__ == '__main__':
    img_path = "data/hymenoptera_data/train/ants/5650366_e22b7e1065.jpg"
    img = Image.open(img_path)
    img2tensor = transforms.ToTensor()
    img_tensor = img2tensor(img)
    writer = SummaryWriter("logs")
    writer.add_image("original", img_tensor)
    trans_norm = transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
    img_norm = trans_norm(img_tensor)
    writer.add_image("Normalize", img_norm)

    writer.close()

原图
torchvision-transforms 常用函数总结_第2张图片

正则化后的图像
torchvision-transforms 常用函数总结_第3张图片

3、Resize(非常常用)

功能
1、Resize([h, w])——对一个图像进行缩放,虽然会改变长宽比,但图像未发生裁剪,因此可以通过Resize再次还原回来
2、Resize(x) ——对短边缩放到x,长宽比不变

注意
PIL image 的size属性返回的是w, h而Resize参数顺序是h,w,切勿弄错

from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
from PIL import Image
import cv2

if __name__ == '__main__':
    img_path = "data/hymenoptera_data/train/ants/5650366_e22b7e1065.jpg"
    img = Image.open(img_path)
    writer = SummaryWriter("logs")
    trans_resize = transforms.Resize((512, 512))
    resized_img = trans_resize(img)
    img2tensor = transforms.ToTensor()
    img_tensor = img2tensor(resized_img)
    writer.add_image("resized", img_tensor)
    writer.close()

缩放后的结果
torchvision-transforms 常用函数总结_第4张图片

4、Compose

功能:组合变换,参数是各种变换组成的列表“[transform1, transform2, …]”

    trans_resize = transforms.Resize((512, 512))
    img2tensor = transforms.ToTensor()
    trans = transforms.Compose([trans_resize, img2tensor])
    img_tensor = trans(img)

5、RandomCrop

功能:随机裁剪,和Resize类似

from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
from PIL import Image
import cv2

if __name__ == '__main__':
    img_path = "data/hymenoptera_data/train/ants/5650366_e22b7e1065.jpg"
    img = Image.open(img_path)
    writer = SummaryWriter("logs")
    trans_random_crop = transforms.RandomCrop((300, 400))
    img2tensor = transforms.ToTensor()
    trans = transforms.Compose([trans_random_crop, img2tensor])
    for i in range(5):
        img_tensor = trans(img)
        writer.add_image("random crop", img_tensor, i)
    writer.close()

三、transforms和数据集的结合使用

方法:先查看数据集里都有啥(调试),然后根据需求加transform

import torchvision
from PIL import Image
from torch.utils.tensorboard import SummaryWriter

trans = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor()
])
if __name__ == '__main__':

    train_set = torchvision.datasets.CIFAR10(root="./dataset", train=True, transform=trans, download=True)
    test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False, transform=trans, download=True)

    print(test_set)
    img, target = test_set[1]
    writer = SummaryWriter("logs")
    writer.add_image("pic1", img, 1)
    print(test_set.classes[target])

你可能感兴趣的:(pytorch,深度学习,人工智能,python,pytorch)