本文将实现 deeplab v3 + 模型(参考:DeepLab 官方开源代码)
# -*- coding: utf-8 -*-
"""
Created on Mon Dec 3 17:57:46 2018
@author: shirhe-lyh
Implementation of DeepLab V3+:
Encoder-Decoder with atrous seperable convolutioon for semantic image
segmentation, Liang-Chieh Chen, et. al., arXiv:1802.02611v3.
"""
import numpy as np
import tensorflow as tf
from tensorflow.contrib.slim import nets
import preprocessing
import resnet_v1_beta
slim = tf.contrib.slim
class DeepLab(object):
"""Implementation of DeepLab V3+."""
def __init__(self,
is_training,
num_classes=3,
output_stride=16,
atrous_rates=[6, 12, 18], # [12, 24, 36] for output_stride=8
decoder_output_stride=4,
default_image_size=513,
fine_tune_batch_norm=False):
"""Constructor.
Args:
is_training: A boolean indicating whether the training version of
computation graph should be constructed.
num_classes: The number of classes.
defualt_image_size: The input size of the model.
"""
self._is_training = is_training
self._num_classes = num_classes
self._output_stride = output_stride
self._atrous_rates = atrous_rates
self._decoder_output_stride = decoder_output_stride
self._default_image_size = default_image_size
# When fine_tune_batch_norm=True, use at least batch size larger than
# 12 (batch size more than 16 is better). Otherwise, one could use
# smaller batch size and set fine_tune_batch_norm=False.
_is_training = is_training and fine_tune_batch_norm
self._batch_norm_params = {'is_training': _is_training,
'epsilon': 1e-5,
'decay': 0.9997,
'scale': True}
@property
def default_image_size(self):
return self._default_image_size
def preprocess(self, images=None, masks=None):
"""Preprocessing.
Args:
images: A float32 tensor with shape [batch_size, height, width,
3] representing a batch of images. Only passed values in case
of test (i.e., in training case images=None).
masks: A float32 tensor with shape [batch_size, height, width, 1]
representing a batch of groundtruth masks.
Returns:
The preprocessed inputs.
"""
preprocessed_dict = {'images': images_preprocessed,
'masks': trimaps_preprocessed}
return preprocessed_dict
def _preprocess_zero_mean_unit_range(self, inputs):
"""Map image values from [0, 255] to [-1, 1].
Only for beta version.
"""
return (2.0 / 255.0) * tf.to_float(inputs) - 1.0
def predict(self, preprocessed_inputs):
"""Predict prediction tensors from inputs tensor.
Outputs of this function can be passed to loss or postprocess functions.
Args:
preprocessed_inputs: A 4-D float32 tensor with shape [batch_size,
height, width, channels].
Returns:
The prediction tensors to be passed to the Loss or Postprocess
functions.
"""
# ResNet-50
with slim.arg_scope(nets.resnet_v1.resnet_arg_scope()):
net, end_points = resnet_v1_beta.resnet_v1_50_beta(
preprocessed_inputs, num_classes=None,
is_training=self._is_training,
multi_grid=[1, 2, 4],
global_pool=False,
output_stride=self._output_stride)
# Use the same scope with ResNet-50
scope='resnet_v1_50'
# Atrous spatial pyramid pooling
net = self._atrous_spatial_pyramid_pooling(
net, atrous_rates=self._atrous_rates, scope=scope)
# Refine by decoder
decoder_height = self.default_image_size // self._decoder_output_stride
decoder_width = self.default_image_size // self._decoder_output_stride
net = self._refine_by_decoder(
net,
end_points,
decoder_height=decoder_height,
decoder_width=decoder_width,
decoder_use_seperable_conv=True,
is_training=self._is_training)
# Convolution
net = self._get_branch_logits(net, self._num_classes,
self._atrous_rates, kernel_size=1)
net = tf.image.resize_bilinear(net, size=[self._default_image_size,
self._default_image_size],
align_corners=True,
name='upsampling_logits')
return net
def split_seperable_conv2d(self,
inputs,
filters,
kernel_size=3,
rate=1,
weight_decay=0.00004,
depthwise_weights_initializer_stddev=0.33,
pointwise_weights_initializer_stddev=0.06,
scope=None):
"""Splits a seperable conv2d into depthwise and pointwise conv2d.
This operation differs from `tf.layers.separable_conv2d` as this
operation applies activation function between depthwise and pointwise
conv2d.
Copy from:
https://github.com/tensorflow/models/blob/master/research/deeplab/
core/utils.py
Args:
inputs: Input tensor with shape [batch, height, width, channels].
filters: Number of filters in the 1x1 pointwise convolution.
kernel_size: A list of length 2: [kernel_height, kernel_width] of
of the filters. Can be an int if both values are the same.
rate: Atrous convolution rate for the depthwise convolution.
weight_decay: The weight decay to use for regularizing the model.
depthwise_weights_initializer_stddev: The standard deviation of the
truncated normal weight initializer for depthwise convolution.
pointwise_weights_initializer_stddev: The standard deviation of the
truncated normal weight initializer for pointwise convolution.
scope: Optional scope for the operation.
Returns:
Computed features after split separable conv2d.
"""
outputs = slim.separable_conv2d(
inputs,
None,
kernel_size=kernel_size,
depth_multiplier=1,
rate=rate,
weights_initializer=tf.truncated_normal_initializer(
stddev=depthwise_weights_initializer_stddev),
weights_regularizer=None,
scope=scope + '_depthwise')
return slim.conv2d(
outputs,
filters,
1,
weights_initializer=tf.truncated_normal_initializer(
stddev=pointwise_weights_initializer_stddev),
weights_regularizer=slim.l2_regularizer(weight_decay),
scope=scope + '_pointwise')
def _atrous_spatial_pyramid_pooling(self, feature_map, weight_decay=0.0001,
atrous_rates=[12, 24, 36],
scope='resnet_v1_50'):
"""Atrous spatial pyramid pooling for DeepLab v3."""
branch_nets = []
# Convolution
with tf.variable_scope(scope):
with slim.arg_scope([slim.conv2d, slim.separable_conv2d],
weights_regularizer=slim.l2_regularizer(
weight_decay),
normalizer_fn=slim.batch_norm,
normalizer_params=self._batch_norm_params):
depth=256
# Image pooling feature
shape = tf.shape(feature_map)[1:3]
image_feature = tf.reduce_mean(feature_map, axis=[1, 2],
keep_dims=True)
image_feature = slim.conv2d(image_feature, kernel_size=1,
num_outputs=depth,
scope='global_pool')
image_feature = tf.image.resize_bilinear(image_feature,
size=shape,
align_corners=True)
branch_nets.append(image_feature)
# Employ a 1x1 convolution
branch_nets.append(slim.conv2d(feature_map, kernel_size=1,
num_outputs=depth,
scope='aspp' + str(0)))
# Employ 3x3 convolutions with different atrous rates.
for i, rate in enumerate(atrous_rates, 1):
scope =scope + 'aspp' + str(i)
aspp_net = self.split_seperable_conv2d(
feature_map,
filters=depth,
rate=rate,
weight_decay=weight_decay,
scope=scope)
branch_nets.append(aspp_net)
# Concatenation
net = tf.concat(branch_nets, axis=3, name='aspp_concate')
net = slim.conv2d(net, depth, kernel_size=1,
scope=scope + '/concat_projection')
net = slim.dropout(net, keep_prob=0.9, is_training=self._is_training,
scope= scope + '/concat_projection_dropout')
return net
def _refine_by_decoder(self,
feature_map,
end_points,
decoder_height,
decoder_width,
decoder_use_seperable_conv=False,
weight_decay=0.0001,
reuse=None,
is_training=False,
scope='resnet_v1_50'):
"""Adds the decoder to obtain sharper segmentation results.
Args:
feature_map: A tensor with shape [batch_size, height, width, depth].
end_points: A dictionary from components of the network to the
corresponding activation.
decoder_height: The height of decoder feature maps.
decoder_width: The width of decoder feature maps.
decoder_use_seperable_conv: Employ seperable convolution for
decoder or not.
weight_decay: The weight decay for model variables.
reuse: Reuse the model variables or not.
is_training: Is training or not.
#fine_tune_batch_norm: Fine-tune the batch norm parameters or not.
Returns:
Decoder output size [batch_size, decoder_height, decoder_width,
decoder_depth].
"""
with tf.variable_scope(scope):
with slim.arg_scope([slim.conv2d, slim.separable_conv2d],
weights_regularizer=slim.l2_regularizer(
weight_decay),
normalizer_fn=slim.batch_norm,
normalizer_params=self._batch_norm_params,
reuse=reuse):
feature_list = ['block1/unit_2/bottleneck_v1/conv3']
decoder_features = feature_map
for i, name in enumerate(feature_list):
decoder_features_list = [decoder_features]
feature_name = '{}/{}'.format('resnet_v1_50', name)
decoder_features_list.append(
slim.conv2d(end_points[feature_name], 48, 1,
scope='feature_project' + str(i)))
for j, feature in enumerate(decoder_features_list):
decoder_features_list[j] = tf.image.resize_bilinear(
feature, [decoder_height, decoder_width],
align_corners=True)
h = (None if isinstance(decoder_height, tf.Tensor)
else decoder_height)
w = (None if isinstance(decoder_width, tf.Tensor)
else decoder_width)
decoder_features_list[j].set_shape([None, h, w, None])
decoder_depth = 256
if decoder_use_seperable_conv:
decoder_features = self.split_seperable_conv2d(
tf.concat(decoder_features_list, axis=3),
filters=decoder_depth,
rate=1,
weight_decay=weight_decay,
scope='decoder_conv0')
decoder_features = self.split_seperable_conv2d(
decoder_features,
filters=decoder_depth,
rate=1,
weight_decay=weight_decay,
scope='decoder_conv1')
else:
num_convs = 2
decoder_features = slim.repeat(
tf.concat(decoder_features_list, axis=3),
num_convs,
slim.conv2d,
decoder_depth,
3,
scope='decoder_conv' + str(i))
return decoder_features
def _get_branch_logits(self,
feature_map,
num_classes,
atrous_rates=[12, 24, 36],
kernel_size=1,
weight_decay=0.0001,
reuse=None,
scope_suffix='seg_logits',
scope='resnet_v1_50'):
"""Gets the logits from each model's branch.
The underlying model is branched out in the last layer when atrous
spatial pyramid pooling is employed, and all branches are sum-merged
to form the final logits.
Args:
feature_map: A float32 tensor with shape [batch_size, height,
width, channels].
num_classes: Number of classes to predict.
atrous_rates: A list of atrous convolution rates for last layer.
kernel_size: Kernel size for convolution.
weight_decay: Weight decay for the model variables.
reuse: Reuse model variables or not.
scope_suffix: Scope suffix for the model variables.
Returns:
Merged logits with shape [batch_size, height, width, num_classes].
"""
with tf.variable_scope(scope):
with slim.arg_scope(
[slim.conv2d],
weights_regularizer=slim.l2_regularizer(
weight_decay),
weights_initializer=tf.truncated_normal_initializer(stddev=0.01),
reuse=reuse):
branch_logits = []
for i, rate in enumerate(atrous_rates):
scope = scope_suffix
if i:
scope += '_%d' % i
branch_logits.append(
slim.conv2d(feature_map,
num_classes,
kernel_size=kernel_size,
rate=rate,
activation_fn=None,
normalizer_fn=None,
scope=scope))
return tf.add_n(branch_logits)
def postprocess(self, prediction_tensors):
"""Convert predicted output tensors to final forms.
Args:
prediction_tensors: The prediction tensors.
Returns:
The postprocessed results.
"""
logits = tf.nn.softmax(prediction_tensors, axis=3)
return logits
def loss(self, prediction_tensors, groundtruth_tensors):
"""Compute scalar loss tensors with respect to provided groundtruth."""
logits = tf.reshape(prediction_tensors, shape=[-1, self._num_classes])
labels = tf.reshape(groundtruth_tensors, shape=[-1,])
labels = tf.where(tf.greater(labels, 0.8),
tf.ones_like(labels),
labels)
labels = tf.where(tf.logical_and(tf.less_equal(labels, 0.8),
tf.greater(labels, 0.0)),
2 * tf.ones_like(labels),
labels)
labels = tf.cast(labels, dtype=tf.int32)
slim.losses.sparse_softmax_cross_entropy(logits, labels)
loss = slim.losses.get_total_loss()
return loss