torch.nn中的线性层的使用

1、神经网络中的线性层

在神经网络中,线性层(Linear Layer),也称为全连接层(Fully Connected Layer, FC Layer)或稠密层(Dense Layer),是最基础且核心的组件之一。它的作用是通过矩阵运算将输入数据映射到另一个维度空间,通常用于特征变换或分类/回归任务。

核心原理
数学定义:

给定输入向量 ( x ∈ R n \mathbf{x} \in \mathbb{R}^n xRn),线性层的输出 ( y ∈ R m \mathbf{y} \in \mathbb{R}^m yRm ) 通过以下公式计算:
[
y = W x + b \mathbf{y} = \mathbf{W}\mathbf{x} + \mathbf{b} y=Wx+b
]
( W ∈ R m × n \mathbf{W} \in \mathbb{R}^{m \times n} WRm×n ):权重矩阵(可学习参数)。
( b ∈ R m \mathbf{b} \in \mathbb{R}^m bRm ):偏置向量(可学习参数)。
每个输出神经元是输入特征的加权和加上偏置。

参数数量:

  • 权重参数:( m × n m \times n m×n ) 个。
  • 偏置参数:( m m m ) 个。
  • 总参数量:( m × n + m m \times n + m m×n+m )。

功能与特点
特征变换:
将输入从 ( n n n ) 维空间映射到 ( m m m ) 维空间(如降维或升维)。
例如:图像分类任务中,将展平后的像素向量映射到隐藏层或类别概率。

无内置非线性:
单纯线性层只能表示线性关系(仿射变换)。
通常与激活函数(如ReLU、Sigmoid)结合,引入非线性,使网络能拟合复杂函数。

通用性:
可用于神经网络的任何位置:

  • 输入层后:处理原始特征。
  • 隐藏层中:逐步抽象特征。
  • 输出层前:生成最终预测(如分类任务的logits)。

在PyTorch中的实现

import torch.nn as nn

# 定义线性层:输入维度=100,输出维度=50
linear_layer = nn.Linear(in_features=100, out_features=50)

# 前向传播示例
x = torch.randn(32, 100)  # 假设batch_size=32
y = linear_layer(x)       # 输出形状:[32, 50]

常见问题

  • 为什么需要偏置(bias)?
    偏置允许模型在输入全为0时仍有输出,增强表达能力(类似截距项)。
  • 与卷积层的区别?
    卷积层通过局部感受野和权值共享处理空间数据(如图像),而线性层对所有输入全局连接。
  • 参数量过大怎么办?
    高维输入(如图像)直接接线性层会导致参数量爆炸,通常先用卷积层/池化层降维。

应用场景

  • 分类任务:输出层接Softmax生成类别概率。
  • 回归任务:直接输出连续值(如房价预测)。
  • 嵌入层:将离散特征(如词ID)映射为稠密向量。

线性层是神经网络的基石,理解其原理有助于设计更复杂的模型结构。

2、torch.nn中的线性层

Linear Layers:

  • nn.Identity
  • nn.Linear
  • nn.Bilinear
  • nn.LazyLinear

3、torch.nn中线性层的使用(以nn.Linear为例)

输入数据的最后一个维度必须是Linear中定义的input维数(下面代码中为30)

  1. 输入数据是批次大小为32的30维张量数据,可以看到由输入(32X30)到输出(32X10):
import torch
from torch import nn
from torch.nn import Linear

class My_module(nn.Module):
    def __init__(self):
        super(My_module, self).__init__()
        self.linear1 = Linear(30, 10)

    def forward(self, input):
        output = self.linear1(input)
        return output

mymodule = My_module()

input = torch.randn(32,30)
print(input.shape)

output = mymodule(input)
print(output.shape)
torch.Size([32, 30])
torch.Size([32, 10])
  1. 输入数据是批次大小为32的5X30维张量数据):

可以看到由输入(32X5X30)到输出(32X5X10):

import torch
from torch import nn
from torch.nn import Linear

class My_module(nn.Module):
    def __init__(self):
        super(My_module, self).__init__()
        self.linear1 = Linear(30, 10)

    def forward(self, input):
        output = self.linear1(input)
        return output

mymodule = My_module()

input = torch.randn(32,5,30)
print(input.shape)

