1.transforms()函数
import torchvision.transforms as transforms
transforms = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize()]
)
transforms()函数介绍:
函数 |
介绍 |
用法 |
ToTensor() |
将原始的PILImage格式或者numpy格式的数据格式化。 |
输入模式为的PIL Image 或 numpy |
Normalize() |
数据标准化归一化,即均值为0,标准差为1,可以加快模型的收敛 |
output = (input - mean) / std (mean:各通道的均值 std:各通道的标准差 inplace:是否原地操作) |
Resize() |
调整PILImage对象的尺寸 |
例如transforms.Resize([28, 28])就能将输入图片转化成28×28的输入特征图。 |
Grayscale() |
将图像转换成灰色图片 |
num_output_channels=1 是正常的灰图 当为3时,R=G=B |
2.ImageFolder()、DataLoader()函数
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
batch_size = 32
train_data = ImageFolder(root= '地址', transform=transform )
train_loader = DataLoader(dataset=train_data,
batch_size=batch_size,
pin_memory=True, num_workers=8, shuffle=True)
ImageFolder()函数介绍:
参数 |
介绍 |
root |
指定的路径下寻找图片 |
transform |
即上文中操作后的transform |
target_transform |
对label的转换 |
DataLoader()函数介绍:
epoch:所有的训练样本输入到模型中,为一个epoch(样本总数)
iteration:一批样本输入到模型中,成为一个Iteration(迭代次数)
batchszie:决定一个epoch有多少个Iteration(批尺寸)
迭代次数(iteration)=样本总数(epoch)/批尺寸(batchszie)
参数 |
介绍 |
dataset |
读取数据 |
batch_size |
每次训练样本的数量 |
shuffle |
epoch是否为乱序 |
num_workers |
是否多进程读取数据 |
drop_last |
样本不能被整除,余出数据是否舍去 |
pin_memory |
是否使用GPU训练 |