基于pytorch图像数据的导入

1.transforms()函数

import torchvision.transforms as transforms
transforms = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize()]
)

transforms()函数介绍:

函数 介绍 用法
ToTensor() 将原始的PILImage格式或者numpy格式的数据格式化。 输入模式为的PIL Image 或 numpy
Normalize() 数据标准化归一化,即均值为0,标准差为1,可以加快模型的收敛 output = (input - mean) / std (mean:各通道的均值 std:各通道的标准差 inplace:是否原地操作)
Resize() 调整PILImage对象的尺寸 例如transforms.Resize([28, 28])就能将输入图片转化成28×28的输入特征图。
Grayscale() 将图像转换成灰色图片 num_output_channels=1 是正常的灰图 当为3时,R=G=B

2.ImageFolder()、DataLoader()函数

from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
batch_size = 32
train_data = ImageFolder(root= '地址', transform=transform )
train_loader = DataLoader(dataset=train_data,
batch_size=batch_size,
pin_memory=True, num_workers=8, shuffle=True)

ImageFolder()函数介绍:

参数 介绍
root 指定的路径下寻找图片
transform 即上文中操作后的transform
target_transform 对label的转换

DataLoader()函数介绍:

epoch:所有的训练样本输入到模型中,为一个epoch(样本总数)
iteration:一批样本输入到模型中,成为一个Iteration(迭代次数)
batchszie:决定一个epoch有多少个Iteration(批尺寸)
迭代次数(iteration)=样本总数(epoch)/批尺寸(batchszie)

参数 介绍
dataset 读取数据
batch_size 每次训练样本的数量
shuffle epoch是否为乱序
num_workers 是否多进程读取数据
drop_last 样本不能被整除,余出数据是否舍去
pin_memory 是否使用GPU训练

你可能感兴趣的:(pytorch,python)