我们首先从 deeplab_demo.ipynb
开始分析
首先 import 必要的库
from io import BytesIO
import tarfile # 处理tar压缩包
import tempfile # 用于创建临时文件
from six.moves import urllib # 下载
from matplotlib import gridspec # 绘图
from matplotlib import pyplot as plt # 绘图
import numpy as np
from PIL import Image # 读图
import tensorflow as tf
定义 DeepLabModel 类
包含
1. __init__
初始化计算图
2. run
图片输入计算图,运行计算图
class DeepLabModel(object):
"""加载deeplab模型,运行inference"""
# 4个类变量
INPUT_TENSOR_NAME = 'ImageTensor:0' # 计算图上输入的名称
OUTPUT_TENSOR_NAME = 'SemanticPredictions:0' # 计算图上输出的名称
INPUT_SIZE = 513 # 输入图片的大小(DeepLab采用513)
FROZEN_GRAPH_NAME = 'frozen_inference_graph' # 导入的Graph储存名称
def __init__(self, tarball_path):
"""加载与训练好的deeplab模型"""
self.graph = tf.Graph() # 创建空计算图
graph_def = None
# 从tar压缩包中抽取出需要的计算图
tar_file = tarfile.open(tarball_path)
for tar_info in tar_file.getmembers():
if self.FROZEN_GRAPH_NAME in os.path.basename(tar_info.name):
file_handle = tar_file.extractfile(tar_info)
# 将计算图赋给 graph_def 变量
graph_def = tf.GraphDef.FromString(file_handle.read())
break
tar_file.close()
# 如果graph_def 还是空 说明在tar包中没有找到FROZEN_GRAPH_NAME所说的计算图
if graph_def is None:
raise RuntimeError('Cannot find inference graph in tar archive.')
# 将graph_def 导入给self.graph,即deeplab的计算图
with self.graph.as_default():
tf.import_graph_def(graph_def, name='')
# 创建sess,是在 self.graph 图上的
self.sess = tf.Session(graph=self.graph)
def run(self, image):
"""单张图片的inference.
Args:
image: A PIL.Image object, raw input image. PIL.Image 打开的图像
Returns:
resized_image: RGB image resized from original input image. resize之后的图片
seg_map: Segmentation map of `resized_image`. 分割图
"""
width, height = image.size # 长宽
resize_ratio = 1.0 * self.INPUT_SIZE / max(width, height) # resize比例,长边resize到513
target_size = (int(resize_ratio * width), int(resize_ratio * height)) # 长宽等比例resize
resized_image = image.convert('RGB').resize(target_size, Image.ANTIALIAS) # 执行resize
batch_seg_map = self.sess.run(
self.OUTPUT_TENSOR_NAME,
feed_dict={self.INPUT_TENSOR_NAME: [np.asarray(resized_image)]}) # 运行计算图,得到的是 self.OUTPUT_TENSOR_NAME=SemanticPredictions:0
seg_map = batch_seg_map[0] # 只有一张图,取出第[0]个 seg_map
return resized_image, seg_map
定义 create_pascal_label_colormap 函数
该函数用于可视化的颜色,返回一个colormap
def create_pascal_label_colormap():
"""Creates a label colormap used in PASCAL VOC segmentation benchmark.
Returns:
A Colormap for visualizing segmentation results.
"""
colormap = np.zeros((256, 3), dtype=int)
ind = np.arange(256, dtype=int)
for shift in reversed(range(8)):
for channel in range(3):
colormap[:, channel] |= ((ind >> channel) & 1) << shift
ind >>= 3
return colormap # 形状[256, 3]
定义 label_to_color_image 函数
从 二维带整数类型 到 类型对应的二维带颜色的矩阵
def label_to_color_image(label):
"""Adds color defined by the dataset colormap to the label.
Args:
label: A 2D array with integer type, storing the segmentation label.
Returns:
result: A 2D array with floating type. The element of the array
is the color indexed by the corresponding element in the input label
to the PASCAL color map.
Raises:
ValueError: If label is not of rank 2 or its value is larger than color
map maximum entry.
"""
if label.ndim != 2:
raise ValueError('Expect 2-D input label')
# 用上面的函数产生 colormap
colormap = create_pascal_label_colormap()
if np.max(label) >= len(colormap):
raise ValueError('label value too large.')
return colormap[label]
函数 vis_segmentation
可视化分割图片 一行四列
def vis_segmentation(image, seg_map):
"""Visualizes input image, segmentation map and overlay view."""
plt.figure(figsize=(15, 5))
grid_spec = gridspec.GridSpec(1, 4, width_ratios=[6, 6, 6, 1])
# 第一列为 原图像
plt.subplot(grid_spec[0])
plt.imshow(image)
plt.axis('off')
plt.title('input image')
# 第二列为 分割图
plt.subplot(grid_spec[1])
seg_image = label_to_color_image(seg_map).astype(np.uint8)
plt.imshow(seg_image)
plt.axis('off')
plt.title('segmentation map')
# 第三列为 原图 半透明分割图 overlap
plt.subplot(grid_spec[2])
plt.imshow(image)
plt.imshow(seg_image, alpha=0.7)
plt.axis('off')
plt.title('segmentation overlay')
# 第四列列出颜色对应的标签
unique_labels = np.unique(seg_map)
ax = plt.subplot(grid_spec[3])
plt.imshow(
FULL_COLOR_MAP[unique_labels].astype(np.uint8), interpolation='nearest')
ax.yaxis.tick_right()
plt.yticks(range(len(unique_labels)), LABEL_NAMES[unique_labels])
plt.xticks([], [])
ax.tick_params(width=0.0)
plt.grid('off')
plt.show()
其他类别
# Pascal VOC 对应的类别名称
LABEL_NAMES = np.asarray([
'background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus',
'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike',
'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tv'
])
# 类别对应id 如 background 对应 0
FULL_LABEL_MAP = np.arange(len(LABEL_NAMES)).reshape(len(LABEL_NAMES), 1)
# 类别对应的颜色图 每个类别对应一种颜色
FULL_COLOR_MAP = label_to_color_image(FULL_LABEL_MAP)
# 模型名(可选)
MODEL_NAME = 'mobilenetv2_coco_voctrainaug' # @param ['mobilenetv2_coco_voctrainaug', 'mobilenetv2_coco_voctrainval', 'xception_coco_voctrainaug', 'xception_coco_voctrainval']
# 下载地址前缀
_DOWNLOAD_URL_PREFIX = 'http://download.tensorflow.org/models/'
# 各模型对应下载地址
_MODEL_URLS = {
'mobilenetv2_coco_voctrainaug':
'deeplabv3_mnv2_pascal_train_aug_2018_01_29.tar.gz',
'mobilenetv2_coco_voctrainval':
'deeplabv3_mnv2_pascal_trainval_2018_01_29.tar.gz',
'xception_coco_voctrainaug':
'deeplabv3_pascal_train_aug_2018_01_04.tar.gz',
'xception_coco_voctrainval':
'deeplabv3_pascal_trainval_2018_01_04.tar.gz',
}
_TARBALL_NAME = 'deeplab_model.tar.gz'
# 模型本地存储地址 多次运行最好换成本地地址 model_dir = /path/to/your/dest/
model_dir = tempfile.mkdtemp()
tf.gfile.MakeDirs(model_dir)
# 下载
download_path = os.path.join(model_dir, _TARBALL_NAME)
print('downloading model, this might take a while...')
urllib.request.urlretrieve(_DOWNLOAD_URL_PREFIX + _MODEL_URLS[MODEL_NAME],
download_path)
print('download completed! loading DeepLab model...')
# 船舰DeepLab模型
MODEL = DeepLabModel(download_path)
print('model loaded successfully!')
跑图片 示例代码 本地图片可以看下一个框
# 图片名
SAMPLE_IMAGE = 'image2' # @param ['image1', 'image2', 'image3']
IMAGE_URL = '' #@param {type:"string"}
# 示例图片下载地址
_SAMPLE_URL = ('https://github.com/tensorflow/models/blob/master/research/'
'deeplab/g3doc/img/%s.jpg?raw=true')
def run_visualization(url):
"""Inferences DeepLab model and visualizes result. 预测分割图并,可视化结果"""
# 下载示例图
try:
f = urllib.request.urlopen(url)
jpeg_str = f.read()
orignal_im = Image.open(BytesIO(jpeg_str))
except IOError:
print('Cannot retrieve image. Please check url: ' + url)
return
# 运行
print('running deeplab on image %s...' % url)
resized_im, seg_map = MODEL.run(orignal_im)
# 用上面的函数可视化
vis_segmentation(resized_im, seg_map)
image_url = IMAGE_URL or _SAMPLE_URL % SAMPLE_IMAGE
run_visualization(image_url)
如果是本地图片,去掉下载等步骤
# 本地图片地址
IMAGE_PATH = /path/to/your/image
def run_visualization(path):
oringnal_im = Image.open(path)
print('running deeplab on image %s...' % path)
resized_im, seg_map = MODEL.run(orignal_im)
vis_segmentation(resized_im, seg_map)
运行整个过程
run_visualization(IMAGE_PATH)