DeepLab 源码分析之 deeplab_demo.ipynb

我们首先从 deeplab_demo.ipynb 开始分析

首先 import 必要的库

from io import BytesIO  
import tarfile   # 处理tar压缩包
import tempfile  # 用于创建临时文件
from six.moves import urllib  # 下载

from matplotlib import gridspec  # 绘图
from matplotlib import pyplot as plt # 绘图
import numpy as np
from PIL import Image  # 读图

import tensorflow as tf

定义 DeepLabModel
包含
1. __init__ 初始化计算图
2. run 图片输入计算图,运行计算图

class DeepLabModel(object):
  """加载deeplab模型,运行inference"""
  # 4个类变量
  INPUT_TENSOR_NAME = 'ImageTensor:0' # 计算图上输入的名称
  OUTPUT_TENSOR_NAME = 'SemanticPredictions:0' # 计算图上输出的名称
  INPUT_SIZE = 513   # 输入图片的大小(DeepLab采用513)
  FROZEN_GRAPH_NAME = 'frozen_inference_graph' # 导入的Graph储存名称

  def __init__(self, tarball_path):
    """加载与训练好的deeplab模型"""
    self.graph = tf.Graph()  # 创建空计算图

    graph_def = None   
    # 从tar压缩包中抽取出需要的计算图
    tar_file = tarfile.open(tarball_path)
    for tar_info in tar_file.getmembers():
      if self.FROZEN_GRAPH_NAME in os.path.basename(tar_info.name):
        file_handle = tar_file.extractfile(tar_info)
        # 将计算图赋给 graph_def 变量
        graph_def = tf.GraphDef.FromString(file_handle.read())
        break

    tar_file.close()

    # 如果graph_def 还是空  说明在tar包中没有找到FROZEN_GRAPH_NAME所说的计算图
    if graph_def is None:
      raise RuntimeError('Cannot find inference graph in tar archive.')

    # 将graph_def 导入给self.graph,即deeplab的计算图
    with self.graph.as_default():
      tf.import_graph_def(graph_def, name='')
    # 创建sess,是在 self.graph 图上的
    self.sess = tf.Session(graph=self.graph)

  def run(self, image):
    """单张图片的inference.

    Args:
      image: A PIL.Image object, raw input image. PIL.Image 打开的图像

    Returns:
      resized_image: RGB image resized from original input image.  resize之后的图片
      seg_map: Segmentation map of `resized_image`.  分割图
    """
    width, height = image.size  # 长宽
    resize_ratio = 1.0 * self.INPUT_SIZE / max(width, height)   # resize比例,长边resize到513
    target_size = (int(resize_ratio * width), int(resize_ratio * height))  # 长宽等比例resize
    resized_image = image.convert('RGB').resize(target_size, Image.ANTIALIAS)  # 执行resize
    batch_seg_map = self.sess.run(
        self.OUTPUT_TENSOR_NAME,
        feed_dict={self.INPUT_TENSOR_NAME: [np.asarray(resized_image)]})  # 运行计算图,得到的是 self.OUTPUT_TENSOR_NAME=SemanticPredictions:0
    seg_map = batch_seg_map[0] # 只有一张图,取出第[0]个 seg_map
    return resized_image, seg_map

定义 create_pascal_label_colormap 函数
该函数用于可视化的颜色,返回一个colormap

def create_pascal_label_colormap():
  """Creates a label colormap used in PASCAL VOC segmentation benchmark.

  Returns:
    A Colormap for visualizing segmentation results.
  """
  colormap = np.zeros((256, 3), dtype=int)
  ind = np.arange(256, dtype=int)

  for shift in reversed(range(8)):
    for channel in range(3):
      colormap[:, channel] |= ((ind >> channel) & 1) << shift
    ind >>= 3

  return colormap  # 形状[256, 3]

定义 label_to_color_image 函数
从 二维带整数类型 到 类型对应的二维带颜色的矩阵

def label_to_color_image(label):
  """Adds color defined by the dataset colormap to the label.

  Args:
    label: A 2D array with integer type, storing the segmentation label.

  Returns:
    result: A 2D array with floating type. The element of the array
      is the color indexed by the corresponding element in the input label
      to the PASCAL color map.

  Raises:
    ValueError: If label is not of rank 2 or its value is larger than color
      map maximum entry.
  """
  if label.ndim != 2:
    raise ValueError('Expect 2-D input label')
  # 用上面的函数产生 colormap
  colormap = create_pascal_label_colormap()

  if np.max(label) >= len(colormap):
    raise ValueError('label value too large.')

  return colormap[label]

函数 vis_segmentation
可视化分割图片 一行四列

