使用tensorflow预训练模型进行物体识别(1)

 

概述

物体识别作为计算机视觉领域的一个典型任务,已经有很多成熟的理论与解决方案。本文主要介绍tensorflow的物体识别API的使用。该API提供了了很多预训练模型。可以让我们快速构建自己的物体识别系统。本文将分几个部分充分展示tensorflow object detect API的使用。本文主要展示物体识别的demo,后面的文章将展开说明如何使用该API来训练自己的的数据。

1、TensorFlow对象检测API安装

1.1 使用git下载代码库

当然在安装该API的前提需要安装好CPU或GPU版本的tensorflow建议TF版本使用2.2以上。注意下面代码实在jupyter notebook中执行的,获益会有 ! git 的方式来运行shell命令。

import os
import pathlib

# Clone the tensorflow models repository if it doesn't already exist
if "models" in pathlib.Path.cwd().parts:
    while "models" in pathlib.Path.cwd().parts:
        os.chdir('..')
elif not pathlib.Path('models').exists():
    !git clone --depth 1 https://github.com/tensorflow/models

1.2 安装API

安装该API需要使用到proc库,需要提前安装,ubuntu中的相关shell 代码如下:

! apt install -y protobuf-compiler

如果需要在其他操作系统中安装请参照下面链接

http://google.github.io/proto-lens/installing-protoc.html

执行下面shell代码进行TF object detect的安装

