在神经网络中,线性层(Linear Layer),也称为全连接层(Fully Connected Layer, FC Layer)或稠密层(Dense Layer),是最基础且核心的组件之一。它的作用是通过矩阵运算将输入数据映射到另一个维度空间,通常用于特征变换或分类/回归任务。
核心原理
数学定义:
给定输入向量 ( x ∈ R n \mathbf{x} \in \mathbb{R}^n x∈Rn),线性层的输出 ( y ∈ R m \mathbf{y} \in \mathbb{R}^m y∈Rm ) 通过以下公式计算:
[
y = W x + b \mathbf{y} = \mathbf{W}\mathbf{x} + \mathbf{b} y=Wx+b
]
( W ∈ R m × n \mathbf{W} \in \mathbb{R}^{m \times n} W∈Rm×n ):权重矩阵(可学习参数)。
( b ∈ R m \mathbf{b} \in \mathbb{R}^m b∈Rm ):偏置向量(可学习参数)。
每个输出神经元是输入特征的加权和加上偏置。
参数数量:
功能与特点
特征变换:
将输入从 ( n n n ) 维空间映射到 ( m m m ) 维空间(如降维或升维)。
例如:图像分类任务中,将展平后的像素向量映射到隐藏层或类别概率。
无内置非线性:
单纯线性层只能表示线性关系(仿射变换)。
通常与激活函数(如ReLU、Sigmoid)结合,引入非线性,使网络能拟合复杂函数。
通用性:
可用于神经网络的任何位置:
在PyTorch中的实现
import torch.nn as nn
# 定义线性层:输入维度=100,输出维度=50
linear_layer = nn.Linear(in_features=100, out_features=50)
# 前向传播示例
x = torch.randn(32, 100) # 假设batch_size=32
y = linear_layer(x) # 输出形状:[32, 50]
常见问题
应用场景
线性层是神经网络的基石,理解其原理有助于设计更复杂的模型结构。
Linear Layers:
输入数据的最后一个维度必须是Linear中定义的input维数(下面代码中为30)
import torch
from torch import nn
from torch.nn import Linear
class My_module(nn.Module):
def __init__(self):
super(My_module, self).__init__()
self.linear1 = Linear(30, 10)
def forward(self, input):
output = self.linear1(input)
return output
mymodule = My_module()
input = torch.randn(32,30)
print(input.shape)
output = mymodule(input)
print(output.shape)
torch.Size([32, 30])
torch.Size([32, 10])
可以看到由输入(32X5X30)到输出(32X5X10):
import torch
from torch import nn
from torch.nn import Linear
class My_module(nn.Module):
def __init__(self):
super(My_module, self).__init__()
self.linear1 = Linear(30, 10)
def forward(self, input):
output = self.linear1(input)
return output
mymodule = My_module()
input = torch.randn(32,5,30)
print(input.shape)
output = mymodule(input)
print(output.shape)
torch.Size([32, 5, 30])
torch.Size([32, 5, 10])
如果想让结果输出一个一维的值,用于分类或回归的结果输出,可以先将多维输入展平(使用torch.reshape或者直接用torch.flatten:
import torch
from torch import nn
from torch.nn import Linear
class My_module(nn.Module):
def __init__(self):
super(My_module, self).__init__()
self.linear1 = Linear(4800, 10)
def forward(self, input):
output = self.linear1(input)
return output
mymodule = My_module()
input = torch.randn(32,5,30)
print(input)
input = input.flatten()
#input = torch.flatten(input)
print(input)
print(input.shape)
output = mymodule(input)
print(output.shape)
tensor([[[ 0.1927, 1.5316, 0.8629, ..., -2.4938, 0.2859, -1.0615],
[-2.2116, -1.2984, 0.6821, ..., -0.3509, -0.2797, 0.8697],
[ 1.5197, -0.7962, -0.1739, ..., -1.2226, 0.6223, 0.4959],
[ 2.2151, 0.0110, -0.5783, ..., -1.5217, 0.6565, -0.6503],
[-0.1608, -0.0074, -1.7496, ..., 1.2666, 0.3784, 0.0275]],
[[ 1.6370, -0.9992, -1.5591, ..., -0.2811, -1.0423, -1.2642],
[-0.1373, 0.0053, 0.1932, ..., 0.4230, -3.9586, -2.1267],
[ 1.0752, -0.6097, -1.1407, ..., -0.5715, 1.7729, 1.1139],
[ 1.0184, 0.0116, -0.2647, ..., 0.3043, 1.3984, 0.9507],
[-0.1650, 0.4731, 2.0052, ..., 0.7604, -0.1749, 0.7843]],
[[ 0.1273, 0.5358, -0.4169, ..., 0.1524, -0.1577, 0.2616],
[-0.3028, 1.7451, 1.5555, ..., -0.9916, -0.4082, -0.0970],
[ 0.8116, 0.4476, -0.0110, ..., -0.2703, -1.0140, -0.3735],
[-0.3451, -0.3428, 1.8389, ..., 0.7206, -0.2909, -1.3641],
[ 0.5403, 0.0655, 0.0825, ..., 1.7568, -0.8547, -0.3213]],
...,
[[ 1.0200, 0.0375, -1.0133, ..., -0.6202, -1.0658, 2.5017],
[ 1.9269, 0.3800, 1.2563, ..., -1.3712, 0.4820, 0.0980],
[-0.1392, -1.4254, -1.6039, ..., -0.8112, 0.6561, -0.8304],
[-0.4198, 1.5363, 0.9587, ..., -0.7531, -0.0072, 1.9401],
[-0.5996, -1.0278, -0.7967, ..., -1.0011, 0.9552, 1.8910]],
[[ 3.1236, -0.8882, -1.2673, ..., -0.0620, -0.0640, -0.0558],
[-0.9593, 0.6215, 1.0444, ..., -1.0440, -0.1488, 0.9926],
[ 1.9661, -0.6619, -1.2196, ..., 0.6399, -0.2180, 0.1434],
[-0.0252, 0.4723, 0.4821, ..., -1.7556, 1.5335, -0.5048],
[-1.5792, 0.6036, 0.0374, ..., -0.3583, 0.7300, 0.5706]],
[[-0.9191, -1.0339, -1.0069, ..., -0.6775, 1.0001, -0.6501],
[-0.4620, 1.0355, 0.3432, ..., -1.4742, -0.0109, -0.0910],
[ 0.7351, -0.6463, -1.6483, ..., -0.6968, -2.3056, -4.0914],
[ 0.2381, 0.0088, 0.9528, ..., 0.7148, 0.1307, 0.4448],
[ 1.0053, 0.1132, -2.0195, ..., -0.7125, 0.1509, 1.7306]]])
tensor([ 0.1927, 1.5316, 0.8629, ..., -0.7125, 0.1509, 1.7306])
torch.Size([4800])
torch.Size([10])
import torch
from torch import nn
from torch.nn import Linear
class My_module(nn.Module):
def __init__(self):
super(My_module, self).__init__()
self.linear1 = Linear(4800, 10)
def forward(self, input):
output = self.linear1(input)
return output
mymodule = My_module()
input = torch.randn(32,5,30)
print(input)
input = torch.reshape(input, [1,1,1,-1])
print(input)
print(input.shape)
output = mymodule(input)
print(output.shape)
tensor([[[ 0.5807, 0.8254, -0.0113, ..., -0.1703, -1.5895, 0.0387],
[-0.0283, 0.1952, -0.1681, ..., -2.1661, 0.2926, -1.7320],
[ 0.6992, 0.1887, -1.1198, ..., -0.0889, 0.3059, -0.2803],
[ 0.6459, -1.1591, 1.2445, ..., -0.1613, -0.0503, 1.0880],
[ 0.5726, 0.8104, 0.5556, ..., 2.6553, 2.1057, -0.3247]],
[[-1.1444, -0.8201, -0.0893, ..., -0.7890, -1.5243, 0.2365],
[-1.5390, 0.5377, 1.2178, ..., -1.9277, -0.0877, -0.0614],
[-0.4113, 0.6755, -0.6351, ..., 1.5513, 0.2657, 0.8339],
[-0.5718, 2.0597, 1.2934, ..., 1.3997, -0.1791, 0.0965],
[-0.2670, 0.5852, -1.1489, ..., -0.1973, 0.2702, 1.2241]],
[[-0.6137, -0.5469, 0.7655, ..., -0.2240, -1.0068, 1.3224],
[-0.8234, -1.2284, -0.7768, ..., -0.9406, -0.6034, 1.2495],
[ 0.0499, 0.2673, -0.4787, ..., 1.2305, 1.4600, 0.9057],
[ 1.8762, 0.2788, -1.0163, ..., 1.9786, -0.3028, 0.1608],
[ 1.2778, -1.0815, 0.4281, ..., -1.3106, 1.1378, -1.4591]],
...,
[[-0.3145, 1.8475, 2.1150, ..., -0.1534, -1.9137, 0.1358],
[ 0.5194, -0.0613, -0.7031, ..., -1.1658, -0.0467, 0.9179],
[ 0.3834, 0.3050, 0.5577, ..., 0.4669, -0.7856, -0.6688],
[-0.5901, 1.1210, 1.2222, ..., -1.9253, -0.7322, 1.0523],
[-0.9621, -0.0254, 0.2266, ..., -1.3613, 1.1155, 0.6122]],
[[-1.6138, 2.4130, -0.0800, ..., 0.8938, -0.6971, 1.3274],
[ 0.1131, -0.7315, -0.2061, ..., -0.1776, -0.2116, -0.8136],
[ 1.9511, -0.2855, 0.1544, ..., -0.4052, 0.1731, -0.4174],
[ 1.0657, -1.4310, 0.7123, ..., -0.6069, 0.5625, 0.7002],
[-2.6849, 0.9532, 1.8795, ..., -1.3769, 0.2873, -0.1081]],
[[ 0.9848, -1.0543, 1.4870, ..., -0.5557, -0.8857, -1.5043],
[ 0.6338, 0.7136, 1.1014, ..., -1.2066, -1.1157, 1.2984],
[-1.0531, 0.6081, 1.5540, ..., 1.5034, -0.6660, 0.4444],
[ 0.5338, -0.6992, 0.9997, ..., -0.8787, 0.1580, 1.8274],
[ 0.0796, 1.3643, -1.1272, ..., 1.1547, -0.1451, 0.2355]]])
tensor([[[[ 0.5807, 0.8254, -0.0113, ..., 1.1547, -0.1451, 0.2355]]]])
torch.Size([1, 1, 1, 4800])
torch.Size([1, 1, 1, 10])