可视化学习笔记4-pytorch可视化激活函数(relu、relu6、leakyrelu、hardswish、Mish)代码

源代码

import torch
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F


class ReLU(nn.Module):
    def __init__(self):
        super(ReLU, self).__init__()

    def forward(self, input):
        return F.relu(input)


# def ReLU6(x) -> list:
#     x = x.numpy()
#     y = []
#     for v in x:
#         if v <= 0:
#             y.append(0)
#         elif v > 0 and v <= 6:
#             y.append(v)
#         else:
#             y.append(6)
#     return y
class LeakyReLU(nn.Module):
    def __init__(self):
        super(LeakyReLU, self).__init__()

    def forward(self, input):
        return F.leaky_relu(input)

class Hardswish(nn.Module):
    def __init__(self):
        super(Hardswish, self).__init__()

    def forward(self, input):
        return F.hardswish(input)

class Mish(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self,x):
        x = x * (torch.tanh(F.softplus(x)))
        return x


def Curve(x):
    relu = ReLU()
    y0 = relu(x)

    # y1 = ReLU6(x)
    leakyrelu = LeakyReLU()
    y1 = leakyrelu(x)

    hardswish = Hardswish()
    y2 = hardswish(x)

    mish = Mish()
    y3 = mish(x)

    plt.title('ReLU+LeakyReLU+Hardswish')
    plt.plot(x, y0, color='blue', label='ReLU')
    plt.plot(x, y1 ,'--',color='red', label='LeakyReLU')
    plt.plot(x, y2, '-.', color='green', label='Hardswish')
    # plt.plot(x, y3, color='purple', label='Mish')
    plt.legend()
    plt.grid(True, linestyle='--', alpha=0.8)


    plt.xlabel('x')
    plt.ylabel('y')
    plt.show()

下图是ReLU+LeakyReLU+Hardswish的可视化结果,想要可视化什么激活函数就改F.leaky_relu(input)这处就行。F指的是torch.nn.functional里面包含了很多激活函数直接调用就行。
可视化学习笔记4-pytorch可视化激活函数(relu、relu6、leakyrelu、hardswish、Mish)代码_第1张图片

你可能感兴趣的:(可视化学习,pytorch,学习,深度学习)