论文: https://arxiv.org/abs/1506.02640https://arxiv.org/abs/1506.02640target的格式:7 * 7 * 30 前20个是类别,然后是
[box1_confidence , x,y,w,h,box2_confidence,x,y,w,h]
记得对输入图片进行resize处理
backbone
import torch
import torch.nn as nn
architecture_config = [
(7, 64, 2, 3),
"M",
(3, 192, 1, 1),
"M",
(1, 128, 1, 0),
(3, 256, 1, 1),
(1, 256, 1, 0),
(3, 512, 1, 1),
"M",
[(1, 256, 1, 0), (3, 512, 1, 1), 4],
(1, 512, 1, 0),
(3, 1024, 1, 1),
"M",
[(1, 512, 1, 0), (3, 1024, 1, 1), 2],
(3, 1024, 1, 1),
(3, 1024, 2, 1),
(3, 1024, 1, 1),
(3, 1024, 1, 1),
]
class CNN_BLOCK(nn.Module):
def __init__(self, in_channels, out_channels, **kwargs):
super(CNN_BLOCK, self).__init__()
self.cnn = nn.Conv2d(in_channels, out_channels, **kwargs, bias=False)
self.bn = nn.BatchNorm2d(out_channels)
self.relu = nn.LeakyReLU(0.1)
def forward(self, x):
return self.relu(self.bn(self.cnn(x)))
class Yolo_v1(nn.Module):
def __init__(self, S=7, B=2, C=20):
super(Yolo_v1, self).__init__()
self.S = S
self.B = B
self.C = C
self.cnn = self._create_cnn()
self.fc = self._create_fc()
def _create_cnn(self):
in_channels = 3
layers = []
for layer in architecture_config:
if type(layer) == str:
layers += [
nn.MaxPool2d(kernel_size=2, stride=2)
]
elif type(layer) == tuple:
layers += [
CNN_BLOCK(in_channels=in_channels, out_channels=layer[1], kernel_size=layer[0], stride=layer[2],
padding=layer[3])
]
in_channels = layer[1]
else:
conv1 = layer[0]
conv2 = layer[1]
for _ in range(layer[-1]):
layers += [
CNN_BLOCK(in_channels=in_channels, out_channels=conv1[1], kernel_size=conv1[0], stride=conv1[2],
padding=conv1[3])
]
in_channels = conv1[1]
layers += [
CNN_BLOCK(in_channels=in_channels, out_channels=conv2[1], kernel_size=conv2[0], stride=conv2[2],
padding=conv2[3])
]
in_channels = conv2[1]
return nn.Sequential(*layers)
def _create_fc(self):
return nn.Sequential(
nn.Flatten(),
nn.Linear(1024 * self.S * self.S, 4096),
nn.LeakyReLU(0.1),
nn.Linear(4096, self.S * self.S * (self.C + self.B * 5))
)
def forward(self, x):
x = self.cnn(x)
x = self.fc(x)
return x
输入batch * 3 * 448 * 448 输出 batch * 7 * 7 * 30
loss 按照论文
import torch
import torch.nn as nn
from utils import intersection_over_union
class Yolo_v1_loss(nn.Module):
def __init__(self, S=7, B=2, C=20):
super(Yolo_v1_loss, self).__init__()
self.mse = nn.MSELoss(reduction='sum')
self.S = S
self.B = B
self.C = C
self.coord = 5
self.noobj = 0.5
def forward(self, predict, target):
# predict: N,(S * S * (B * 5 + C))
# target: N,S,S,(B * 5 + C)
predict = predict.reshape(-1, self.S, self.S, self.B * 5 + self.C)
"""计算每个anchor预测的第一个box和实际box的iou值"""
iou_b1 = intersection_over_union(predict[..., 21:25], target[..., 21:25])
"""计算每个anchor预测的第二个box和实际box的iou值"""
iou_b2 = intersection_over_union(predict[..., 26:30], target[..., 21:25])
ious = torch.cat([iou_b1.unsqueeze(0), iou_b2.unsqueeze(0)], dim=0)
iou_max, best_box_idx = torch.max(ious, dim=0)
exists_box = target[..., 20:21]
# ============= #
# 开始计算loss #
# ============= #
"""bounding_box损失"""
"""先得到最优iou对应的box"""
box_predict = exists_box * (best_box_idx * predict[..., 26:30] + (1 - best_box_idx) * predict[..., 21:25])
box_target = exists_box * target[..., 21:25]
"""按照论文给的coord损失,首先将w和h开方"""
box_predict[..., 2:4] = torch.sign(box_predict[..., 2:4]) * torch.sqrt(torch.abs(box_predict[..., 2:4] + 1e-6))
box_target[..., 2:4] = torch.sqrt(box_target[..., 2:4])
"""得到box_loss"""
box_loss = self.mse(
# N * 4
torch.flatten(box_predict, end_dim=-2),
torch.flatten(box_target, end_dim=-2)
)
"""confidence损失"""
"""先得到最优iou对应的confidence"""
confidence_predict = best_box_idx * predict[..., 25:26] + (1 - best_box_idx) * predict[..., 20:21]
confidence_target = target[..., 20:21]
"""这里只极大化目标位置的confidence,其他位置的损失由于太多了会影响,所以给个权重参数为noobj来弱化其他位置的损失"""
confidence_loss = self.mse(
torch.flatten(exists_box * confidence_predict),
torch.flatten(exists_box * confidence_target)
)
no_confidence_loss = self.mse(
torch.flatten((1 - exists_box) * predict[..., 20:21]),
torch.flatten((1 - exists_box) * confidence_target)
)
no_confidence_loss += self.mse(
torch.flatten((1 - exists_box) * predict[..., 25:26]),
torch.flatten((1 - exists_box) * confidence_target)
)
"""计算classes损失"""
class_loss = self.mse(
# N * S * S * 20
torch.flatten(exists_box * predict[..., :20], end_dim=-2),
torch.flatten(exists_box * target[..., :20], end_dim=-2)
)
"""相加"""
loss = self.coord * box_loss + confidence_loss + self.noobj * no_confidence_loss + class_loss
return loss
utils
import xml.etree.ElementTree as ET
import os
import os.path
import numpy as np
import torch
import matplotlib.pyplot as plt # 导入绘图包
import cv2 as cv
class_dict = ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car',
'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike',
'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor']
class_dict = {name: i for i, name in enumerate(class_dict)}
class_list = ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car',
'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike',
'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor']
def parse_xml():
xml_path = '../../VOCtrainval_06-Nov-2007/VOCdevkit/VOC2007/Annotations/'
xml_file = os.listdir(xml_path)
if not os.path.exists('labels'):
os.makedirs('labels')
for file in xml_file:
with open('labels/' + file.replace('.xml', '.txt'), 'w') as f:
root = ET.parse(xml_path + file).getroot()
width = float(root.find('size/width').text)
height = float(root.find('size/height').text)
for child in root.findall('object'):
"""类别"""
c = child.find('name').text
c = class_dict[c]
xmin = float(child.find('bndbox').find('xmin').text)
ymin = float(child.find('bndbox').find('ymin').text)
xmax = float(child.find('bndbox').find('xmax').text)
ymax = float(child.find('bndbox').find('ymax').text)
x_center = (xmin + xmax) / (2 * width)
y_center = (ymin + ymax) / (2 * height)
w = (xmax - xmin) / width
h = (ymax - ymin) / height
f.write(' '.join([str(c), str(x_center), str(y_center), str(w), str(h)]) + '\n')
def intersection_over_union(box1, box2, mode='center'):
if mode == 'center':
"""x_center,y_center,w,h"""
"""xmin,ymin,xmax,ymax"""
box1_x1 = box1[..., 0:1] - box1[..., 2:3] / 2
box1_y1 = box1[..., 1:2] - box1[..., 3:4] / 2
box1_x2 = box1[..., 0:1] + box1[..., 2:3] / 2
box1_y2 = box1[..., 1:2] + box1[..., 3:4] / 2
box2_x1 = box2[..., 0:1] - box2[..., 2:3] / 2
box2_y1 = box2[..., 1:2] - box2[..., 3:4] / 2
box2_x2 = box2[..., 0:1] + box2[..., 2:3] / 2
box2_y2 = box2[..., 1:2] + box2[..., 3:4] / 2
else:
box1_x1 = box1[..., 0:1]
box1_y1 = box1[..., 1:2]
box1_x2 = box1[..., 0:1]
box1_y2 = box1[..., 1:2]
box2_x1 = box2[..., 0:1]
box2_y1 = box2[..., 1:2]
box2_x2 = box2[..., 0:1]
box2_y2 = box2[..., 1:2]
"""计算交集面积"""
x1 = torch.max(box1_x1, box2_x1)
y1 = torch.max(box1_y1, box2_y1)
x2 = torch.min(box1_x2, box2_x2)
y2 = torch.min(box1_y2, box2_y2)
intersection_area = (x2 - x1).clamp(0) * (y2 - y1).clamp(0)
"""计算并集面积"""
box1_area = abs((box1_x2 - box1_x1) * (box1_y2 - box1_y1))
box2_area = abs((box2_x2 - box2_x1) * (box2_y2 - box2_y1))
return intersection_area / (box1_area + box2_area - intersection_area + 1e-6)
def non_max_suppression(bboxes, iou_threshold=0.5, threshold=0.4):
# bboxes: [[class,confidence,x,y,w,h],...]
bboxes = [box for box in bboxes if box[1] > threshold]
bboxes = sorted(bboxes, key=lambda x: x[1], reverse=True)
bboxes_nms = []
while bboxes:
chosen_box = bboxes.pop(0)
"""类别不一样或者iou小于某一个阈值说明俩个box不是预测同一个物体"""
bboxes = [
box for box in bboxes
if box[0] != chosen_box[0]
or intersection_over_union(torch.tensor(chosen_box[2:6]), torch.tensor(box[2:6]) < iou_threshold)
]
bboxes_nms.append(chosen_box)
return bboxes_nms
def plot_box(boxes, img):
H,W = img.shape[:2]
plt.imshow(img)
current_axis = plt.gca()
for bbox in boxes:
classes = bbox[0]
confidence = round(bbox[1].item(),2)
x = bbox[2]
y = bbox[3]
w = bbox[4]
h = bbox[5]
xmin = (x - w / 2) * W
xmax = (x + w / 2) * W
ymin = (y - h / 2) * H
ymax = (y + h / 2) * H
current_axis.add_patch(
plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, color='green', fill=False, linewidth=2))
current_axis.text(xmin, ymin, class_list[int(classes)] + ': {}'.format(confidence),
color='white', bbox={'facecolor': 'green', 'alpha': 1.0})
plt.show()
def get_boxes(pre,S = 7):
# pre.shape == 1 * 7 * 7 * 30
'''[[class confidence,x,y,w,h],...]'''
cell_indices = torch.arange(7).repeat(1, 7, 1).unsqueeze(-1)
pre[...,21:22] = (pre[...,21:22] + cell_indices) / S
pre[...,26:27] = (pre[...,26:27] + cell_indices) / S
pre[..., 22:23] = (pre[..., 22:23] + cell_indices.permute(0, 2, 1, 3)) / S
pre[..., 27:28] = (pre[..., 27:28] + cell_indices.permute(0, 2, 1, 3)) / S
pre[...,23:25] = pre[...,23:25] / S
pre[..., 28:30] = pre[..., 28:30] / S
pre = pre.reshape(7,7,30)
classes = torch.max(pre[..., :20],dim=-1).indices.unsqueeze(-1)
box1 = pre[...,21:25]
box2 = pre[...,26:30]
confidence1 = pre[...,20:21]
confidence2 = pre[...,25:26]
new_box = torch.zeros((7*7*2,6))
new_box[:49,2:6] = torch.flatten(box1,end_dim=-2)
new_box[49:,2:6] = torch.flatten(box2,end_dim=-2)
new_box[:49,0:1] = torch.flatten(classes,end_dim=-2)
new_box[49:,0:1] = torch.flatten(classes,end_dim=-2)
new_box[:49,1:2] = torch.flatten(confidence1,end_dim=-2)
new_box[49:,1:2] = torch.flatten(confidence2,end_dim=-2)
return non_max_suppression(new_box)
不知道要train多久,租了个服务器一直在跑