【PyTorch】教程:学习基础知识-(4) Transforms

Transforms

在训练机器学习算法中,原始数据并不是以最终的形式出现,我们必须将数据做一定的转换,才能在训练中使用。

所有的 TorchVision datasets 有两个参数: transform 改变特征, target_transform 改变标签。它们接受包含转换逻辑的可调用对象。torchvision.transforms 模块提供了常用的开箱即用的常用转换。

FashionMNIST 特性为 PIL Image 格式,标签为整数。为了训练,我们需要将特征作为标准化 tensors ,将标签作为 one-hot 编码的 tensors. 为了进行这些转换,我们使用 ToTensorLambda

import torch 
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda

ds = datasets.FashionMNIST(
    root = "../../data/",
    train = True,
    download = True,
    transform = ToTensor(),
    target_transform = Lambda(lambda y: 
        torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))
)

ToTensor()

ToTensor 将 PIL image 或者 NumPy ndarray 转换为 FloatTensor, 并且将像素亮度值缩放到 [0., 1.]

Lambda Transforms

Lambda 转换应用任何用户定义的 Lambda 函数。在这里,我们定义了一个函数来将整数转换为一个 one-hot 编码的 tensor 。它首先创建一个大小为 10 (根据数据集标签数量而定) 的 0值 张量,并调用 scatter_ ,它在标签 y 给出的索引上赋值 1

【参考】

Transforms — PyTorch Tutorials 1.13.1+cu117 documentation

你可能感兴趣的:(PyTorch,pytorch,学习,深度学习)