【学习笔记】【Pytorch】十、线性层

【学习笔记】【Pytorch】九、线性层

  • 学习地址
  • 主要内容
    • 一、前言
    • 二、Pytorch的线性层
    • 三、Linear类的使用
      • 1.使用说明
      • 2.代码实现

学习地址

PyTorch深度学习快速入门教程【小土堆】.

主要内容

一、前言

在神经网络中,我们通常用线性层来完成两层神经元间的线性变换。
【学习笔记】【Pytorch】十、线性层_第1张图片
在普通神经网络里,输入是一个二维矩阵,不需要摊平。而在卷积神经网络里,在网络的最后几层里,会把卷积层摊平放到全连接里进行计算。
:当把一个输出张量从卷积层传递到线性层时,需要进行 flatten 操作,对一个张量进行flatten(扁平化)。

二、Pytorch的线性层

linear-layers

三、Linear类的使用

from torch.nn import Linear

官方解释

作用:对输入数据做线性变换:y=Ax+b。

1.使用说明

【实例化】Linear(in_features, out_features, bias=True, device=None, dtype=None)

  • 作用:创建一个实例。
  • in_features:输入结点数
    out_features:输出结点数
    bias :是否需要偏置
  • 计算公式:
    在这里插入图片描述

2.代码实现

import torch
import torchvision
from torch import nn
from torch.nn import Linear
from torch.utils.data import DataLoader


class Model(nn.Module):
    def __init__(self):
        super().__init__()  # 初始化父类参数
        self.linear1 = Linear(196608, 10)  # 创建一个实例

    def forward(self, input):
        output = self.linear1(input)
        return output


# 创建一个实例
dataset = torchvision.datasets.CIFAR10(root="./dataset", train=False,
                                       transform=torchvision.transforms.ToTensor())
# 创建一个实例,drop_last=True防止经过线性函数报错
dataloader = DataLoader(dataset, batch_size=64, drop_last=True)
model = Model()  # 创建一个实例

for data in dataloader:
    imgs, targets = data
    # batch_size, 通道数, 图片尺寸
    print(imgs.shape)  # torch.Size([64, 3, 32, 32])
    # 拉伸成向量,自动计算最后一个位置(64*3*32*32=196608)
    output = torch.reshape(imgs, (1, 1, 1, -1))
    print(output.shape)  # torch.Size([1, 1, 1, 196608])
    output = model(output)
    print(output.shape)  # torch.Size([1, 1, 1, 10])

输出

torch.Size([64, 3, 32, 32])
torch.Size([1, 1, 1, 196608])
torch.Size([1, 1, 1, 10])
torch.Size([64, 3, 32, 32])
torch.Size([1, 1, 1, 196608])
torch.Size([1, 1, 1, 10])
....
....

补充:

output = torch.reshape(imgs, (1, 1, 1, -1))
# 也可以换成
output = torch.flatten(imgs)

输出变成:

torch.Size([64, 3, 32, 32])
torch.Size([196608])
torch.Size([10])
torch.Size([64, 3, 32, 32])
torch.Size([196608])
torch.Size([10])
....
....

你可能感兴趣的:(Pytorch,pytorch,学习,深度学习)