绘制yolov5自带的几种激活函数

参考链接

  • yolov5中的几种激活函数介绍:Yolov5如何更换激活函数?
  • 将激活函数绘制成图的代码:github:Hardswish-ReLU6-SiLU-Mish-Activation-Function
  • 常用的激活函数Sigmoid,ReLU,Swish,Mish,GELU

只绘制出了如下的4个函数:(555,太菜了)

绘制yolov5自带的几种激活函数_第1张图片

分开的摸样:(分开就直接注释掉几行代码就行了哈哈哈)

绘制yolov5自带的几种激活函数_第2张图片

绘制yolov5自带的几种激活函数_第3张图片

绘制yolov5自带的几种激活函数_第4张图片

绘制yolov5自带的几种激活函数_第5张图片

代码:

import torch
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
from pylab import *

"""
conference link:
    - 最新的激活函数:https://yolov5.blog.csdn.net/article/details/124413941
    - 将激活函数绘制成图的代码:https://github.com/luokai-dandan/Hardswish-ReLU6-SiLU-Mish-Activation-Function
"""


# 1. SiLU
# SiLU https://arxiv.org/pdf/1606.08415.pdf
class SiLU(nn.Module):  # export-friendly version of nn.SiLU()
    @staticmethod
    def forward(x):
        return x * torch.sigmoid(x)


# 2. Hardswish
class Hardswish(nn.Module):  # export-friendly version of nn.Hardswish()
    @staticmethod
    def forward(x):
        # return x * F.hardsigmoid(x)  # for TorchScript and CoreML
        return x * F.hardtanh(x + 3, 0.0, 6.0) / 6.0  # for TorchScript, CoreML and ONNX


# 3. Mish
# Mish https://github.com/digantamisra98/Mish
class Mish(nn.Module):
    @staticmethod
    def forward(x):
        return x * F.softplus(x).tanh()


# 4. MemoryEfficientMish
class MemoryEfficientMish(nn.Module):
    class F(torch.autograd.Function):
        @staticmethod
        def forward(ctx, x):
            ctx.save_for_backward(x)
            return x.mul(torch.tanh(F.softplus(x)))  # x * tanh(ln(1 + exp(x)))

        @staticmethod
        def backward(ctx, grad_output):
            x = ctx.saved_tensors[0]
            sx = torch.sigmoid(x)
            fx = F.softplus(x).tanh()
            return grad_output * (fx + x * sx * (1 - fx * fx))

    def forward(self, x):
        return self.F.apply(x)

# 5. FReLU
# FReLU https://arxiv.org/abs/2007.11824
class FReLU(nn.Module):
    def __init__(self, c1, k=3):  # ch_in, kernel
        super().__init__()
        self.conv = nn.Conv2d(c1, c1, k, 1, 1, groups=c1, bias=False)
        self.bn = nn.BatchNorm2d(c1)

    def forward(self, x):
        return torch.max(x, self.bn(self.conv(x)))


# 6. AconC
class AconC(nn.Module):
    r""" ACON activation (activate or not).
    AconC: (p1*x-p2*x) * sigmoid(beta*(p1*x-p2*x)) + p2*x, beta is a learnable parameter
    according to "Activate or Not: Learning Customized Activation" .
    """

    def __init__(self, c1):
        super().__init__()
        self.p1 = nn.Parameter(torch.randn(1, c1, 1, 1))
        self.p2 = nn.Parameter(torch.randn(1, c1, 1, 1))
        self.beta = nn.Parameter(torch.ones(1, c1, 1, 1))

    def forward(self, x):
        dpx = (self.p1 - self.p2) * x
        return dpx * torch.sigmoid(self.beta * dpx) + self.p2 * x

def ReLU6(x) -> list:
    x = x.numpy()
    y = []
    for v in x:
        if v <= 0:
            y.append(0)
        elif v > 0 and v <= 6:
            y.append(v)
        else:
            y.append(6)
    return y


def Curve(x):
    # 设置上下界
    silu = SiLU()
    y1 = silu(x)

    # ---------设置上下轴范围---------- #
    x_min, x_max = x.min(), x.max()
    y_min, y_max = x_min, x_max

    xlim(x_min, x_max)
    ylim(y_min, y_max)
    # ------------------------------ #

    hardswish = Hardswish()
    y2 = hardswish(x)

    mish = Mish()
    y3 = mish(x)

    y4 = ReLU6(x)

    # memory = MemoryEfficientMish()
    # y4 = memory(x)

    # y5 = FReLU(x)

    # y6 = AconC(x)

    plt.title('Hardswish+ReLU6+SiLU')
    plt.plot(x, y1, color='green', label='SiLU')
    plt.plot(x, y2, color='blue', label='Hardswish')
    plt.plot(x, y3, color='red', label='Mish')
    plt.plot(x, y4, color='orange', label='ReLU6')
    # plt.plot(x, y4, color='purple', label='memory')
    # plt.plot(x, y5, color='yellow', label='Mish')
    # plt.plot(x, y6, color='black', label='AconC')


    plt.legend()
    plt.grid(True, linestyle='--', alpha=0.8)

    plt.xlabel('x')
    plt.ylabel('y')
    plt.show()


if __name__ == '__main__':
    x = torch.linspace(-6, 6, 10000)
    Curve(x)

你可能感兴趣的:(深度学习项目经验tips,python,深度学习)