deeplabv3开源工程详解(1)—— 开源模型测试自己的图片

前言

deeplabv3是当前较为常用的语义分割的神经网络,且整个训练工程已经全部开源,使用公布的模型进行测试或基于自己的训练都可以得到一个较好的结果。

deeplabv3开源工程详解(1)—— 开源模型测试自己的图片
deeplabv3开源工程详解(2)—— 使用自己的数据集进行训练、迁移学习
deeplabv3开源工程(3)—— 报错:2 root error(s) found. (0) Invalid argument: padded_shape[0]=168 is not…

1 工程准备

环境

【TITAN XP】+【Ubuntu】+【tensorflow-gpu-1.14】+【cuda10.2】
(使用1.8.0及以上版本,低版本缺少函数,会报错)

下载包

  • 下载工程并解压
    https://github.com/tensorflow/models
    解压后,deeplabv3的工程在路径【./research/deeplab】下。
  • 下载deeplabv3的开源的权重
    https://github.com/tensorflow/models/blob/master/research/deeplab/g3doc/model_zoo.md
    在这个界面可以看到,基于PASCAL VOC数据集训练的模型、基于Cityscapes数据集训练的模型。本篇使用的是前者模型

    deeplabv3开源工程详解(1)—— 开源模型测试自己的图片_第1张图片
  • 进入 ./research/deeplab 路径下,创建model文件夹,将下载的权重文件解压在该文件夹下。

    deeplabv3开源工程详解(1)—— 开源模型测试自己的图片_第2张图片
    可以看见每个文件夹下有相同命名的三个文件。
    其中,pb文件用于实际测试;ckpt文件用于预训练模型。

    deeplabv3开源工程详解(1)—— 开源模型测试自己的图片_第3张图片

2 deeplabv3开源模型测试

官方给出的测试实际数据的脚本为:deeplab_demo.ipynb,代码易懂。
但该脚本会自己下载模型并进行图片预测,并且读取的模型的压缩包的文件。想要使用到自己的工程中,并不是最简洁的形式。

为了方便实际的使用测试,自己编写了测试脚本。

  • 代码主要分为2个部分:模型的加载和预测、预测结果的可视化
  • 使用具体的模型测试自己的图片,代码中需要修改的内容是模型的路径、图片的路径。

from matplotlib import gridspec
from matplotlib import pyplot as plt
import numpy as np
from PIL import Image
import tensorflow as tf


class DeepLabModel(object):
 """Class to load deeplab model and run inference."""

 INPUT_TENSOR_NAME = 'ImageTensor:0'
 OUTPUT_TENSOR_NAME = 'SemanticPredictions:0'
 INPUT_SIZE = 513
 FROZEN_GRAPH_NAME = 'frozen_inference_graph'

 def __init__(self, model_path):
   """Creates and loads pretrained deeplab model."""
   self.graph = tf.Graph()

   print(model_path)
   with tf.gfile.FastGFile(model_path, 'rb') as f:
     frozen_graph_def = tf.GraphDef()
     frozen_graph_def.ParseFromString(f.read())
   if frozen_graph_def is None:
     raise RuntimeError('Cannot find inference graph.')

   with self.graph.as_default():
     tf.import_graph_def(frozen_graph_def, name='')
   print('model loaded successfully!')

   self.sess = tf.Session(graph=self.graph)

 def run(self, image):

   width, height = image.size
   resize_ratio = 1.0 * self.INPUT_SIZE / max(width, height)
   target_size = (int(resize_ratio * width), int(resize_ratio * height))
   resized_image = image.convert('RGB').resize(target_size, Image.ANTIALIAS)

   batch_seg_map = self.sess.run(
       self.OUTPUT_TENSOR_NAME,
       feed_dict={self.INPUT_TENSOR_NAME: [np.asarray(resized_image)]})
   seg_map = batch_seg_map[0]
   return resized_image, seg_map


def label_to_color_image(label):
 if label.ndim != 2:
   raise ValueError('Expect 2-D input label')
 
 def create_pascal_label_colormap():
   colormap = np.zeros((256, 3), dtype=int)
   ind = np.arange(256, dtype=int)
   for shift in reversed(range(8)):
     for channel in range(3):
       colormap[:, channel] |= ((ind >> channel) & 1) << shift
     ind >>= 3
   return colormap

 colormap = create_pascal_label_colormap()
 if np.max(label) >= len(colormap):
   raise ValueError('label value too large.')

 return colormap[label]


def vis_segmentation(image, seg_map):
 """Visualizes input image, segmentation map and overlay view."""
 plt.figure(figsize=(15, 5))
 grid_spec = gridspec.GridSpec(1, 4, width_ratios=[6, 6, 6, 1])

 plt.subplot(grid_spec[0])
 plt.imshow(image)
 plt.axis('off')
 plt.title('input image')

 plt.subplot(grid_spec[1])
 seg_image = label_to_color_image(seg_map).astype(np.uint8)
 plt.imshow(seg_image)
 plt.axis('off')
 plt.title('segmentation map')

 plt.subplot(grid_spec[2])
 plt.imshow(image)
 plt.imshow(seg_image, alpha=0.7)
 plt.axis('off')
 plt.title('segmentation overlay')

 unique_labels = np.unique(seg_map)
 ax = plt.subplot(grid_spec[3])
 plt.imshow(
     FULL_COLOR_MAP[unique_labels].astype(np.uint8), >interpolation='nearest')
 ax.yaxis.tick_right()
 plt.yticks(range(len(unique_labels)), LABEL_NAMES[unique_labels])
 plt.xticks([], [])
 ax.tick_params(width=0.0)
 plt.grid('off')
 plt.show()

####================================= Select a pretrained model ===============================
#deeplab_demo.ipynb 中 提供的训练的模型命名与对应存放的路径
# {
#     'mobilenetv2_coco_voctrainaug':
#         'deeplabv3_mnv2_pascal_train_aug_2018_01_29.tar.gz',
#     'mobilenetv2_coco_voctrainval':
#         'deeplabv3_mnv2_pascal_trainval_2018_01_29.tar.gz',
#     'xception_coco_voctrainaug':
#         'deeplabv3_pascal_train_aug_2018_01_04.tar.gz',
#     'xception_coco_voctrainval':
#         'deeplabv3_pascal_trainval_2018_01_04.tar.gz',
# }


if __name__ == '__main__':

 LABEL_NAMES = np.asarray([
     'background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus',
     'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike',
     'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tv'
 ])
 FULL_LABEL_MAP = np.arange(len(LABEL_NAMES)).reshape(len(LABEL_NAMES), 1)
 FULL_COLOR_MAP = label_to_color_image(FULL_LABEL_MAP)

 model_select = "deeplabv3_pascal_train_aug"
 model_path = "./model/{}/frozen_inference_graph.pb".format(model_select)
 MODEL = DeepLabModel(model_path)

 original_im = Image.open("./1107_3.png")
 resized_im, seg_map = MODEL.run(original_im)
 vis_segmentation(resized_im, seg_map)


预测结果展示:
deeplabv3开源工程详解(1)—— 开源模型测试自己的图片_第4张图片

你可能感兴趣的:(deeplab系列,工程配置)