pytorch快速入门(九)线性层Linear(), torch.flatten()

1、线性层linear

1、官方简介

pytorch快速入门(九)线性层Linear(), torch.flatten()_第1张图片
简单解释:下图中的input layer为in_features, hidden layer为out_features, 经过线性层将x转换为g
pytorch快速入门(九)线性层Linear(), torch.flatten()_第2张图片

2、代码

import torch
import torchvision
from torch import nn
from torch.nn import Linear
from torch.utils.data import DataLoader

#导入数据集
datasets = torchvision.datasets.CIFAR10("E:/PycharmProjects/Pytoch_learning/dataset/CIFAR10", train=False,transform=torchvision.transforms.ToTensor())
#加载数据集
dataloader = DataLoader(datasets,batch_size=64)

#构建模型
class Tian(nn.Module):
    def __init__(self):
        super(Tian, self).__init__()
        #in_features=196608, out_features=10
        self.linear = Linear(196608, 10)

    def forward(self, input):
        output = self.linear(input)
        return output

#应用模型
ren = Tian()

for data in dataloader:
    imgs,tagets = data
    print(imgs.shape)#输出([64,3,32,32])
    #改变尺寸
    # output = torch.reshape(imgs, (1, 1, 1, -1))
    #使用flatten
    output = torch.flatten(imgs)
    # print("After reshape:", output.shape)
    print("After flatten:", output.shape)
    #应用模型
    output_linear = ren(output)
    print("使用linear后:", output_linear.shape)

说明:关于CIFAR数据集,本人已下载好,若未下载请参考新手数据集下载

3、结果分析

在这里插入图片描述
在这里插入图片描述

直接加载的图片尺寸:torch.Size()
使用torch.reshape后输出尺寸:After reshape:
使用torch.flatten后输出尺寸:After flatten:
经由Linear layer,其中in_features=196608, out_features=10输出尺寸:使用linear后

2、torch.flatten()

pytorch快速入门(九)线性层Linear(), torch.flatten()_第3张图片
其实就是“拉平”的意思,代码在1中

你可能感兴趣的:(pytorch,pytorch,深度学习,机器学习)