cd models/research/
protoc object_detection/protos/*.proto --python_out=.
cp object_detection/packages/tf2/setup.py .
python -m pip install .

1.3 验证安装

可执行下面代码验证安装是否成功

python object_detection/builders/model_builder_tf2_test.py

如果看到类似下面的输出则说明安装成功

使用tensorflow预训练模型进行物体识别(1)_第1张图片

2、体验预训练模型的效果

在本文中并不涉及到训练自己的数据,这在后面的章节中会详细说明。下面我们将看一下预训练模型的效果展示。

2.1 下载测试图像

首先,我们将下载将在本教程中使用的图像。下方显示的代码段将从TensorFlow Model Garden下载测试图像 并将其保存在data/images文件夹中。相关下载代码如下

import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'    # Suppress TensorFlow logging (1)
import pathlib
import tensorflow as tf

tf.get_logger().setLevel('ERROR')           # Suppress TensorFlow logging (2)

# Enable GPU dynamic memory allocation
gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)

def download_images():
    base_url = 'https://raw.githubusercontent.com/tensorflow/models/master/research/object_detection/test_images/'
    filenames = ['image1.jpg', 'image2.jpg']
    image_paths = []
    for filename in filenames:
        image_path = tf.keras.utils.get_file(fname=filename,
                                            origin=base_url + filename,
                                            untar=False)
        image_path = pathlib.Path(image_path)
        image_paths.append(str(image_path))
    return image_paths

IMAGE_PATHS = download_images()

注意:有时候由于网络原因我们并不能下载到上面的图像。可以直接使用下面的图像直接另存。

下面有三幅图像名称分别为imge1.jpg、imge2.jpg、imge3.jpg其中第一张与第二张图片是从上面的链接下载的,后面的一张图片是一张关于货架的图片。提供该图片的目的是来看一下预训练模型分泛化能力。因为所有的模型不管模型结果如何都是从COCO数据集训练的来的。该数据集包含的内容比较多但并不包含商品的部分数据。下面三张图像的大小都是1024*1024

使用tensorflow预训练模型进行物体识别(1)_第2张图片

使用tensorflow预训练模型进行物体识别(1)_第3张图片

使用tensorflow预训练模型进行物体识别(1)_第4张图片

如果下载不成功可以直接提供三个图片的地址,替代代码如下:

IMAGE_PATHS=['data/images/imge1.jpg','data/images/imge2.jpg','data/images/imge3.jpg']

2.2 下载模型参数

下面显示的代码段用于下载我们将用于执行推理的预训练对象检测模型。我们将使用的特定检测算法是 CenterNet HourGlass104 1024x1024。在TensorFlow 2 Detection Model Zoo中可以找到更多模型。要使用其他模型,您将需要特定模型的URL名称。可以按照以下步骤进行操作:

  1. 右键单击您要使用的模型的模型名称;

  2. 单击复制链接地址以复制模型的下载链接;

  3. 将链接粘贴到您选择的文本编辑器中。您应该观察到类似于的链接download.tensorflow.org/models/object_detection/tf2/YYYYYYYY/XXXXXXXXX.tar.gz。

  4. 复制XXXXXXXXX链接的一部分,并用它替换MODEL_NAME下面所示代码中变量的值;

  5. 复制YYYYYYYY链接的一部分,并用它替换MODEL_DATE下面所示代码中变量的值

# Download and extract model
def download_model(model_name, model_date):
    base_url = 'http://download.tensorflow.org/models/object_detection/tf2/'
    model_file = model_name + '.tar.gz'
    model_dir = tf.keras.utils.get_file(fname=model_name,
                                        origin=base_url + model_date + '/' + model_file,
                                        untar=True)
    return str(model_dir)

MODEL_DATE = '20200711'
MODEL_NAME = 'centernet_hg104_1024x1024_coco17_tpu-32'
PATH_TO_MODEL_DIR = download_model(MODEL_NAME, MODEL_DATE)

2.3 下载标签数据

标签数据是对分类数据的文本说明。下面显示的库德代码段用于下载标签文件(.pbtxt),该文件包含用于向每个检测(例如人)添加正确标签的字符串列表。由于我们将使用的预训练模型已经在COCO数据集上进行了训练,因此我们需要下载与该数据集相对应的标签文件,名为mscoco_label_map.pbtxt。TensorFlow Models Garden中包含的标签文件的完整列表可在此处https://github.com/tensorflow/models/tree/master/research/object_detection/data 找到。

 def download_labels(filename):
     base_url = 'https://raw.githubusercontent.com/tensorflow/models/master/research/object_detection/data/'
     label_dir = tf.keras.utils.get_file(fname=filename,
                                         origin=base_url + filename,
                                         untar=False)
     label_dir = pathlib.Path(label_dir)
     return str(label_dir)

LABEL_FILENAME = 'mscoco_label_map.pbtxt'
PATH_TO_LABELS = download_labels(LABEL_FILENAME)

该文件同样会出现由于网络原因无法加载的情况可以直接使用下面文件

item {
  name: "/m/01g317"
  id: 1
  display_name: "person"
}
item {
  name: "/m/0199g"
  id: 2
  display_name: "bicycle"
}
item {
  name: "/m/0k4j"
  id: 3
  display_name: "car"
}
item {
  name: "/m/04_sv"
  id: 4
  display_name: "motorcycle"
}
item {
  name: "/m/05czz6l"
  id: 5
  display_name: "airplane"
}
item {
  name: "/m/01bjv"
  id: 6
  display_name: "bus"
}
item {
  name: "/m/07jdr"
  id: 7
  display_name: "train"
}
item {
  name: "/m/07r04"
  id: 8
  display_name: "truck"
}
item {
  name: "/m/019jd"
  id: 9
  display_name: "boat"
}
item {
  name: "/m/015qff"
  id: 10
  display_name: "traffic light"
}
item {
  name: "/m/01pns0"
  id: 11
  display_name: "fire hydrant"
}
item {
  name: "/m/02pv19"
  id: 13
  display_name: "stop sign"
}
item {
  name: "/m/015qbp"
  id: 14
  display_name: "parking meter"
}
item {
  name: "/m/0cvnqh"
  id: 15
  display_name: "bench"
}
item {
  name: "/m/015p6"
  id: 16
  display_name: "bird"
}
item {
  name: "/m/01yrx"
  id: 17
  display_name: "cat"
}
item {
  name: "/m/0bt9lr"
  id: 18
  display_name: "dog"
}
item {
  name: "/m/03k3r"
  id: 19
  display_name: "horse"
}
item {
  name: "/m/07bgp"
  id: 20
  display_name: "sheep"
}
item {
  name: "/m/01xq0k1"
  id: 21
  display_name: "cow"
}
item {
  name: "/m/0bwd_0j"
  id: 22
  display_name: "elephant"
}
item {
  name: "/m/01dws"
  id: 23
  display_name: "bear"
}
item {
  name: "/m/0898b"
  id: 24
  display_name: "zebra"
}
item {
  name: "/m/03bk1"
  id: 25
  display_name: "giraffe"
}
item {
  name: "/m/01940j"
  id: 27
  display_name: "backpack"
}
item {
  name: "/m/0hnnb"
  id: 28
  display_name: "umbrella"
}
item {
  name: "/m/080hkjn"
  id: 31
  display_name: "handbag"
}
item {
  name: "/m/01rkbr"
  id: 32
  display_name: "tie"
}
item {
  name: "/m/01s55n"
  id: 33
  display_name: "suitcase"
}
item {
  name: "/m/02wmf"
  id: 34
  display_name: "frisbee"
}
item {
  name: "/m/071p9"
  id: 35
  display_name: "skis"
}
item {
  name: "/m/06__v"
  id: 36
  display_name: "snowboard"
}
item {
  name: "/m/018xm"
  id: 37
  display_name: "sports ball"
}
item {
  name: "/m/02zt3"
  id: 38
  display_name: "kite"
}
item {
  name: "/m/03g8mr"
  id: 39
  display_name: "baseball bat"
}
item {
  name: "/m/03grzl"
  id: 40
  display_name: "baseball glove"
}
item {
  name: "/m/06_fw"
  id: 41
  display_name: "skateboard"
}
item {
  name: "/m/019w40"
  id: 42
  display_name: "surfboard"
}
item {
  name: "/m/0dv9c"
  id: 43
  display_name: "tennis racket"
}
item {
  name: "/m/04dr76w"
  id: 44
  display_name: "bottle"
}
item {
  name: "/m/09tvcd"
  id: 46
  display_name: "wine glass"
}
item {
  name: "/m/08gqpm"
  id: 47
  display_name: "cup"
}
item {
  name: "/m/0dt3t"
  id: 48
  display_name: "fork"
}
item {
  name: "/m/04ctx"
  id: 49
  display_name: "knife"
}
item {
  name: "/m/0cmx8"
  id: 50
  display_name: "spoon"
}
item {
  name: "/m/04kkgm"
  id: 51
  display_name: "bowl"
}
item {
  name: "/m/09qck"
  id: 52
  display_name: "banana"
}
item {
  name: "/m/014j1m"
  id: 53
  display_name: "apple"
}
item {
  name: "/m/0l515"
  id: 54
  display_name: "sandwich"
}
item {
  name: "/m/0cyhj_"
  id: 55
  display_name: "orange"
}
item {
  name: "/m/0hkxq"
  id: 56
  display_name: "broccoli"
}
item {
  name: "/m/0fj52s"
  id: 57
  display_name: "carrot"
}
item {
  name: "/m/01b9xk"
  id: 58
  display_name: "hot dog"
}
item {
  name: "/m/0663v"
  id: 59
  display_name: "pizza"
}
item {
  name: "/m/0jy4k"
  id: 60
  display_name: "donut"
}
item {
  name: "/m/0fszt"
  id: 61
  display_name: "cake"
}
item {
  name: "/m/01mzpv"
  id: 62
  display_name: "chair"
}
item {
  name: "/m/02crq1"
  id: 63
  display_name: "couch"
}
item {
  name: "/m/03fp41"
  id: 64
  display_name: "potted plant"
}
item {
  name: "/m/03ssj5"
  id: 65
  display_name: "bed"
}
item {
  name: "/m/04bcr3"
  id: 67
  display_name: "dining table"
}
item {
  name: "/m/09g1w"
  id: 70
  display_name: "toilet"
}
item {
  name: "/m/07c52"
  id: 72
  display_name: "tv"
}
item {
  name: "/m/01c648"
  id: 73
  display_name: "laptop"
}
item {
  name: "/m/020lf"
  id: 74
  display_name: "mouse"
}
item {
  name: "/m/0qjjc"
  id: 75
  display_name: "remote"
}
item {
  name: "/m/01m2v"
  id: 76
  display_name: "keyboard"
}
item {
  name: "/m/050k8"
  id: 77
  display_name: "cell phone"
}
item {
  name: "/m/0fx9l"
  id: 78
  display_name: "microwave"
}
item {
  name: "/m/029bxz"
  id: 79
  display_name: "oven"
}
item {
  name: "/m/01k6s3"
  id: 80
  display_name: "toaster"
}
item {
  name: "/m/0130jx"
  id: 81
  display_name: "sink"
}
item {
  name: "/m/040b_t"
  id: 82
  display_name: "refrigerator"
}
item {
  name: "/m/0bt_c3"
  id: 84
  display_name: "book"
}
item {
  name: "/m/01x3z"
  id: 85
  display_name: "clock"
}
item {
  name: "/m/02s195"
  id: 86
  display_name: "vase"
}
item {
  name: "/m/01lsmm"
  id: 87
  display_name: "scissors"
}
item {
  name: "/m/0kmg4"
  id: 88
  display_name: "teddy bear"
}
item {
  name: "/m/03wvsk"
  id: 89
  display_name: "hair drier"
}
item {
  name: "/m/012xff"
  id: 90
  display_name: "toothbrush"
}

2.4  加载模型参数并运行

加载模型参数

import time
from object_detection.utils import label_map_util
from object_detection.utils import visualization_utils as viz_utils

PATH_TO_SAVED_MODEL = PATH_TO_MODEL_DIR + "/saved_model"

print('Loading model...', end='')
start_time = time.time()

# Load saved model and build the detection function
detect_fn = tf.saved_model.load(PATH_TO_SAVED_MODEL)

end_time = time.time()
elapsed_time = end_time - start_time
print('Done! Took {} seconds'.format(elapsed_time))

加载标签图数据(用于绘图)

category_index = label_map_util.create_category_index_from_labelmap(PATH_TO_LABELS,use_display_name=True)

进行物体检测

import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')   # Suppress Matplotlib warnings
%matplotlib notebook
def load_image_into_numpy_array(path):
    """Load an image from file into a numpy array.

    Puts image into numpy array to feed into tensorflow graph.
    Note that by convention we put it into a numpy array with shape
    (height, width, channels), where channels=3 for RGB.

    Args:
      path: the file path to the image

    Returns:
      uint8 numpy array with shape (img_height, img_width, 3)
    """
    return np.array(Image.open(path))


for image_path in IMAGE_PATHS:

    print('Running inference for {}... '.format(image_path), end='')

    image_np = load_image_into_numpy_array(image_path)

    # Things to try:
    # Flip horizontally
    # image_np = np.fliplr(image_np).copy()

    # Convert image to grayscale
    # image_np = np.tile(
    #     np.mean(image_np, 2, keepdims=True), (1, 1, 3)).astype(np.uint8)

    # The input needs to be a tensor, convert it using `tf.convert_to_tensor`.
    input_tensor = tf.convert_to_tensor(image_np)
    # The model expects a batch of images, so add an axis with `tf.newaxis`.
    input_tensor = input_tensor[tf.newaxis, ...]

    # input_tensor = np.expand_dims(image_np, 0)
    detections = detect_fn(input_tensor)

    # All outputs are batches tensors.
    # Convert to numpy arrays, and take index [0] to remove the batch dimension.
    # We're only interested in the first num_detections.
    num_detections = int(detections.pop('num_detections'))
    detections = {key: value[0, :num_detections].numpy()
                   for key, value in detections.items()}
    detections['num_detections'] = num_detections

    # detection_classes should be ints.
    detections['detection_classes'] = detections['detection_classes'].astype(np.int64)

    image_np_with_detections = image_np.copy()

    viz_utils.visualize_boxes_and_labels_on_image_array(
          image_np_with_detections,
          detections['detection_boxes'],
          detections['detection_classes'],
          detections['detection_scores'],
          category_index,
          use_normalized_coordinates=True,
          max_boxes_to_draw=200,
          min_score_thresh=.30,
          agnostic_mode=False)

    plt.figure()
    plt.imshow(image_np_with_detections)
    print('Done')
plt.show()

效果如下

使用tensorflow预训练模型进行物体识别(1)_第5张图片使用tensorflow预训练模型进行物体识别(1)_第6张图片

使用tensorflow预训练模型进行物体识别(1)_第7张图片

大家注意在第三幅图片中程序识别出了大多数的瓶子,说明模型具有一定的泛化能力。我们可以在该模型的基础上训练自己的模型。下一篇文章将详细说明具体步骤

你可能感兴趣的:(机器学习与人工智能,新版tensorflow,人工智能,tensorflow,物体检测,计算机视觉)