class YoloBody(nn.Module):
def __init__(self, num_anchors, num_classes):
super(YoloBody, self).__init__()
# backbone
self.backbone = darknet53(None)
self.conv1 = make_three_conv([512,1024],1024)
self.SPP = SpatialPyramidPooling()
self.conv2 = make_three_conv([512,1024],2048)
self.upsample1 = Upsample(512,256)
self.conv_for_P4 = conv2d(512,256,1)
self.make_five_conv1 = make_five_conv([256, 512],512)
self.upsample2 = Upsample(256,128)
self.conv_for_P3 = conv2d(256,128,1)
self.make_five_conv2 = make_five_conv([128, 256],256)
# 3*(5+num_classes)=3*(5+20)=3*(4+1+20)=75
# 4+1+num_classes
final_out_filter2 = num_anchors * (5 + num_classes)
self.yolo_head3 = yolo_head([256, final_out_filter2],128)
self.down_sample1 = conv2d(128,256,3,stride=2)
self.make_five_conv3 = make_five_conv([256, 512],512)
# 3*(5+num_classes)=3*(5+20)=3*(4+1+20)=75
final_out_filter1 = num_anchors * (5 + num_classes)
self.yolo_head2 = yolo_head([512, final_out_filter1],256)
self.down_sample2 = conv2d(256,512,3,stride=2)
self.make_five_conv4 = make_five_conv([512, 1024],1024)
# 3*(5+num_classes)=3*(5+20)=3*(4+1+20)=75
final_out_filter0 = num_anchors * (5 + num_classes)
self.yolo_head1 = yolo_head([1024, final_out_filter0],512)
def forward(self, x):
# backbone
x2, x1, x0 = self.backbone(x)
P5 = self.conv1(x0)
P5 = self.SPP(P5)
P5 = self.conv2(P5)
P5_upsample = self.upsample1(P5)
P4 = self.conv_for_P4(x1)
P4 = torch.cat([P4,P5_upsample],axis=1)
P4 = self.make_five_conv1(P4)
P4_upsample = self.upsample2(P4)
P3 = self.conv_for_P3(x2)
P3 = torch.cat([P3,P4_upsample],axis=1)
P3 = self.make_five_conv2(P3)
P3_downsample = self.down_sample1(P3)
P4 = torch.cat([P3_downsample,P4],axis=1)
P4 = self.make_five_conv3(P4)
P4_downsample = self.down_sample2(P4)
P5 = torch.cat([P4_downsample,P5],axis=1)
P5 = self.make_five_conv4(P5)
out2 = self.yolo_head3(P3)
out1 = self.yolo_head2(P4)
out0 = self.yolo_head1(P5)
return out0, out1, out2
class YoloV6(nn.Module):
def __init__(self, cfg, ch=3):
super(YoloV6, self).__init__()
with open(cfg) as f:
self.md = yaml.load(f, Loader=yaml.FullLoader)
self.nc = self.md['nc']
self.anchors = self.md['anchors']
self.na = len(self.anchors[0]) // 2 # number of anchors
self.backbone = resnet18()
self.backbone = IntermediateLayerGetter(
self.backbone, {'layer2': 0, 'layer3': 1, 'layer4': 2})
# FPN
in_channels = [128, 256, 512]
self.fpn = FeaturePyramidNetwork(in_channels_list=in_channels, out_channels=256)
self.detect = Detect(self.nc, self.anchors, [256, 256, 256])
# forward to get Detect lay params dynamically
s = 512 # 2x min stride
self.detect.stride = torch.tensor(
[s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))]) # forward
self.detect.anchors /= self.detect.stride.view(-1, 1, 1)
check_anchor_order(self.detect)
self.stride = self.detect.stride
self._initialize_biases()
initialize_weights(self)
# for compatible, in check_anchors in train.py
self.model = [self.detect]
def _initialize_biases(self, cf=None):
m = self.detect # Detect() module
for mi, s in zip(m.m, m.stride): # from
b = mi.bias.view(m.na, -1) # conv.bias(255) to (3,85)
# obj (8 objects per 640 image)
b[:, 4] += math.log(8 / (640 / s) ** 2)
b[:, 5:] += math.log(0.6 / (m.nc - 0.99)
) if cf is None else torch.log(cf / cf.sum()) # cls
mi.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
def forward(self, x, augment=False):
# we not using augment at all
feas = self.backbone(x)
a = self.fpn(feas)
# for k, v in a.items():
# print(k, v.shape)
x_s, x_m, x_l = a[0], a[1], a[2]
# print('p3 ', p3.shape)
# print('p4 ', p4.shape)
# print('p5 ', p5.shape)
# print('feas ', feas.shape)
# x_s, x_m, x_l = self._build_head(p3, p4, p5, feas)
x = self.detect([x_s, x_m, x_l])
return x
detect实现
class Detect(nn.Module):
stride = None # strides computed during build
export = False # onnx export
def __init__(self, nc=80, anchors=(), ch=()): # detection layer
super(Detect, self).__init__()
self.nc = nc # number of classes
self.no = nc + 5 # number of outputs per anchor
self.nl = len(anchors) # number of detection layers
self.na = len(anchors[0]) // 2 # number of anchors
self.grid = [torch.zeros(1)] * self.nl # init grid
a = torch.tensor(anchors).float().view(self.nl, -1, 2)
self.register_buffer('anchors', a) # shape(nl,na,2)
self.register_buffer('anchor_grid', a.clone().view(self.nl, 1, -1, 1, 1, 2)) # shape(nl,1,na,1,1,2)
self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv
self.anchor_grid_awesome = torch.tensor(anchors).to(device).view(self.nl, 1, -1, 1, 1, 2)
def forward(self, x):
# x = x.copy() # for profiling
z = [] # inference output
# self.training |= self.export
for i in range(self.nl):
x[i] = self.m[i](x[i]) # conv
if self.export:
print('exporting...')
s = x[i].shape
bs, _, ny, nx = s # x(bs,255,20,20) to x(bs,3,20,20,85)
x_i = x[i]
x_i = x_i.view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
else:
bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85)
x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
if not self.training or self.export: # inference
if self.export:
self.grid = [i.to(x_i.device) for i in self.grid]
self.a = self.anchor_grid_awesome[i]
if self.grid[i].shape[2:4] != x_i.shape[2:4]:
self.grid[i] = self._make_grid(nx, ny).to(x_i.device)
y = x_i.sigmoid()
print('[WARN] you are calling export...')
x1y1, x2y2, conf, prob = torch.split(y, [2, 2, 1, self.nc], dim=4)
x1y1 = ((x1y1*2. - 0.5 + self.grid[i].to(x_i.device)) * self.stride[i]).type(x_i.dtype)
x2y2 = (x2y2 * 2) ** 2 * self.a
xyxy = torch.cat((x1y1, x2y2), dim=4)
# # add a idx (label ids before prob)
idxs = torch.argmax(prob, dim=-1).unsqueeze(axis=-1).type(x_i.dtype).to(x_i.device)
y = torch.cat((xyxy, conf, idxs, prob), dim=4).to(x_i.device)
# we added idxs so no+1
z.append(y.view(bs, -1, self.no+1))
else:
if self.grid[i].shape[2:4] != x[i].shape[2:4]:
self.grid[i] = self._make_grid(nx, ny).to(x[i].device)
y = x[i].sigmoid()
y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i].to(x[i].device)) * self.stride[i] # xy
y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
z.append(y.view(bs, -1, self.no))
if self.training:
return x if self.training else (torch.cat(z, 1), x)
elif self.export:
return torch.cat(z, 1)
else:
return (torch.cat(z, 1), x)
@staticmethod
def _make_grid(nx=20, ny=20):
yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()