【Pytorch】transforms.Compose,torchvision.datasets.ImageFolder,torch.utils.data.DataLoader的用法

pytorch通过深度学习进行预处理图片,离不开transforms.Compose(),torchvision.datasets.ImageFolder(),torch.utils.data.DataLoader()的用法。

本篇通过实例解读这三个函数的用法。

1.Transform.Compose()详解

导入相应的库

import torch
import torchvision
import matplotlib.pyplot as plt
from torch.utils import data
from torchvision import datasets,transforms
from PIL import Image
%matplotlib inline

展示原始图片

pic = "./train/Chihuahua/n02085620_10074.jpg"

img = plt.imread(pic)
plt.imshow(img)

【Pytorch】transforms.Compose,torchvision.datasets.ImageFolder,torch.utils.data.DataLoader的用法_第1张图片
定义图片预处理的对象。

traintransform = transforms.Compose([transforms.RandomRotation(20),           # 随机旋转20°
                                     transforms.ColorJitter(brightness=0.1), #随机改变图像的亮度对比度和饱和度
                                     transforms.Resize([150,150]),          # 转换为需要的尺寸
                                     transforms.ToTensor(),                #convert a PIL image to tensor (H*W*C)
                                    ])
img1 = Image.fromarray(img)   #将numpy对象的img转换为PIL格式
img2 = traintransform(img1)# 图像预处理tensor
img3 = transforms.ToPILImage()(img2)#转换为PIL进行展示
plt.imshow(img3)

展示处理之后的图片,可以看出,图片旋转了20°,并且大小转换为(150,150)
【Pytorch】transforms.Compose,torchvision.datasets.ImageFolder,torch.utils.data.DataLoader的用法_第2张图片

附上——transforms中的函数如何使用?
# Resize:把给定的图片resize到given size
# Normalize:Normalized an tensor image with mean and standard deviation
# ToTensor:convert a PIL image to tensor (H*W*C) in range [0,255] to a torch.Tensor(C*H*W) in the range [0.0,1.0]
# ToPILImage: convert a tensor to PIL image
# Scale:目前已经不用了,推荐用Resize
# CenterCrop:在图片的中间区域进行裁剪
# RandomCrop:在一个随机的位置进行裁剪
# RandomHorizontalFlip:以0.5的概率水平翻转给定的PIL图像
# RandomVerticalFlip:以0.5的概率竖直翻转给定的PIL图像
# RandomResizedCrop:将PIL图像裁剪成任意大小和纵横比
# Grayscale:将图像转换为灰度图像
# RandomGrayscale:将图像以一定的概率转换为灰度图像
# FiceCrop:把图像裁剪为四个角和一个中心
# TenCrop
# Pad:填充
# ColorJitter:随机改变图像的亮度对比度和饱和度。

2.Torchvision.datasets.ImageFolder()详解

trainpath = "./train"      #数据库路径
batch_size = 64

traintransform = transforms.Compose([transforms.RandomRotation(20),           # 随机旋转20°
                                     transforms.ColorJitter(brightness=0.1), #随机改变图像的亮度对比度和饱和度
                                     transforms.Resize([150,150]),          # 转换为需要的尺寸
                                     transforms.ToTensor(),                #convert a PIL image to tensor (H*W*C) 
                                    ])
trainData = torchvision.datasets.ImageFolder(trainpath,transform=traintransform)

torchvision.datasets.ImageFolder 有 root, transform, target_transform, loader四个参数,root:图片存储的根目录,即各类别文件夹所在目录的上一级目录,在下面的例子中是’./train/’。
在这里插入图片描述
在这里插入图片描述

文件格式如以下:
train/Chihuahua/xxx.png
train/Chihuahua/xxy.png
train/Chihuahua/xxz.png

train/Japanese_spaniel/123.png
train/Japanese_spaniel/nsdf3.png
train/Japanese_spaniel/asd932_.png

transform:对图片进行预处理的操作(函数),原始图片作为输入,返回一个转换后的图片。target_transform:对图片类别进行预处理的操作,输入为 target,输出对其的转换。如果不传该参数,即对 target 不做任何转换,返回的顺序索引 0,1, 2…loader:表示数据集加载方式,通常默认加载方式即可。

我们得到的trainData,它的结构就是[(img_data,class_id),(img_data,class_id),…],下面我们打印第一个元素:
trainData[0]

【Pytorch】transforms.Compose,torchvision.datasets.ImageFolder,torch.utils.data.DataLoader的用法_第3张图片

3.Torch.utils.data.DataLoader()详解

参数说明:

torch.utils.data.DataLoader(
      dataset   			#数据加载
      batch_size = 1		#批处理样本大小
      shuffle = False		#是否在每一轮epoch打乱样本顺序
      sampler = None		#指定数据加载中使用的索引/键的序列
      batch_sampler = None	#和sampler类似
      num_workers = 0		#是否进行多进程加载数据设置
      collat​​e_fn = None		#是否合并样本列表以形成一小批Tensor
      pin_memory = False	#如果True,数据加载器会在返回之前将Tensors复制到CUDA固定内存
      drop_last = False		#True若数据集大小不能被batch_size整除,则删除最后一个不完整的批处理。
      timeout = 0			#如果为正,则为从工作人员收集批处理的超时值
      worker_init_fn = None )
trainLoader = torch.utils.data.DataLoader(dataset=trainData,batch_size=batch_size,shuffle=True)
加载完之后,trainloader究竟是个什么玩意,可以借助以下的方式打开查看。
for i in trainLoader:
    print(i[0].shape,i[1].shape)

在这里插入图片描述
以上,我们加载了152+185=337张图片。通过以上可以看出,trainLoader每次加载64张照片+64个标签值,并且总和为337=64*5+17.
至此结束,希望对大家有用。

你可能感兴趣的:(pytorch,python)