个人理解是 为了给神经网络引入一些非线性的特质
下面是查到的资料所得结果
如上图的神经网络,在正向传播过程中,若使用线性激活函数(恒等激励函数)
即令
则隐藏层的输出为
可以看到使用线性激活函数神经网络只是把输入线性组合再输出,所以当有很多隐藏层时,在隐藏层使用线性激活函数的训练效果和不使用隐藏层即 标准的Logistic回归是一样的。故我们要在隐藏层使用非线性激活函数而非线性的。
对于这些激励函数所对应的图像及作用,参考这篇大佬写的文章,在这里,我先了解下非线性激活,目前仅了解部分激励函数,也不太好总结
下面以Sigmoid和ReLU激励函数来熟悉下非线性激活
import torch
from torch import nn
from torch.nn import ReLU
input = torch.tensor([[1,-0.5],
[-1,3]])
input = torch.reshape(input,(-1,1,2,2))
class Test(nn.Module):
def __init__(self):
super(Test, self).__init__()
self.relu = ReLU()
def forward(self,input):
output = self.relu(input)
return output
test = Test()
output = test(input)
print(output)
输出:
tensor([[[[1., 0.],
[0., 3.]]]])
import torch
import torchvision
from torch import nn
from torch.nn import ReLU, Sigmoid
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
dataset = torchvision.datasets.CIFAR10("dataset2",train=False,transform=torchvision.transforms.ToTensor(),download=True)
dataloader = DataLoader(dataset,batch_size=64)
class Test(nn.Module):
def __init__(self):
super(Test, self).__init__()
self.sigmoid = Sigmoid()
def forward(self,input):
output = self.sigmoid(input)
return output
writer = SummaryWriter("logs")
test = Test()
step=0
for data in dataloader:
imgs,targets = data
writer.add_image("input",imgs,step,dataformats="NCHW")
output = test(imgs)
writer.add_image("output",output,step,dataformats="NCHW")
step+=1
writer.close()
线性层又称为全连接层,其每个神经元与上一层所有神经元相连,实现对前一层的线性组合,线性变换。
功能:在卷积神经网络进行分类时,输出分类结果前,通常采用全连接层对特征进行处理。Pytorch中全连接层又称为Linear线性层,因为如果不考虑激活函数的非线性性质,那么全连接层就是对输入数据进行线性组合,因此而得名"线性层"。
nn.Linear(in_features, out_features, bias=True)
in_features:输入结点数
out_features:输出结点数
bias :是否需要偏置
理解:隐藏层中的第一个数6是如何计算的,图中写的很清楚,1x1+2x1+3x1=6
代码:
import torch
from torch import nn
inputs = torch.tensor([[1., 2, 3]])
linear_layer = nn.Linear(3, 4)
linear_layer.weight.data = torch.tensor([[1., 1., 1.],
[2., 2., 2.],
[3., 3., 3.],
[4., 4., 4.]])
linear_layer.bias.data.fill_(0.5)
output = linear_layer(inputs)
print(inputs, inputs.shape)
print(linear_layer.weight.data, linear_layer.weight.data.shape)
print(output, output.shape)