只绘制出了如下的4个函数:(555,太菜了)
分开的摸样:(分开就直接注释掉几行代码就行了哈哈哈)
代码:
import torch
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
from pylab import *
"""
conference link:
- 最新的激活函数:https://yolov5.blog.csdn.net/article/details/124413941
- 将激活函数绘制成图的代码:https://github.com/luokai-dandan/Hardswish-ReLU6-SiLU-Mish-Activation-Function
"""
# 1. SiLU
# SiLU https://arxiv.org/pdf/1606.08415.pdf
class SiLU(nn.Module): # export-friendly version of nn.SiLU()
@staticmethod
def forward(x):
return x * torch.sigmoid(x)
# 2. Hardswish
class Hardswish(nn.Module): # export-friendly version of nn.Hardswish()
@staticmethod
def forward(x):
# return x * F.hardsigmoid(x) # for TorchScript and CoreML
return x * F.hardtanh(x + 3, 0.0, 6.0) / 6.0 # for TorchScript, CoreML and ONNX
# 3. Mish
# Mish https://github.com/digantamisra98/Mish
class Mish(nn.Module):
@staticmethod
def forward(x):
return x * F.softplus(x).tanh()
# 4. MemoryEfficientMish
class MemoryEfficientMish(nn.Module):
class F(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return x.mul(torch.tanh(F.softplus(x))) # x * tanh(ln(1 + exp(x)))
@staticmethod
def backward(ctx, grad_output):
x = ctx.saved_tensors[0]
sx = torch.sigmoid(x)
fx = F.softplus(x).tanh()
return grad_output * (fx + x * sx * (1 - fx * fx))
def forward(self, x):
return self.F.apply(x)
# 5. FReLU
# FReLU https://arxiv.org/abs/2007.11824
class FReLU(nn.Module):
def __init__(self, c1, k=3): # ch_in, kernel
super().__init__()
self.conv = nn.Conv2d(c1, c1, k, 1, 1, groups=c1, bias=False)
self.bn = nn.BatchNorm2d(c1)
def forward(self, x):
return torch.max(x, self.bn(self.conv(x)))
# 6. AconC
class AconC(nn.Module):
r""" ACON activation (activate or not).
AconC: (p1*x-p2*x) * sigmoid(beta*(p1*x-p2*x)) + p2*x, beta is a learnable parameter
according to "Activate or Not: Learning Customized Activation" .
"""
def __init__(self, c1):
super().__init__()
self.p1 = nn.Parameter(torch.randn(1, c1, 1, 1))
self.p2 = nn.Parameter(torch.randn(1, c1, 1, 1))
self.beta = nn.Parameter(torch.ones(1, c1, 1, 1))
def forward(self, x):
dpx = (self.p1 - self.p2) * x
return dpx * torch.sigmoid(self.beta * dpx) + self.p2 * x
def ReLU6(x) -> list:
x = x.numpy()
y = []
for v in x:
if v <= 0:
y.append(0)
elif v > 0 and v <= 6:
y.append(v)
else:
y.append(6)
return y
def Curve(x):
# 设置上下界
silu = SiLU()
y1 = silu(x)
# ---------设置上下轴范围---------- #
x_min, x_max = x.min(), x.max()
y_min, y_max = x_min, x_max
xlim(x_min, x_max)
ylim(y_min, y_max)
# ------------------------------ #
hardswish = Hardswish()
y2 = hardswish(x)
mish = Mish()
y3 = mish(x)
y4 = ReLU6(x)
# memory = MemoryEfficientMish()
# y4 = memory(x)
# y5 = FReLU(x)
# y6 = AconC(x)
plt.title('Hardswish+ReLU6+SiLU')
plt.plot(x, y1, color='green', label='SiLU')
plt.plot(x, y2, color='blue', label='Hardswish')
plt.plot(x, y3, color='red', label='Mish')
plt.plot(x, y4, color='orange', label='ReLU6')
# plt.plot(x, y4, color='purple', label='memory')
# plt.plot(x, y5, color='yellow', label='Mish')
# plt.plot(x, y6, color='black', label='AconC')
plt.legend()
plt.grid(True, linestyle='--', alpha=0.8)
plt.xlabel('x')
plt.ylabel('y')
plt.show()
if __name__ == '__main__':
x = torch.linspace(-6, 6, 10000)
Curve(x)