Yolov3主干网络Darknet53网络结构复现,非常基础的写法
from torch import nn
from torch.nn import functional
import torch
class ConvolutionalLayers(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride, padding,bias=False):
super(ConvolutionalLayers, self).__init__()
self.sub_module = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding,bias=bias),
nn.BatchNorm2d(out_channels),
nn.LeakyReLU()
)
def forward(self,x):
return self.sub_module(x)
class Residual(nn.Module):
def __init__(self, in_channels):
super(Residual, self).__init__()
self.sub_module = nn.Sequential(
ConvolutionalLayers(in_channels, in_channels // 2, 1, 1, 0),
ConvolutionalLayers(in_channels // 2, in_channels, 3, 1, 1),
)
def forward(self, x):
return x + self.sub_module(x)
class Convolutional_Set(nn.Module):
def __init__(self, in_channels, out_channels):
super(Convolutional_Set, self).__init__()
self.sub_module = nn.Sequential(
ConvolutionalLayers(in_channels, out_channels, 1, 1, 0),
ConvolutionalLayers(out_channels, in_channels, 3, 1, 1),
ConvolutionalLayers(in_channels, out_channels, 1, 1, 0),
ConvolutionalLayers(out_channels, in_channels, 3, 1, 1),
ConvolutionalLayers(in_channels, out_channels, 1, 1, 0),
)
def forward(self, x):
return self.sub_module(x)
class UpSamplingLayers(nn.Module):
def __init__(self):
super(UpSamplingLayers, self).__init__()
def forward(self,x):
return functional.interpolate(x,scale_factor=2,mode='nearest')
class Darknet53(nn.Module):
def __init__(self):
super(Darknet53, self).__init__()
self.Residual_Block_52=nn.Sequential(
ConvolutionalLayers(in_channels=3, out_channels=32, kernel_size=3, stride=1, padding=1),
ConvolutionalLayers(in_channels=32, out_channels=64, kernel_size=3, stride=2, padding=1),
Residual(64),
ConvolutionalLayers(in_channels=64, out_channels=128, kernel_size=3, stride=2, padding=1),
Residual(128),
Residual(128),
ConvolutionalLayers(in_channels=128, out_channels=256, kernel_size=3, stride=2, padding=1),
Residual(256),
Residual(256),
Residual(256),
Residual(256),
Residual(256),
Residual(256),
Residual(256),
Residual(256),
)
self.Residual_Block_26=nn.Sequential(
ConvolutionalLayers(in_channels=256, out_channels=512, kernel_size=3, stride=2, padding=1),
Residual(512),
Residual(512),
Residual(512),
Residual(512),
Residual(512),
Residual(512),
Residual(512),
Residual(512),
)
self.Residual_Block_13 = nn.Sequential(
ConvolutionalLayers(in_channels=512, out_channels=1024, kernel_size=3, stride=2, padding=1),
Residual(1024),
Residual(1024),
Residual(1024),
Residual(1024),
)
#----------------------------------------------------------
self.convset_13=nn.Sequential(
Convolutional_Set(1024,512)
)
#Predict one
self.detetion_13=nn.Sequential(
ConvolutionalLayers(in_channels=512,out_channels=1024,kernel_size=3,stride=1,padding=1),
nn.Conv2d(1024,24,1,1,0)
)
self.up_13to26=nn.Sequential(
ConvolutionalLayers(512,256,3,1,1),
UpSamplingLayers()
)
#---------------------------------------------------------
self.convset_26 = nn.Sequential(
Convolutional_Set(768,256)
)
# Predict two
self.detetion_26 = nn.Sequential(
ConvolutionalLayers(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1),
nn.Conv2d(512, 24, 1, 1, 0)
)
self.up_26to52 = nn.Sequential(
ConvolutionalLayers(256, 128, 3, 1, 1),
UpSamplingLayers()
)
#------------------------------------------------------------
self.convset_52 = nn.Sequential(
Convolutional_Set(384, 128)
)
# Predict three
self.detetion_52 = nn.Sequential(
ConvolutionalLayers(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1),
nn.Conv2d(256, 24, 1, 1, 0)
)
def forward(self,x):
Residual_output_52 = self.Residual_Block_52(x)
Residual_output_26 = self.Residual_Block_26(Residual_output_52)
Residual_output_13 = self.Residual_Block_13(Residual_output_26)
convset_out_13 = self.convset_13(Residual_output_13)
detetion_out_13 = self.detetion_13(convset_out_13)
up_out_26 = self.up_13to26(convset_out_13)
route_out_26 = torch.cat((up_out_26,Residual_output_26), dim=1)
convset_out_26 = self.convset_26(route_out_26)
detetion_out_26 = self.detetion_26(convset_out_26)
up_out_52 = self.up_26to52(convset_out_26)
route_out_52 = torch.cat((up_out_52, Residual_output_52), dim=1)
convset_out_52 = self.convset_52(route_out_52)
detetion_out_52 = self.detetion_52(convset_out_52)
return detetion_out_13, detetion_out_26, detetion_out_52
if __name__ == '__main__':
yolo = Darknet53()
x = torch.randn(1, 3, 416, 416)
y = yolo(x)
print(y[0].shape)
print(y[1].shape)
print(y[2].shape)