import os
import re
import cv2
import glob
import json
import math
import shutil
import numpy as np
import tensorflow as tf
from PIL import Image
PATH_TO_CKPT = 'model/frozen_inference_graph.pb'
detection_graph = tf.Graph()
with detection_graph.as_default():
od_graph_def = tf.GraphDef()
with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
serialized_graph = fid.read()
od_graph_def.ParseFromString(serialized_graph)
tf.import_graph_def(od_graph_def, name='')
with detection_graph.as_default():
with tf.Session(graph=detection_graph) as sess:
image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
detection_boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
detection_scores = detection_graph.get_tensor_by_name('detection_scores:0')
detection_classes = detection_graph.get_tensor_by_name('detection_classes:0')
num_detections = detection_graph.get_tensor_by_name('num_detections:0')
PATH_TO_IMAGES = 'images/'
PREDICTED_PATH = 'predicted/'
if not os.path.exists(PREDICTED_PATH):
os.makedirs(PREDICTED_PATH)
for filename in os.listdir(PATH_TO_IMAGES):
if filename.lower().endswith(('.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff')):
image_name = PATH_TO_IMAGES + filename
image = cv2.imread(image_name)
image_np = np.array(image).astype(np.uint8)
image_np_expanded = np.expand_dims(image_np, axis=0)
(boxes, scores, classes, num) = sess.run(
[detection_boxes, detection_scores, detection_classes, num_detections],
feed_dict={
image_tensor: image_np_expanded})
img = cv2.cvtColor(np.array(image_np), cv2.COLOR_RGB2BGR)
shape = img.shape
p1 = (int(boxes[0][0][0] * shape[0]), int(boxes[0][0][1] * shape[1]))
p2 = (int(boxes[0][0][2] * shape[0]), int(boxes[0][0][3] * shape[1]))
with open(PREDICTED_PATH + "2007_" + filename[:-4] + ".txt","w") as f:
f.write("instrument"+" "+str(round(scores[0][0],6))+" "+str(p1[::-1][0])+" "+str(p1[::-1][1])+" "+str(p2[::-1][0])+" "+str(p2[::-1][1])+"\n")
PATH_TO_XML = 'annotations/'
GROUND_TRUTH_PATH = 'ground-truth/'
if not os.path.exists(GROUND_TRUTH_PATH):
os.makedirs(GROUND_TRUTH_PATH)
for filename in os.listdir(PATH_TO_XML):
if filename.lower().endswith(('.xml')):
xml_name = os.path.join(PATH_TO_XML, filename)
with open(xml_name,"r") as f:
f1 = open(GROUND_TRUTH_PATH + "2007_" + filename[:-4] + ".txt","w")
for line in f.readlines():
classes = re.match(r".*(.*) .*",line,re.I)
xmin = re.match(r".*(.*) .*",line,re.I)
ymin = re.match(r".*(.*) .*",line,re.I)
xmax = re.match(r".*(.*) .*",line,re.I)
ymax = re.match(r".*(.*) .*",line,re.I)
if classes:
f1.write(classes.group(1) + " ")
if xmin:
f1.write(xmin.group(1) + " ")
if ymin:
f1.write(ymin.group(1) + " ")
if xmax:
f1.write(xmax.group(1) + " ")
if ymax:
f1.write(ymax.group(1) + "\n")
f1.close()
f.close()
def voc_ap(rec, prec):
rec.insert(0, 0.0)
rec.append(1.0)
mrec = rec[:]
prec.insert(0, 0.0)
prec.append(0.0)
mpre = prec[:]
for i in range(len(mpre)-2, -1, -1):
mpre[i] = max(mpre[i], mpre[i+1])
i_list = []
for i in range(1, len(mrec)):
if mrec[i] != mrec[i-1]:
i_list.append(i)
ap = 0.0
for i in i_list:
ap += ((mrec[i]-mrec[i-1])*mpre[i])
return ap
TEMP_FILES_PATH = ".temp_files"
if not os.path.exists(TEMP_FILES_PATH):
os.makedirs(TEMP_FILES_PATH)
ground_truth_files_list = glob.glob('ground-truth/*.txt')
ground_truth_files_list.sort()
gt_counter_per_class = {
}
counter_images_per_class = {
}
for txt_file in ground_truth_files_list:
file_id = txt_file.split(".txt",1)[0]
file_id = os.path.basename(os.path.normpath(file_id))
f = open(txt_file)
lines_list = f.readlines()
f.close()
lines_list = [x.strip() for x in lines_list]
bounding_boxes = []
already_seen_classes = []
for line in lines_list:
class_name, left, top, right, bottom = line.split()
bbox = left + " " + top + " " + right + " " +bottom
bounding_boxes.append({
"class_name":class_name, "bbox":bbox, "used":False})
if class_name in gt_counter_per_class:
gt_counter_per_class[class_name] += 1
else:
gt_counter_per_class[class_name] = 1
if class_name not in already_seen_classes:
if class_name in counter_images_per_class:
counter_images_per_class[class_name] += 1
else:
counter_images_per_class[class_name] = 1
already_seen_classes.append(class_name)
# 将每个真实标签数据写入一个临时的.json文件中
with open(TEMP_FILES_PATH + "/" + file_id + "_ground_truth.json", 'w') as outfile:
json.dump(bounding_boxes, outfile)
gt_classes = list(gt_counter_per_class.keys())
gt_classes = sorted(gt_classes)
n_classes = len(gt_classes)
predicted_files_list = glob.glob('predicted/*.txt')
predicted_files_list.sort()
for class_index, class_name in enumerate(gt_classes):
bounding_boxes = []
for txt_file in predicted_files_list:
file_id = txt_file.split(".txt",1)[0]
file_id = os.path.basename(os.path.normpath(file_id))
f = open(txt_file)
lines = f.readlines()
f.close()
lines = [x.strip() for x in lines]
for line in lines:
tmp_class_name, confidence, left, top, right, bottom = line.split()
if tmp_class_name == class_name:
bbox = left + " " + top + " " + right + " " +bottom
bounding_boxes.append({
"confidence":confidence, "file_id":file_id, "bbox":bbox})
bounding_boxes.sort(key=lambda x:float(x['confidence']), reverse=True)
with open(TEMP_FILES_PATH + "/" + class_name + "_predictions.json", 'w') as outfile:
json.dump(bounding_boxes, outfile)
MINOVERLAP = 0.5 # iou阈值
sum_AP = 0.0
results_files_path = "results"
if os.path.exists(results_files_path):
shutil.rmtree(results_files_path)
os.makedirs(results_files_path)
with open(results_files_path + "/results.txt", 'w') as results_file:
results_file.write("# AP and precision/recall per class\n")
count_true_positives = {
}
for class_index, class_name in enumerate(gt_classes):
count_true_positives[class_name] = 0
#导入预测数据
predictions_file = TEMP_FILES_PATH + "/" + class_name + "_predictions.json"
predictions_data = json.load(open(predictions_file))
nd = len(predictions_data)
tp = [0] * nd
fp = [0] * nd
for idx, prediction in enumerate(predictions_data):
file_id = prediction["file_id"]
gt_file = TEMP_FILES_PATH + "/" + file_id + "_ground_truth.json"
ground_truth_data = json.load(open(gt_file))
ovmax = -1
gt_match = -1
# 导入预测框
bb = [ float(x) for x in prediction["bbox"].split() ]
for obj in ground_truth_data:
if obj["class_name"] == class_name:
bbgt = [ float(x) for x in obj["bbox"].split() ]
bi = [max(bb[0],bbgt[0]), max(bb[1],bbgt[1]), min(bb[2],bbgt[2]), min(bb[3],bbgt[3])]
iw = bi[2] - bi[0] + 1
ih = bi[3] - bi[1] + 1
if iw > 0 and ih > 0:
# 计算IOU
ua = (bb[2] - bb[0] + 1) * (bb[3] - bb[1] + 1) + (bbgt[2] - bbgt[0]
+ 1) * (bbgt[3] - bbgt[1] + 1) - iw * ih
ov = iw * ih / ua
if ov > ovmax:
ovmax = ov
gt_match = obj
min_overlap = MINOVERLAP
if ovmax >= min_overlap:
if not bool(gt_match["used"]):
tp[idx] = 1
gt_match["used"] = True
count_true_positives[class_name] += 1
with open(gt_file, 'w') as f:
f.write(json.dumps(ground_truth_data))
else:
fp[idx] = 1
else:
fp[idx] = 1
if ovmax > 0:
status = "INSUFFICIENT OVERLAP"
# 计算召回率和精度
cumsum = 0
for idx, val in enumerate(fp):
fp[idx] += cumsum
cumsum += val
cumsum = 0
for idx, val in enumerate(tp):
tp[idx] += cumsum
cumsum += val
rec = tp[:]
for idx, val in enumerate(tp):
rec[idx] = float(tp[idx]) / gt_counter_per_class[class_name]
prec = tp[:]
for idx, val in enumerate(tp):
prec[idx] = float(tp[idx]) / (fp[idx] + tp[idx])
# 计算AP
ap = voc_ap(rec[:], prec[:])
sum_AP += ap
text = "{0:.2f}%".format(ap*100) + " = " + class_name + " AP "
# 结果写入results.txt文件中
rounded_prec = [ '%.2f' % elem for elem in prec ]
rounded_rec = [ '%.2f' % elem for elem in rec ]
results_file.write(text + "\n Precision: " + str(rounded_prec) + "\n Recall :" + str(rounded_rec) + "\n\n")
n_images = counter_images_per_class[class_name]
results_file.write("\n# mAP of all classes\n")
# 计算mAP
mAP = sum_AP / n_classes
text = "mAP = {0:.2f}%".format(mAP*100)
results_file.write(text + "\n")
shutil.rmtree(TEMP_FILES_PATH)
pred_counter_per_class = {
}
for txt_file in predicted_files_list:
f = open(txt_file)
lines_list = f.readlines()
f.close()
lines_list = [x.strip() for x in lines]
for line in lines_list:
class_name = line.split()[0]
if class_name in pred_counter_per_class:
pred_counter_per_class[class_name] += 1
else:
pred_counter_per_class[class_name] = 1
pred_classes = list(pred_counter_per_class.keys())
with open(results_files_path + "/results.txt", 'a') as results_file:
results_file.write("\n# Number of ground-truth objects per class\n")
for class_name in sorted(gt_counter_per_class):
results_file.write(class_name + ": " + str(gt_counter_per_class[class_name]) + "\n")
for class_name in pred_classes:
if class_name not in gt_classes:
count_true_positives[class_name] = 0
with open(results_files_path + "/results.txt", 'a') as results_file:
results_file.write("\n# Number of predicted objects per class\n")
for class_name in sorted(pred_classes):
n_pred = pred_counter_per_class[class_name]
text = class_name + ": " + str(n_pred)
text += " (tp:" + str(count_true_positives[class_name]) + ""
text += ", fp:" + str(n_pred - count_true_positives[class_name]) + ")\n"
results_file.write(text)
results_file.close()
代码文件下载地址