SSD Mobile netv1的实现
主要是跟着https://www.cnblogs.com/White-xzx/p/9503203.html来做的,问题在于训练之后的检测,跑通了但是没有框出现,结合终端的结果我觉得是训练参数的问题,原blog中检测代码里有点小问题,已订正如下:
# -*- coding: utf-8 -*-
import os
from PIL import Image
import time
import tensorflow as tf
from PIL import Image
import numpy as np
import os
import six.moves.urllib as urllib
import sys
import tarfile
import zipfile
import time
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
from collections import defaultdict
from io import StringIO
from matplotlib import pyplot as plt
# plt.switch_backend('Agg')
from utils import label_map_util
from utils import visualization_utils as vis_util
PATH_TO_TEST_IMAGES = "D:\\ssd_mobilenetv1\\test_images\\"
MODEL_NAME = 'D:/ssd_mobilenetv1/data'
PATH_TO_CKPT = MODEL_NAME + '/exported_model_directory/frozen_inference_graph.pb'
PATH_TO_LABELS = MODEL_NAME+'/label_map.pbtxt'
NUM_CLASSES = 1
PATH_TO_RESULTS = "D:\\ssd_mobilenetv1\\results\\"
def load_image_into_numpy_array(image):
(im_width, im_height) = image.size
return np.array(image.getdata()).reshape((im_height, im_width, 3)).astype(np.uint8)
def save_object_detection_result():
IMAGE_SIZE = (12, 8)
# Load a (frozen) Tensorflow model into memory.
detection_graph = tf.Graph()
with detection_graph.as_default():
od_graph_def = tf.GraphDef()
# loading ckpt file to graph
with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
serialized_graph = fid.read()
od_graph_def.ParseFromString(serialized_graph)
tf.import_graph_def(od_graph_def, name='')
# Loading label map
label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES,
use_display_name=True)
category_index = label_map_util.create_category_index(categories)
# Helper code
with detection_graph.as_default():
with tf.Session(graph=detection_graph) as sess:
for test_image in os.listdir(PATH_TO_TEST_IMAGES):
image = Image.open(PATH_TO_TEST_IMAGES + test_image)
# the array based representation of the image will be used later in order to prepare the
# result image with boxes and labels on it.
image_np = load_image_into_numpy_array(image)
# Expand dimensions since the model expects images to have shape: [1, None, None, 3]
image_np_expanded = np.expand_dims(image_np, axis=0)
image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
# Each box represents a part of the image where a particular object was detected.
boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
# Each score represent how level of confidence for each of the objects.
# Score is shown on the result image, together with the class label.
scores = detection_graph.get_tensor_by_name('detection_scores:0')
classes = detection_graph.get_tensor_by_name('detection_classes:0')
num_detections = detection_graph.get_tensor_by_name('num_detections:0')
# Actual detection.
(boxes, scores, classes, num_detections) = sess.run(
[boxes, scores, classes, num_detections],
feed_dict={image_tensor: image_np_expanded})
# Visualization of the results of a detection.
vis_util.visualize_boxes_and_labels_on_image_array(
image_np,
np.squeeze(boxes),
np.squeeze(classes).astype(np.int32),
np.squeeze(scores),
category_index,
use_normalized_coordinates=True,
line_thickness=8)
final_score = np.squeeze(scores)
count = 0
for i in range(100):
if scores is None or final_score[i] > 0.24:
count = count + 1
print()
print("the count of objects is: ", count)
(im_width, im_height) = image.size
start = time.time()
for i in range(count):
# print(boxes[0][i])
y_min = boxes[0][i][0] * im_height
x_min = boxes[0][i][1] * im_width
y_max = boxes[0][i][2] * im_height
x_max = boxes[0][i][3] * im_width
x = int((x_min + x_max) / 2)
y = int((y_min + y_max) / 2)
if category_index[classes[0][i]]['name'] == "ship":
print("this image has a ship!")
y = int((y_max - y_min) / 4 * 3 + y_min)
print("object{0}: {1}".format(i, category_index[classes[0][i]]['name']),
',Center_X:', x, ',Center_Y:', y)
# print(x_min,y_min,x_max,y_max)
plt.figure(figsize=IMAGE_SIZE)
plt.imshow(image_np)
picName = test_image.split('/')[-1]
# print(picName)
plt.savefig(PATH_TO_RESULTS + picName)
print(test_image + ' succeed')
end = time.time()
seconds = end - start
print("Time taken : {0} seconds".format(seconds))
save_object_detection_result()
主要流程跟着 https://blog.csdn.net/hezuo1181/article/details/91380182
其中训练命令个人更正为相对目录方可用。
制作数据集参考https://blog.csdn.net/duanyajun987/article/details/81507656?ops_request_misc=%257B%2522request%255Fid%2522%253A%2522158365437919724839202749%2522%252C%2522scm%2522%253A%252220140713.130056874…%2522%257D&request_id=158365437919724839202749&biz_id=0&utm_source=distribute.pc_search_result.none-task
批量重命名参考 https://blog.csdn.net/weixin_39853245/article/details/90421440?ops_request_misc=%257B%2522request%255Fid%2522%253A%2522158382784119725211965304%2522%252C%2522scm%2522%253A%252220140713.130056874…%2522%257D&request_id=158382784119725211965304&biz_id=0&utm_source=distribute.pc_search_result.none-task
其中统计选框数目和图片数目的代码参考(仅针对单类别且数目为0那一栏)
在这里插入import os
path = r'D:\SSD-Tensorflow-master\voc2007\ImageSets\Main'
train = open(os.path.join(path, 'train.txt'))
TRAIN_STATISTICS = {
'none': [0, 0],
'0': [0, 0], # 238图片书, 306目标总数
'1': [0, 0],
'2': [0, 0],
'3': [0, 0],
'4': [0, 0],
'5': [0, 0],
'6': [0, 0],
'7': [0, 0],
'8': [0, 0],
'9': [0, 0],
'total': [0, 0]
}
lines = train.readlines()
for line in lines:
line = line.strip()
print(line)
line_set = ''.join(x for i, x in enumerate(line) if line.index(x) == i)
for ele in line:
TRAIN_STATISTICS[ele][1] += 1
TRAIN_STATISTICS['total'][1] += 1
for item in line_set:
TRAIN_STATISTICS[item][0] += 1
TRAIN_STATISTICS['total'][0] = len(lines)
train.close()
test = open(os.path.join(path, 'test.txt'))
TEST_STATISTICS = {
'none': [0, 0],
'0': [0, 0], # 238图片书, 306目标总数
'1': [0, 0],
'2': [0, 0],
'3': [0, 0],
'4': [0, 0],
'5': [0, 0],
'6': [0, 0],
'7': [0, 0],
'8': [0, 0],
'9': [0, 0],
'total': [0, 0]
}
lines = test.readlines()
for line in lines:
line = line.strip()
print(line)
line_set = ''.join(x for i, x in enumerate(line) if line.index(x) == i)
for ele in line:
TEST_STATISTICS[ele][1] += 1
TEST_STATISTICS['total'][1] += 1
for item in line_set:
TEST_STATISTICS[item][0] += 1
TEST_STATISTICS['total'][0] = len(lines)
test.close()
print(TRAIN_STATISTICS)
print(TEST_STATISTICS)代码片