本文实验环境:
!pip install -U --pre tensorflow=="2.*"!pip install pycocotools
下载tensorflow的模型
import osimport pathlibif "models" in pathlib.Path.cwd().parts: while "models" in pathlib.Path.cwd().parts: os.chdir('..')elif not pathlib.Path('models').exists(): !git clone --depth 1 https://github.com/tensorflow/models
%%bashcd models/research/protoc object_detection/protos/*.proto --python_out=.
%%bash cd models/researchpip install .
import numpy as npimport osimport six.moves.urllib as urllibimport sysimport tarfileimport tensorflow as tfimport zipfilefrom collections import defaultdictfrom io import StringIOfrom matplotlib import pyplot as pltfrom PIL import Imagefrom IPython.display import display
from object_detection.utils import ops as utils_opsfrom object_detection.utils import label_map_utilfrom object_detection.utils import visualization_utils as vis_util
tensorflow2.0的tf.gfile改到tf.io.gfile了,这里需要打个程序补丁
# patch tf1 into `utils.ops`utils_ops.tf = tf.compat.v1# Patch the location of gfiletf.gfile = tf.io.gfile
def load_model(model_name): base_url = 'http://download.tensorflow.org/models/object_detection/' model_file = model_name + '.tar.gz' model_dir = tf.keras.utils.get_file( fname=model_name, origin=base_url + model_file, untar=True) model_dir = pathlib.Path(model_dir)/"saved_model" model = tf.saved_model.load(str(model_dir)) model = model.signatures['serving_default'] return model
# List of the strings that is used to add correct label for each box.PATH_TO_LABELS = 'models/research/object_detection/data/mscoco_label_map.pbtxt'category_index = label_map_util.create_category_index_from_labelmap(PATH_TO_LABELS, use_display_name=True)
简单起见,这个文件夹里只放了2个图片进行测试。如果要测试自己的图像,就把路径加到TEST_IMAGE_PATHS里。
这两个原始图片是这样的:
.
# If you want to test the code with your images, just add path to the images to the TEST_IMAGE_PATHS.PATH_TO_TEST_IMAGES_DIR = pathlib.Path('models/research/object_detection/test_images')TEST_IMAGE_PATHS = sorted(list(PATH_TO_TEST_IMAGES_DIR.glob("*.jpg")))TEST_IMAGE_PATHS
model_name = 'ssd_mobilenet_v1_coco_2017_11_17'detection_model = load_model(model_name)
看看模型结构,它是3通道 unit8格式的图片:
输出的变量 output_dict 是个字典类型, 包含 :
def run_inference_for_single_image(model, image): image = np.asarray(image) # The input needs to be a tensor, convert it using `tf.convert_to_tensor`. input_tensor = tf.convert_to_tensor(image) # The model expects a batch of images, so add an axis with `tf.newaxis`. input_tensor = input_tensor[tf.newaxis,...] # Run inference output_dict = model(input_tensor) # All outputs are batches tensors. # Convert to numpy arrays, and take index [0] to remove the batch dimension. # We're only interested in the first num_detections. num_detections = int(output_dict.pop('num_detections')) output_dict = {key:value[0, :num_detections].numpy() for key,value in output_dict.items()} output_dict['num_detections'] = num_detections # detection_classes should be ints. output_dict['detection_classes'] = output_dict['detection_classes'].astype(np.int64) # Handle models with masks: if 'detection_masks' in output_dict: # Reframe the the bbox mask to the image size. detection_masks_reframed = utils_ops.reframe_box_masks_to_image_masks( output_dict['detection_masks'], output_dict['detection_boxes'], image.shape[0], image.shape[1]) detection_masks_reframed = tf.cast(detection_masks_reframed > 0.5, tf.uint8) output_dict['detection_masks_reframed'] = detection_masks_reframed.numpy() return output_dict
def show_inference(model, image_path): # the array based representation of the image will be used later in order to prepare the # result image with boxes and labels on it. image_np = np.array(Image.open(image_path)) # 调用检测函数 output_dict = run_inference_for_single_image(model, image_np) # 可视化. vis_util.visualize_boxes_and_labels_on_image_array( image_np, output_dict['detection_boxes'], output_dict['detection_classes'], output_dict['detection_scores'], category_index, instance_masks=output_dict.get('detection_masks_reframed', None), use_normalized_coordinates=True, line_thickness=8) display(Image.fromarray(image_np))
for image_path in TEST_IMAGE_PATHS: show_inference(detection_model, image_path)
.
model_name = "mask_rcnn_inception_resnet_v2_atrous_coco_2018_01_28"masking_model = load_model("mask_rcnn_inception_resnet_v2_atrous_coco_2018_01_28")
可以看出resnet_v2有更高的准确度。