1.图像目标检测是什么
2.模型是如何完成目标检测的
3.深度学习目标检测模型简介
4.Pytorch中Faster RCNN训练
1.图像目标检测是什么
目标检测:判断图像中目标的位置
目标检测两要素:
1.分类:分类向量[p0,p1,…,pn]
2.回归:回归边界框[x1,y1,x2,y2]
2.模型是如何完成目标检测的
将3D张量映射到两个张量
1.分类张量:shape为[N,c+1]
2.边界框张量:shape为[N,4]
推荐阅读:
《Recent Advances in Deep Learning for Object Detection》–2019
边界框数量N如何确定?
传统方法—滑动窗策略
很简单很原始的方法,遍历图中所有位置,进行判定,如果判定的概率大于阈值,就显示
为了克服物体大小变换的缺点,采用了多尺度的滑动窗口
缺点:
1.重复计算量大
2.窗口大小难确定
利用卷积解决滑动窗策略
重要概念:
特征图一个像素对应原图的一块区域
最后一层输出是分类层
目标检测模型简介:(以2014年的RCNN为分界线 )
《Object Detection in 20 Years-A Survey》-2019
Faster FCNN 确定了 two-stage的出现,具有里程碑意义
目标检测按流程分为2类:one-stage和two-stage
区别在于是否还有Proposal generation结构
Two-Stage:FasterRCNN:
ROI Layer以后每一个候选框都reshape为3*3,然后将每一候选框的reshape之后的特征图拼接起来
Faster RCNN数据流:
3.深度学习目标检测模型简介
4.Pytorch中Faster RCNN训练
一般分类的运算过程:(可与检测做比较)
将3D张量映射到两个张量
1.分类张量:shape为[N,c+1]
2.边界框张量:shape为[N,4]
import os
import time
import torch.nn as nn
import torch
import numpy as np
import torchvision.transforms as transforms
import torchvision
from PIL import Image
from matplotlib import pyplot as plt
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# classes_coco
COCO_INSTANCE_CATEGORY_NAMES = [
'__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign',
'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A',
'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table',
'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book',
'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
]
if __name__ == "__main__":
# path_img = os.path.join(BASE_DIR, "demo_img1.png")
path_img = os.path.join(BASE_DIR, "demo_img2.png")
# config
preprocess = transforms.Compose([
transforms.ToTensor(),
])
# 1. load data & model
input_image = Image.open(path_img).convert("RGB")
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
model.eval()
# 2. preprocess
img_chw = preprocess(input_image)
# 3. to device
if torch.cuda.is_available():
img_chw = img_chw.to('cuda')
model.to('cuda')
# 4. forward
input_list = [img_chw]
with torch.no_grad():
tic = time.time()
print("input img tensor shape:{}".format(input_list[0].shape))
output_list = model(input_list)
output_dict = output_list[0]
print("pass: {:.3f}s".format(time.time() - tic))
for k, v in output_dict.items():
print("key:{}, value:{}".format(k, v))
# 5. visualization
out_boxes = output_dict["boxes"].cpu()
out_scores = output_dict["scores"].cpu()
out_labels = output_dict["labels"].cpu()
fig, ax = plt.subplots(figsize=(12, 12))
ax.imshow(input_image, aspect='equal')
num_boxes = out_boxes.shape[0]
max_vis = 40
thres = 0.5 # 阈值
for idx in range(0, min(num_boxes, max_vis)):
score = out_scores[idx].numpy()
bbox = out_boxes[idx].numpy()
class_name = COCO_INSTANCE_CATEGORY_NAMES[out_labels[idx]]
if score < thres:
continue
ax.add_patch(plt.Rectangle((bbox[0], bbox[1]), bbox[2] - bbox[0], bbox[3] - bbox[1], fill=False,
edgecolor='red', linewidth=3.5))
ax.text(bbox[0], bbox[1] - 2, '{:s} {:.3f}'.format(class_name, score), bbox=dict(facecolor='blue', alpha=0.5),
fontsize=14, color='white')
plt.show()
plt.close()
# appendix
classes_pascal_voc = ['__background__',
'aeroplane', 'bicycle', 'bird', 'boat',
'bottle', 'bus', 'car', 'cat', 'chair',
'cow', 'diningtable', 'dog', 'horse',
'motorbike', 'person', 'pottedplant',
'sheep', 'sofa', 'train', 'tvmonitor']
# classes_coco
COCO_INSTANCE_CATEGORY_NAMES = [
'__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign',
'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A',
'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table',
'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book',
'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
]
输入:
input img tensor shape:torch.Size([3, 624, 1270])
pass: 13.840s
输出:
一是:坐标
key:boxes, value:tensor([[2.1388e+01, 4.0842e+02, 5.6309e+01, 5.3991e+02],
[2.7492e+02, 4.1662e+02, 3.1850e+02, 5.2794e+02],
[3.3163e+02, 5.0661e+02, 3.8220e+02, 6.2109e+02],
[1.0627e+03, 5.6293e+02, 1.1683e+03, 6.2370e+02],
[8.8008e+02, 5.0105e+02, 9.3214e+02, 6.2330e+02],
[2.9637e+02, 5.2668e+02, 3.4388e+02, 6.2200e+02],
[1.5381e+02, 3.9281e+02, 1.9044e+02, 4.7901e+02],
[5.2482e+02, 5.5501e+02, 5.9420e+02, 6.2310e+02],
[4.3968e+02, 4.7417e+02, 4.9715e+02, 6.1564e+02],
[9.6593e+02, 4.4682e+02, 1.0049e+03, 5.7216e+02],
[1.0311e+03, 4.7705e+02, 1.0740e+03, 6.1917e+02],
[7.1512e+02, 5.5515e+02, 7.6445e+02, 6.2252e+02],
[5.9517e+02, 5.6860e+02, 6.5844e+02, 6.2400e+02],
[1.9183e+02, 3.9135e+02, 2.1817e+02, 4.5500e+02],
[9.2349e+02, 4.2543e+02, 9.6887e+02, 5.4163e+02],
[8.4533e+02, 4.2685e+02, 8.8472e+02, 5.3358e+02],
[5.7697e-01, 3.6250e+02, 1.9273e+01, 4.2032e+02],
[7.8785e+02, 4.5489e+02, 8.3018e+02, 5.5748e+02],
[5.9754e+02, 4.3986e+02, 6.4339e+02, 5.7236e+02],
[7.5374e+02, 5.4186e+02, 8.4093e+02, 6.2392e+02],
[6.8067e+02, 5.3767e+02, 7.2788e+02, 6.2329e+02],
[1.0173e+03, 5.0085e+02, 1.0504e+03, 5.4632e+02],
[8.1195e+02, 4.2308e+02, 8.4457e+02, 5.3206e+02],
[1.1842e+03, 5.6938e+02, 1.2687e+03, 6.2373e+02],
[7.5443e+02, 3.9087e+02, 7.9018e+02, 4.9879e+02],
[5.3105e+02, 3.9078e+02, 5.6281e+02, 4.8724e+02],
[8.9093e+02, 4.1321e+02, 9.2183e+02, 5.0890e+02],
[9.0156e+02, 4.5113e+02, 9.1980e+02, 4.7044e+02],
[4.9166e+02, 4.8392e+02, 5.1212e+02, 5.2901e+02],
[7.1177e+02, 4.7322e+02, 7.4841e+02, 5.6356e+02],
[1.1422e+03, 4.1848e+02, 1.1851e+03, 5.2724e+02],
[1.1043e+03, 4.1396e+02, 1.1432e+03, 5.1554e+02],
[1.5299e+02, 4.0617e+02, 1.7880e+02, 4.3980e+02],
[9.6673e+02, 4.7046e+02, 9.9373e+02, 5.1017e+02],
[4.8157e+02, 5.2469e+02, 4.9799e+02, 5.6309e+02],
[3.3969e+02, 3.4014e+02, 3.6644e+02, 4.0998e+02],
[1.1228e+01, 3.0504e+02, 2.5386e+01, 3.4646e+02],
[5.6783e+02, 4.4703e+02, 6.0337e+02, 5.6745e+02],
[1.0672e+03, 4.0854e+02, 1.1082e+03, 5.1371e+02],
[7.0498e+02, 4.0970e+02, 7.3973e+02, 4.9432e+02],
[2.5139e+02, 3.2950e+02, 2.7345e+02, 3.9105e+02],
[1.1737e+03, 4.1148e+02, 1.2080e+03, 5.2497e+02],
[1.6692e+02, 2.8142e+02, 1.8108e+02, 3.1289e+02],
[3.0409e+02, 4.6405e+02, 3.1880e+02, 5.0083e+02],
[3.6005e+02, 3.9337e+02, 3.8882e+02, 4.8300e+02],
[1.3484e+02, 3.2502e+02, 1.5079e+02, 3.6977e+02],
[1.0540e+03, 5.0612e+02, 1.0702e+03, 5.3384e+02],
[1.0107e+03, 4.4346e+02, 1.0457e+03, 5.4820e+02],
[9.8965e+02, 3.7219e+02, 1.0162e+03, 4.4769e+02],
[8.3097e+02, 3.9531e+02, 8.5813e+02, 4.6470e+02],
[4.6970e+02, 4.5238e+02, 5.0836e+02, 5.7045e+02],
[1.5863e+02, 3.3730e+02, 1.7745e+02, 3.7866e+02],
[1.5346e+02, 4.0810e+02, 1.7282e+02, 4.4108e+02],
[5.5801e+02, 3.8345e+02, 5.8827e+02, 4.8299e+02],
[9.6700e+02, 4.7294e+02, 9.9018e+02, 5.0854e+02],
[3.8862e+02, 3.7365e+02, 4.2910e+02, 5.0485e+02],
[4.9743e+02, 3.8817e+02, 5.2440e+02, 4.8490e+02],
[6.0864e+01, 2.8248e+02, 7.6250e+01, 3.1515e+02],
[6.8809e+02, 4.9485e+02, 7.2735e+02, 5.5286e+02],
[6.5060e+02, 4.9366e+02, 7.0111e+02, 6.1884e+02],
[6.6512e+02, 4.5297e+02, 6.8753e+02, 4.8096e+02],
[2.0731e+02, 3.9242e+02, 2.2674e+02, 4.5651e+02],
[3.3188e+02, 3.0842e+02, 3.4601e+02, 3.4885e+02],
[6.0170e+01, 3.0998e+02, 7.7779e+01, 3.4435e+02],
[1.0159e+03, 4.9616e+02, 1.0568e+03, 5.4726e+02],
[1.1761e+03, 5.3929e+02, 1.2377e+03, 6.1400e+02],
[6.3865e+02, 4.2180e+02, 6.7230e+02, 5.2063e+02],
[4.6562e+02, 3.9251e+02, 4.8681e+02, 4.2763e+02],
[5.6223e+01, 3.1221e+02, 7.0839e+01, 3.4422e+02],
[4.3124e+02, 3.2679e+02, 4.6978e+02, 3.9345e+02],
[2.8342e+02, 3.0273e+02, 2.9939e+02, 3.3403e+02],
[1.7359e+02, 3.9106e+02, 1.9323e+02, 4.5638e+02],
[4.7852e+02, 4.7405e+02, 5.1008e+02, 5.2955e+02],
[6.6626e+02, 4.1746e+02, 7.0461e+02, 4.9124e+02],
[7.4063e+02, 4.8192e+02, 7.6984e+02, 5.5984e+02],
[4.1674e+02, 3.6833e+02, 4.4251e+02, 4.5964e+02],
[7.4635e+02, 3.5432e+02, 7.7786e+02, 4.1237e+02],
[8.7581e+02, 3.5277e+02, 8.9965e+02, 4.2111e+02],
[9.6405e+02, 3.4505e+02, 9.8138e+02, 3.9366e+02],
[5.9925e+02, 3.9358e+02, 6.2953e+02, 4.5713e+02],
[1.0104e+03, 4.4925e+02, 1.0562e+03, 6.0309e+02],
[1.9035e+02, 3.3621e+02, 2.1258e+02, 3.8854e+02],
[4.4140e+02, 5.9187e+02, 4.8636e+02, 6.2350e+02],
[1.2040e+03, 4.6967e+02, 1.2527e+03, 5.5478e+02],
[1.8247e+02, 2.8248e+02, 1.9429e+02, 3.0939e+02],
[3.0101e+02, 3.0192e+02, 3.2035e+02, 3.5139e+02],
[1.0764e+03, 4.9216e+02, 1.1278e+03, 5.8641e+02],
[1.0449e+03, 3.4934e+02, 1.0654e+03, 4.0475e+02],
[5.1159e+02, 3.8448e+02, 5.3611e+02, 4.8216e+02],
[3.1655e+02, 3.1366e+02, 3.3138e+02, 3.5103e+02],
[9.4753e+02, 3.4062e+02, 9.6557e+02, 3.9751e+02],
[5.8450e+02, 3.9833e+02, 6.1503e+02, 4.8090e+02],
[1.0924e+03, 3.8405e+02, 1.1197e+03, 4.3034e+02],
[7.4493e+02, 3.8010e+02, 7.5996e+02, 4.0455e+02],
[5.1159e+02, 3.4380e+02, 5.3484e+02, 3.9000e+02],
[1.1097e+03, 3.0709e+02, 1.1257e+03, 3.4568e+02],
[1.0731e+03, 4.9901e+02, 1.1604e+03, 6.1909e+02],
[6.6789e+02, 3.3650e+02, 6.8846e+02, 3.8195e+02],
[3.9529e+02, 4.2961e+02, 4.2192e+02, 4.6252e+02],
[9.1005e+02, 4.1943e+02, 9.3937e+02, 5.1179e+02]])
二是:类别
key:labels, value:tensor([ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 31, 1, 1, 1, 1, 1, 31, 27, 1, 1, 1, 27, 27, 31, 1,
1, 1, 1, 1, 1, 1, 1, 31, 1, 1, 31, 1, 1, 1, 1, 1, 31, 1,
31, 1, 1, 1, 1, 1, 31, 1, 1, 1, 27, 1, 1, 31, 1, 1, 1, 1,
27, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 31, 1, 1, 1, 1, 31, 1])
三是:类别置信度
key:scores, value:tensor([0.9861, 0.9850, 0.9780, 0.9778, 0.9771, 0.9736, 0.9489, 0.9457, 0.9451,
0.9073, 0.8725, 0.8721, 0.8539, 0.8525, 0.8394, 0.8079, 0.7984, 0.7772,
0.7599, 0.7479, 0.7290, 0.7090, 0.6798, 0.6636, 0.6636, 0.6547, 0.6513,
0.6507, 0.6486, 0.6380, 0.6087, 0.6002, 0.5873, 0.5867, 0.5682, 0.5653,
0.5568, 0.5564, 0.5498, 0.5368, 0.5271, 0.5193, 0.5159, 0.4953, 0.4537,
0.4505, 0.4494, 0.4458, 0.4408, 0.4390, 0.4288, 0.4184, 0.4058, 0.3984,
0.3913, 0.3878, 0.3873, 0.3766, 0.3633, 0.3479, 0.3471, 0.3450, 0.3420,
0.3408, 0.3378, 0.3357, 0.3273, 0.3202, 0.3140, 0.3099, 0.2969, 0.2955,
0.2953, 0.2906, 0.2810, 0.2803, 0.2800, 0.2782, 0.2759, 0.2721, 0.2699,
0.2681, 0.2658, 0.2644, 0.2599, 0.2527, 0.2498, 0.2456, 0.2371, 0.2361,
0.2312, 0.2309, 0.2293, 0.2239, 0.2210, 0.2201, 0.2136, 0.2110, 0.2064,
0.1998])