【Pytorch】神经网络-非线性激活 - 学习笔记

目的:非线性激活的主要作用是提高泛化能力。

视频网址

首先来看看官方文档(以ReLU为例)
【Pytorch】神经网络-非线性激活 - 学习笔记_第1张图片
其中要注意到参数:inplace,可以举例子解释一下
相当于输出是否覆盖输入,一般情况下inplace=False(默认值)
【Pytorch】神经网络-非线性激活 - 学习笔记_第2张图片
代码

import torch
import time

from torch import nn
from torch.nn import ReLU

start = time.time()

input = torch.tensor([[1, -0.5],
                      [-1, 3]])

input = torch.reshape(input, (-1, 1, 2, 2))
print(input.shape)


class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.relu1 = ReLU()

    def forward(self, input):
        output = self.relu1(input)
        return output

model = Model()
output = model(input)
print(output)

end = time.time()

print('Running time: %s Seconds' % (end - start))

输出结果为

D:\Anaconda3\envs\pytorch\python.exe D:/研究生/代码尝试/nn_relu.py
torch.Size([1, 1, 2, 2])
tensor([[[[1., 0.],
          [0., 3.]]]])
Running time: 0.05785083770751953 Seconds

进程已结束,退出代码为 0

以sigmoid函数为例,来解释一下非线性激活的作用

import torch
import time

import torchvision
from torch import nn
from torch.nn import ReLU, Sigmoid
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

start = time.time()

dataset = torchvision.datasets.CIFAR10("./dataset", train=False, download=False, transform=torchvision.transforms.ToTensor())

dataloader = DataLoader(dataset, batch_size=64)

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.relu1 = ReLU()
        self.sigmoid1 = Sigmoid()

    def forward(self, input):
        output = self.sigmoid1(input)
        return output

model = Model()

step = 0
writer = SummaryWriter("./logs_relu")
for data in dataloader:
    imgs, targets = data
    writer.add_images("input", imgs, global_step=step)
    output = model(imgs)
    writer.add_images("output", output, step)
    step += 1

writer.close()

end = time.time()

print('Running time: %s Seconds' % (end - start))

打开Tensorboard
【Pytorch】神经网络-非线性激活 - 学习笔记_第3张图片
【Pytorch】神经网络-非线性激活 - 学习笔记_第4张图片
可以看出,图像都变灰了

你可能感兴趣的:(Pytorch,深度学习,Python,pytorch,神经网络,深度学习)