测试中遇到的问题,直接上代码,核心在53-55行,先通过cvtcolor转换通道,然后用PIL从numpy转换为PIL类型图像
import numpy as np
import os
import six.moves.urllib as urllib
import sys
import tarfile
import tensorflow as tf
import zipfile
from collections import defaultdict
from io import StringIO
from matplotlib import pyplot as plt
from PIL import Image
sys.path.append("..")
from utils import label_map_util
from utils import visualization_utils as vis_util
import cv2
from timeit import default_timer as timer
PATH_TO_CKPT = '/media/wxy/0007C67F000018F2/SSD_HHC/output92001/frozen_inference_graph.pb'
PATH_TO_LABELS = '/media/wxy/0007C67F000018F2/SSD_HHC/dataset/pascal_label_map.pbtxt'
#os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
#os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
NUM_CLASSES =2
detection_graph = tf.Graph()
with detection_graph.as_default():
od_graph_def = tf.GraphDef()
with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
serialized_graph = fid.read()
od_graph_def.ParseFromString(serialized_graph)
tf.import_graph_def(od_graph_def, name='')
label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES, use_display_name=True)
category_index = label_map_util.create_category_index(categories)
print(category_index)
'''
def load_image_into_numpy_array(image):
#image = Image.merge('RGB',(image,image,image))
img = np.array(image)
#print(image.size)
return img
'''
IMAGE_SIZE = (12, 8)
with detection_graph.as_default():
with tf.Session(graph=detection_graph) as sess:
for image_path in os.listdir('SSD_image'):
start = timer()
image = Image.open('./SSD_image'+'/'+image_path)
image_np1 = np.array(image)
image1=cv2.cvtColor(image_np1,cv2.COLOR_GRAY2BGR)
pilimg=Image.fromarray(image1)
image_np = np.array(pilimg)
image_np_expanded = np.expand_dims(image_np, axis=0)
image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
scores = detection_graph.get_tensor_by_name('detection_scores:0')
classes = detection_graph.get_tensor_by_name('detection_classes:0')
num_detections = detection_graph.get_tensor_by_name('num_detections:0')
(boxes, scores, classes, num_detections) = sess.run(
[boxes, scores, classes, num_detections],
feed_dict={image_tensor: image_np_expanded})
end = timer()
vis_util.visualize_boxes_and_labels_on_image_array(
image_np, np.squeeze(boxes),
np.squeeze(classes).astype(np.int32),
np.squeeze(scores),
category_index,
use_normalized_coordinates=True,
line_thickness=8)
(r, g, b)=cv2.split(image_np)
image_np=cv2.merge([b,g,r])
print(end - start)
cv2.imwrite('./SSD_result_920'+'/'+image_path,image_np)