output = mymodule(input)
print(output.shape)
torch.Size([32, 5, 30])
torch.Size([32, 5, 10])

如果想让结果输出一个一维的值,用于分类或回归的结果输出,可以先将多维输入展平(使用torch.reshape或者直接用torch.flatten:

  • torch.flatten
import torch
from torch import nn
from torch.nn import Linear

class My_module(nn.Module):
    def __init__(self):
        super(My_module, self).__init__()
        self.linear1 = Linear(4800, 10)

    def forward(self, input):
        output = self.linear1(input)
        return output

mymodule = My_module()

input = torch.randn(32,5,30)
print(input)
input = input.flatten()  
#input = torch.flatten(input)

print(input)
print(input.shape)

output = mymodule(input)
print(output.shape)
tensor([[[ 0.1927,  1.5316,  0.8629,  ..., -2.4938,  0.2859, -1.0615],
         [-2.2116, -1.2984,  0.6821,  ..., -0.3509, -0.2797,  0.8697],
         [ 1.5197, -0.7962, -0.1739,  ..., -1.2226,  0.6223,  0.4959],
         [ 2.2151,  0.0110, -0.5783,  ..., -1.5217,  0.6565, -0.6503],
         [-0.1608, -0.0074, -1.7496,  ...,  1.2666,  0.3784,  0.0275]],

        [[ 1.6370, -0.9992, -1.5591,  ..., -0.2811, -1.0423, -1.2642],
         [-0.1373,  0.0053,  0.1932,  ...,  0.4230, -3.9586, -2.1267],
         [ 1.0752, -0.6097, -1.1407,  ..., -0.5715,  1.7729,  1.1139],
         [ 1.0184,  0.0116, -0.2647,  ...,  0.3043,  1.3984,  0.9507],
         [-0.1650,  0.4731,  2.0052,  ...,  0.7604, -0.1749,  0.7843]],

        [[ 0.1273,  0.5358, -0.4169,  ...,  0.1524, -0.1577,  0.2616],
         [-0.3028,  1.7451,  1.5555,  ..., -0.9916, -0.4082, -0.0970],
         [ 0.8116,  0.4476, -0.0110,  ..., -0.2703, -1.0140, -0.3735],
         [-0.3451, -0.3428,  1.8389,  ...,  0.7206, -0.2909, -1.3641],
         [ 0.5403,  0.0655,  0.0825,  ...,  1.7568, -0.8547, -0.3213]],

        ...,

        [[ 1.0200,  0.0375, -1.0133,  ..., -0.6202, -1.0658,  2.5017],
         [ 1.9269,  0.3800,  1.2563,  ..., -1.3712,  0.4820,  0.0980],
         [-0.1392, -1.4254, -1.6039,  ..., -0.8112,  0.6561, -0.8304],
         [-0.4198,  1.5363,  0.9587,  ..., -0.7531, -0.0072,  1.9401],
         [-0.5996, -1.0278, -0.7967,  ..., -1.0011,  0.9552,  1.8910]],

        [[ 3.1236, -0.8882, -1.2673,  ..., -0.0620, -0.0640, -0.0558],
         [-0.9593,  0.6215,  1.0444,  ..., -1.0440, -0.1488,  0.9926],
         [ 1.9661, -0.6619, -1.2196,  ...,  0.6399, -0.2180,  0.1434],
         [-0.0252,  0.4723,  0.4821,  ..., -1.7556,  1.5335, -0.5048],
         [-1.5792,  0.6036,  0.0374,  ..., -0.3583,  0.7300,  0.5706]],

        [[-0.9191, -1.0339, -1.0069,  ..., -0.6775,  1.0001, -0.6501],
         [-0.4620,  1.0355,  0.3432,  ..., -1.4742, -0.0109, -0.0910],
         [ 0.7351, -0.6463, -1.6483,  ..., -0.6968, -2.3056, -4.0914],
         [ 0.2381,  0.0088,  0.9528,  ...,  0.7148,  0.1307,  0.4448],
         [ 1.0053,  0.1132, -2.0195,  ..., -0.7125,  0.1509,  1.7306]]])
tensor([ 0.1927,  1.5316,  0.8629,  ..., -0.7125,  0.1509,  1.7306])
torch.Size([4800])
torch.Size([10])
  • torch.reshape
import torch
from torch import nn
from torch.nn import Linear

class My_module(nn.Module):
    def __init__(self):
        super(My_module, self).__init__()
        self.linear1 = Linear(4800, 10)

    def forward(self, input):
        output = self.linear1(input)
        return output

mymodule = My_module()

input = torch.randn(32,5,30)
print(input)
input = torch.reshape(input, [1,1,1,-1])
print(input)
print(input.shape)

output = mymodule(input)
print(output.shape)
tensor([[[ 0.5807,  0.8254, -0.0113,  ..., -0.1703, -1.5895,  0.0387],
         [-0.0283,  0.1952, -0.1681,  ..., -2.1661,  0.2926, -1.7320],
         [ 0.6992,  0.1887, -1.1198,  ..., -0.0889,  0.3059, -0.2803],
         [ 0.6459, -1.1591,  1.2445,  ..., -0.1613, -0.0503,  1.0880],
         [ 0.5726,  0.8104,  0.5556,  ...,  2.6553,  2.1057, -0.3247]],

        [[-1.1444, -0.8201, -0.0893,  ..., -0.7890, -1.5243,  0.2365],
         [-1.5390,  0.5377,  1.2178,  ..., -1.9277, -0.0877, -0.0614],
         [-0.4113,  0.6755, -0.6351,  ...,  1.5513,  0.2657,  0.8339],
         [-0.5718,  2.0597,  1.2934,  ...,  1.3997, -0.1791,  0.0965],
         [-0.2670,  0.5852, -1.1489,  ..., -0.1973,  0.2702,  1.2241]],

        [[-0.6137, -0.5469,  0.7655,  ..., -0.2240, -1.0068,  1.3224],
         [-0.8234, -1.2284, -0.7768,  ..., -0.9406, -0.6034,  1.2495],
         [ 0.0499,  0.2673, -0.4787,  ...,  1.2305,  1.4600,  0.9057],
         [ 1.8762,  0.2788, -1.0163,  ...,  1.9786, -0.3028,  0.1608],
         [ 1.2778, -1.0815,  0.4281,  ..., -1.3106,  1.1378, -1.4591]],

        ...,

        [[-0.3145,  1.8475,  2.1150,  ..., -0.1534, -1.9137,  0.1358],
         [ 0.5194, -0.0613, -0.7031,  ..., -1.1658, -0.0467,  0.9179],
         [ 0.3834,  0.3050,  0.5577,  ...,  0.4669, -0.7856, -0.6688],
         [-0.5901,  1.1210,  1.2222,  ..., -1.9253, -0.7322,  1.0523],
         [-0.9621, -0.0254,  0.2266,  ..., -1.3613,  1.1155,  0.6122]],

        [[-1.6138,  2.4130, -0.0800,  ...,  0.8938, -0.6971,  1.3274],
         [ 0.1131, -0.7315, -0.2061,  ..., -0.1776, -0.2116, -0.8136],
         [ 1.9511, -0.2855,  0.1544,  ..., -0.4052,  0.1731, -0.4174],
         [ 1.0657, -1.4310,  0.7123,  ..., -0.6069,  0.5625,  0.7002],
         [-2.6849,  0.9532,  1.8795,  ..., -1.3769,  0.2873, -0.1081]],

        [[ 0.9848, -1.0543,  1.4870,  ..., -0.5557, -0.8857, -1.5043],
         [ 0.6338,  0.7136,  1.1014,  ..., -1.2066, -1.1157,  1.2984],
         [-1.0531,  0.6081,  1.5540,  ...,  1.5034, -0.6660,  0.4444],
         [ 0.5338, -0.6992,  0.9997,  ..., -0.8787,  0.1580,  1.8274],
         [ 0.0796,  1.3643, -1.1272,  ...,  1.1547, -0.1451,  0.2355]]])
tensor([[[[ 0.5807,  0.8254, -0.0113,  ...,  1.1547, -0.1451,  0.2355]]]])
torch.Size([1, 1, 1, 4800])
torch.Size([1, 1, 1, 10])

你可能感兴趣的:(Pytorch实战,pytorch,python,深度学习,人工智能,机器学习,cnn,神经网络)