def vis_segmentation(image, seg_map):
  """Visualizes input image, segmentation map and overlay view."""
  plt.figure(figsize=(15, 5))
  grid_spec = gridspec.GridSpec(1, 4, width_ratios=[6, 6, 6, 1])
  # 第一列为 原图像
  plt.subplot(grid_spec[0])
  plt.imshow(image)
  plt.axis('off')
  plt.title('input image')
  # 第二列为 分割图
  plt.subplot(grid_spec[1])
  seg_image = label_to_color_image(seg_map).astype(np.uint8)
  plt.imshow(seg_image)
  plt.axis('off')
  plt.title('segmentation map')
  # 第三列为 原图 半透明分割图 overlap
  plt.subplot(grid_spec[2])
  plt.imshow(image)
  plt.imshow(seg_image, alpha=0.7)
  plt.axis('off')
  plt.title('segmentation overlay')
  # 第四列列出颜色对应的标签
  unique_labels = np.unique(seg_map)
  ax = plt.subplot(grid_spec[3])
  plt.imshow(
      FULL_COLOR_MAP[unique_labels].astype(np.uint8), interpolation='nearest')
  ax.yaxis.tick_right()
  plt.yticks(range(len(unique_labels)), LABEL_NAMES[unique_labels])
  plt.xticks([], [])
  ax.tick_params(width=0.0)
  plt.grid('off')
  plt.show()

其他类别

# Pascal VOC 对应的类别名称
LABEL_NAMES = np.asarray([
    'background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus',
    'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike',
    'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tv'
])
# 类别对应id  如 background 对应 0
FULL_LABEL_MAP = np.arange(len(LABEL_NAMES)).reshape(len(LABEL_NAMES), 1)
# 类别对应的颜色图   每个类别对应一种颜色
FULL_COLOR_MAP = label_to_color_image(FULL_LABEL_MAP)
# 模型名(可选)
MODEL_NAME = 'mobilenetv2_coco_voctrainaug'  # @param ['mobilenetv2_coco_voctrainaug', 'mobilenetv2_coco_voctrainval', 'xception_coco_voctrainaug', 'xception_coco_voctrainval']
# 下载地址前缀
_DOWNLOAD_URL_PREFIX = 'http://download.tensorflow.org/models/'
# 各模型对应下载地址
_MODEL_URLS = {
    'mobilenetv2_coco_voctrainaug':
        'deeplabv3_mnv2_pascal_train_aug_2018_01_29.tar.gz',
    'mobilenetv2_coco_voctrainval':
        'deeplabv3_mnv2_pascal_trainval_2018_01_29.tar.gz',
    'xception_coco_voctrainaug':
        'deeplabv3_pascal_train_aug_2018_01_04.tar.gz',
    'xception_coco_voctrainval':
        'deeplabv3_pascal_trainval_2018_01_04.tar.gz',
}
_TARBALL_NAME = 'deeplab_model.tar.gz'
# 模型本地存储地址  多次运行最好换成本地地址 model_dir = /path/to/your/dest/
model_dir = tempfile.mkdtemp()
tf.gfile.MakeDirs(model_dir)
# 下载
download_path = os.path.join(model_dir, _TARBALL_NAME)
print('downloading model, this might take a while...')
urllib.request.urlretrieve(_DOWNLOAD_URL_PREFIX + _MODEL_URLS[MODEL_NAME],
                   download_path)
print('download completed! loading DeepLab model...')
# 船舰DeepLab模型
MODEL = DeepLabModel(download_path)
print('model loaded successfully!')

跑图片 示例代码 本地图片可以看下一个框

# 图片名
SAMPLE_IMAGE = 'image2'  # @param ['image1', 'image2', 'image3']
IMAGE_URL = ''  #@param {type:"string"}

# 示例图片下载地址
_SAMPLE_URL = ('https://github.com/tensorflow/models/blob/master/research/'
               'deeplab/g3doc/img/%s.jpg?raw=true')


def run_visualization(url):
  """Inferences DeepLab model and visualizes result. 预测分割图并,可视化结果"""
  # 下载示例图
  try:
    f = urllib.request.urlopen(url)
    jpeg_str = f.read()
    orignal_im = Image.open(BytesIO(jpeg_str))
  except IOError:
    print('Cannot retrieve image. Please check url: ' + url)
    return

  # 运行
  print('running deeplab on image %s...' % url)
  resized_im, seg_map = MODEL.run(orignal_im)
  # 用上面的函数可视化
  vis_segmentation(resized_im, seg_map)


image_url = IMAGE_URL or _SAMPLE_URL % SAMPLE_IMAGE
run_visualization(image_url)

如果是本地图片,去掉下载等步骤

# 本地图片地址
IMAGE_PATH = /path/to/your/image

def run_visualization(path):
    oringnal_im = Image.open(path)
    print('running deeplab on image %s...' % path)
    resized_im, seg_map = MODEL.run(orignal_im)
    vis_segmentation(resized_im, seg_map)
运行整个过程
run_visualization(IMAGE_PATH)

你可能感兴趣的:(DeepLab,Segmentation,DeepLab,Segmentation)