pytorch格式变换


title: Pytorch学习笔记-数据格式变换

学习笔记和实现代码详见如下:

import torch
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda

ds = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
    target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))
)
"""
ToTensor将PIL image或NumPy ndarray转换为FloatTensor。并将图像像素强度值缩放到[0.,1.)

Lambda转换应用任何用户定义的Lambda函数。
这里,我们定义了一个函数来将整数转换为一个单热编码张量。
它首先创建一个大小为10的零张量(数据集中标签的数量),并调用scatter_,它在标签y给出的索引上赋值为1。
"""

你可能感兴趣的:(Pytorch,pytorch,深度学习,python)