感谢 @zcl1122指出的倒数第三节代码中的i错误的被转行成大写的I的问题。
上一节粗略的描述了如何关于图像识别,抠图,分类的理论相关,本节主要用代码,来和大家一起分析每一步骤。
看完本节,希望你也能独立完成自己的图片、视频的内容实时定位。
首先,我们需要安装TensorFlow环境,建议利用conda进行安装,配置,90%尝试单独安装的人最后都挂了。
其次,我们需要安装从git上下载训练好的模型,git clone https://github.com/balancap/SSD-Tensorflow
如果没有安装git的朋友,请自行百度安装。
最后找到你下载的位置进行解压,unzip ./SSD-Tensorflow/checkpoints/ssd_300_vgg.ckpt.zip
这边务必注意,网上90%的教程这边就结束了,其实你这样是最后跑不通代码的,你需要把解压的文件进行移动到checkpoint的文件夹下面,这个问题git上这个同学解释了,详细的去看下https://github.com/balancap/SSD-Tensorflow/issues/150
最后的最后,下载你需要检测的网路图片,就ok了
预处理步骤完成了,下面让我们看代码。
加载相关的包:
import os
import math
import random
import sys
import numpy as np
import tensorflow as tf
import cv2
import matplotlib.pyplot as plt
import matplotlib.cm as mpcm
sys.path.append('./SSD-Tensorflow/')
from nets import ssd_vgg_300, ssd_common, np_methods
from preprocessing import ssd_vgg_preprocessing
配置相关TensorFlow环境
gpu_options = tf.GPUOptions(allow_growth=True)
config = tf.ConfigProto(log_device_placement=False, gpu_options=gpu_options)
isess = tf.InteractiveSession(config=config)
做图片的格式的处理,使他满足input的条件
#我们用的TensorFlow下的一个集成包slim,比tensor要更加轻便
slim = tf.contrib.slim
#训练数据中包含了一下已知的类别,也就是我们可以识别出以下的东西,不过后续我们将自己自己训练自己的模型,来识别自己想识别的东西
l_VOC_CLASS = [
'aeroplane', 'bicycle', 'bird', 'boat', 'bottle',
'bus', 'car', 'cat', 'chair', 'cow',
'diningTable', 'dog', 'horse', 'motorbike', 'person',
'pottedPlant', 'sheep', 'sofa', 'train', 'TV'
]
# 定义数据格式
net_shape = (300, 300)
data_format = 'NHWC' # [Number, height, width, color],Tensorflow backend 的格式
# 预处理将输入图片大小改成 300x300,作为下一步输入
img_input = tf.placeholder(tf.uint8, shape=(None, None, 3))
image_pre, labels_pre, bboxes_pre, bbox_img = ssd_vgg_preprocessing.preprocess_for_eval(
img_input,
None,
None,
net_shape,
data_format,
resize=ssd_vgg_preprocessing.Resize.WARP_RESIZE
)
image_4d = tf.expand_dims(image_pre, 0)
下面我们来载入SSD作者已经搞定的模型
# 定义 SSD 模型结构
reuse = True if 'ssd_net' in locals() else None
ssd_net = ssd_vgg_300.SSDNet()
with slim.arg_scope(ssd_net.arg_scope(data_format=data_format)):
predictions, localisations, _, _ = ssd_net.net(image_4d, is_training=False, reuse=reuse)
# 导入官方给出的 SSD 模型参数
#这边修改成你自己的路径
ckpt_filename = '/Users/slade/SSD-Tensorflow/checkpoints/ssd_300_vgg.ckpt'
isess.run(tf.global_variables_initializer())
saver = tf.train.Saver()
saver.restore(isess, ckpt_filename)
ssd_anchors = ssd_net.anchors(net_shape)
下面让我们把SSD识别出来的结果在图片中表示出来
#不同类别,我们以不同的颜色表示
def colors_subselect(colors, num_classes=21):
dt = len(colors) // num_classes
sub_colors = []
for i in range(num_classes):
color = colors[i*dt]
if isinstance(color[0], float):
sub_colors.append([int(c * 255) for c in color])
else:
sub_colors.append([c for c in color])
return sub_colors
#画出在图中的位置
def bboxes_draw_on_img(img, classes, scores, bboxes, colors, thickness=5):
shape = img.shape
for i in range(bboxes.shape[0]):
bbox = bboxes[i]
color = colors[classes[i]]
# Draw bounding box...
p1 = (int(bbox[0] * shape[0]), int(bbox[1] * shape[1]))
p2 = (int(bbox[2] * shape[0]), int(bbox[3] * shape[1]))
cv2.rectangle(img, p1[::-1], p2[::-1], color, thickness)
# Draw text...
s = '%s:%.3f' % ( l_VOC_CLASS[int(classes[i])-1], scores[i])
p1 = (p1[0]-5, p1[1])
cv2.putText(img, s, p1[::-1], cv2.FONT_HERSHEY_SIMPLEX, 1, color, 2)
colors_plasma = colors_subselect(mpcm.plasma.colors, num_classes=21)
让我们开始训练吧
def process_image(img, select_threshold=0.3, nms_threshold=.8, net_shape=(300, 300)):
#先获取SSD网络的层相关的参数
rimg, rpredictions, rlocalisations, rbbox_img = isess.run([image_4d, predictions, localisations, bbox_img],
feed_dict={img_input: img})
#获取分类结果,位置
rclasses, rscores, rbboxes = np_methods.ssd_bboxes_select(
rpredictions, rlocalisations, ssd_anchors,
select_threshold=select_threshold, img_shape=net_shape, num_classes=21, decode=True)
rbboxes = np_methods.bboxes_clip(rbbox_img, rbboxes)
rclasses, rscores, rbboxes = np_methods.bboxes_sort(rclasses, rscores, rbboxes, top_k=400)
rclasses, rscores, rbboxes = np_methods.bboxes_nms(rclasses, rscores, rbboxes, nms_threshold=nms_threshold)
# 让我们在图中画出来就行了
rbboxes = np_methods.bboxes_resize(rbbox_img, rbboxes)
bboxes_draw_on_img(img, rclasses, rscores, rbboxes, colors_plasma, thickness=2)
return img
预处理的函数都写完了,我们就可以执行了。
#读取数据
img = cv2.imread("/Users/slade/Documents/Yoho/picture_recognize/test7.jpg")
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
plt.imshow(process_image(img))
plt.show()
img的数据形式如下:
In [8]: img
Out[8]:
array([[[ 35, 59, 43],
[ 37, 60, 44],
[ 38, 61, 45],
...,
[ 73, 99, 62],
[ 74, 99, 60],
[ 72, 97, 57]],
[[ 37, 60, 44],
[ 37, 60, 44],
[ 37, 60, 44],
...,
[ 66, 92, 57],
[ 67, 93, 56],
[ 67, 92, 53]],
[[ 37, 60, 44],
[ 36, 59, 43],
[ 37, 58, 43],
...,
[ 56, 83, 48],
[ 60, 86, 51],
[ 61, 87, 50]],
...,
[[ 96, 101, 95],
[107, 109, 104],
[ 98, 97, 95],
...,
[ 84, 126, 76],
[ 72, 118, 72],
[ 78, 126, 86]],
[[ 98, 103, 96],
[114, 116, 111],
[112, 113, 108],
...,
[ 94, 137, 84],
[ 87, 133, 86],
[105, 153, 111]],
[[ 99, 105, 95],
[110, 113, 106],
[134, 135, 129],
...,
[127, 170, 116],
[121, 167, 118],
[131, 180, 135]]], dtype=uint8)
处理后的结果如下:
是不是非常无脑,上面的代码直接复制就可以完成。
下面在拓展一下视频的处理方式,其实相关的内容是一致的。
利用moviepy.editor包里面的VideoFileClip的切片的功能,然后对每一次切片的结果进行process_image过程就可以了,这边就不贴代码了,需要的朋友私密我。
最后感谢大家阅读。
欢迎大家关注我的个人bolg,更多代码内容欢迎follow我的个人Github,如果有任何算法、代码疑问都欢迎通过公众号发消息给我哦。