TensorFlow 实现语义分割模型:DeepLab V3+(占坑,因 TensorFlow 2.0 改版很大,以前很多 API 都将取消,所以博主停更了,但仍欢迎多多交流)

本文将实现 deeplab v3 + 模型(参考:DeepLab 官方开源代码)

# -*- coding: utf-8 -*-
"""
Created on Mon Dec  3 17:57:46 2018

@author: shirhe-lyh


Implementation of DeepLab V3+:
    Encoder-Decoder with atrous seperable convolutioon for semantic image
    segmentation, Liang-Chieh Chen, et. al., arXiv:1802.02611v3.
"""

import numpy as np
import tensorflow as tf

from tensorflow.contrib.slim import nets

import preprocessing
import resnet_v1_beta

slim = tf.contrib.slim


class DeepLab(object):
    """Implementation of DeepLab V3+."""
    
    def __init__(self,
                 is_training,
                 num_classes=3,
                 output_stride=16,
                 atrous_rates=[6, 12, 18],  # [12, 24, 36] for output_stride=8
                 decoder_output_stride=4,
                 default_image_size=513,
                 fine_tune_batch_norm=False):
        """Constructor.
        
        Args:
            is_training: A boolean indicating whether the training version of
                computation graph should be constructed.
            num_classes: The number of classes.
            defualt_image_size: The input size of the model.
        """
        self._is_training = is_training
        self._num_classes = num_classes
        self._output_stride = output_stride
        self._atrous_rates = atrous_rates
        self._decoder_output_stride = decoder_output_stride
        self._default_image_size = default_image_size
        
        # When fine_tune_batch_norm=True, use at least batch size larger than 
        # 12 (batch size more than 16 is better). Otherwise, one could use 
        # smaller batch size and set fine_tune_batch_norm=False.
        _is_training = is_training and fine_tune_batch_norm
        self._batch_norm_params = {'is_training': _is_training,
                                   'epsilon': 1e-5,
                                   'decay': 0.9997,
                                   'scale': True}
        
    @property
    def default_image_size(self):
        return self._default_image_size
        
    def preprocess(self, images=None, masks=None):
        """Preprocessing.
        
        Args:
            images: A float32 tensor with shape [batch_size, height, width,
                3] representing a batch of images. Only passed values in case
                of test (i.e., in training case images=None).
            masks: A float32 tensor with shape [batch_size, height, width, 1] 
                representing a batch of groundtruth masks.
            
        Returns:
            The preprocessed inputs.
        """
        
        preprocessed_dict = {'images': images_preprocessed,
                             'masks': trimaps_preprocessed}
        return preprocessed_dict
    
    def _preprocess_zero_mean_unit_range(self, inputs):
        """Map image values from [0, 255] to [-1, 1].
        
        Only for beta version.
        """
        return (2.0 / 255.0) * tf.to_float(inputs) - 1.0
    
    def predict(self, preprocessed_inputs):
        """Predict prediction tensors from inputs tensor.
        
        Outputs of this function can be passed to loss or postprocess functions.
        
        Args:
            preprocessed_inputs: A 4-D float32 tensor with shape [batch_size, 
                height, width, channels].
            
        Returns:
            The prediction tensors to be passed to the Loss or Postprocess 
            functions.
        """
        # ResNet-50
        with slim.arg_scope(nets.resnet_v1.resnet_arg_scope()):
            net, end_points = resnet_v1_beta.resnet_v1_50_beta(
                preprocessed_inputs, num_classes=None,
                is_training=self._is_training,
                multi_grid=[1, 2, 4],
                global_pool=False,
                output_stride=self._output_stride)
            
        # Use the same scope with ResNet-50
        scope='resnet_v1_50'
        
        # Atrous spatial pyramid pooling
        net = self._atrous_spatial_pyramid_pooling(
            net, atrous_rates=self._atrous_rates, scope=scope)
        
        # Refine by decoder
        decoder_height = self.default_image_size // self._decoder_output_stride
        decoder_width = self.default_image_size // self._decoder_output_stride
        net = self._refine_by_decoder(
            net,
            end_points,
            decoder_height=decoder_height,
            decoder_width=decoder_width,
            decoder_use_seperable_conv=True,
            is_training=self._is_training)
        
        # Convolution
        net = self._get_branch_logits(net, self._num_classes,
                                      self._atrous_rates, kernel_size=1)
        net = tf.image.resize_bilinear(net, size=[self._default_image_size,
                                                  self._default_image_size],
                                       align_corners=True,
                                       name='upsampling_logits')
        return net
    
    def split_seperable_conv2d(self,
                               inputs,
                               filters,
                               kernel_size=3,
                               rate=1,
                               weight_decay=0.00004,
                               depthwise_weights_initializer_stddev=0.33,
                               pointwise_weights_initializer_stddev=0.06,
                               scope=None):
        """Splits a seperable conv2d into depthwise and pointwise conv2d.
        
        This operation differs from `tf.layers.separable_conv2d` as this 
        operation applies activation function between depthwise and pointwise 
        conv2d.
        
        Copy from:
            https://github.com/tensorflow/models/blob/master/research/deeplab/
            core/utils.py
            
        Args:
            inputs: Input tensor with shape [batch, height, width, channels].
            filters: Number of filters in the 1x1 pointwise convolution.
            kernel_size: A list of length 2: [kernel_height, kernel_width] of
                of the filters. Can be an int if both values are the same.
            rate: Atrous convolution rate for the depthwise convolution.
            weight_decay: The weight decay to use for regularizing the model.
            depthwise_weights_initializer_stddev: The standard deviation of the
                truncated normal weight initializer for depthwise convolution.
            pointwise_weights_initializer_stddev: The standard deviation of the
                truncated normal weight initializer for pointwise convolution.
            scope: Optional scope for the operation.
            
        Returns:
            Computed features after split separable conv2d.
        """
        outputs = slim.separable_conv2d(
            inputs,
            None,
            kernel_size=kernel_size,
            depth_multiplier=1,
            rate=rate,
            weights_initializer=tf.truncated_normal_initializer(
                stddev=depthwise_weights_initializer_stddev),
            weights_regularizer=None,
            scope=scope + '_depthwise')
        return slim.conv2d(
            outputs,
            filters,
            1,
            weights_initializer=tf.truncated_normal_initializer(
                stddev=pointwise_weights_initializer_stddev),
            weights_regularizer=slim.l2_regularizer(weight_decay),
            scope=scope + '_pointwise')
              
    def _atrous_spatial_pyramid_pooling(self, feature_map, weight_decay=0.0001,
                                        atrous_rates=[12, 24, 36],
                                        scope='resnet_v1_50'):
        """Atrous spatial pyramid pooling for DeepLab v3."""
        branch_nets = []
        # Convolution
        with tf.variable_scope(scope):
            with slim.arg_scope([slim.conv2d, slim.separable_conv2d], 
                                weights_regularizer=slim.l2_regularizer(
                                    weight_decay),
                                normalizer_fn=slim.batch_norm,
                                normalizer_params=self._batch_norm_params):
                depth=256
                
                # Image pooling feature
                shape = tf.shape(feature_map)[1:3]
                image_feature = tf.reduce_mean(feature_map, axis=[1, 2],
                                               keep_dims=True)
                image_feature = slim.conv2d(image_feature, kernel_size=1,
                                            num_outputs=depth,
                                            scope='global_pool')
                image_feature = tf.image.resize_bilinear(image_feature, 
                                                         size=shape,
                                                         align_corners=True)
                branch_nets.append(image_feature)
                
                # Employ a 1x1 convolution
                branch_nets.append(slim.conv2d(feature_map, kernel_size=1,
                                               num_outputs=depth,
                                               scope='aspp' + str(0)))          
                
                # Employ 3x3 convolutions with different atrous rates.
                for i, rate in enumerate(atrous_rates, 1):
                    scope =scope + 'aspp' + str(i)
                    aspp_net = self.split_seperable_conv2d(
                        feature_map,
                        filters=depth,
                        rate=rate,
                        weight_decay=weight_decay,
                        scope=scope)
                    branch_nets.append(aspp_net)
        
        # Concatenation
        net = tf.concat(branch_nets, axis=3, name='aspp_concate')
        net = slim.conv2d(net, depth, kernel_size=1, 
                          scope=scope + '/concat_projection')
        net = slim.dropout(net, keep_prob=0.9, is_training=self._is_training,
                           scope= scope + '/concat_projection_dropout')
        return net
    
    def _refine_by_decoder(self,
                           feature_map,
                           end_points,
                           decoder_height,
                           decoder_width,
                           decoder_use_seperable_conv=False,
                           weight_decay=0.0001,
                           reuse=None,
                           is_training=False,
                           scope='resnet_v1_50'):
        """Adds the decoder to obtain sharper segmentation results.
        
        Args:
            feature_map: A tensor with shape [batch_size, height, width, depth].
            end_points: A dictionary from components of the network to the 
                corresponding activation.
            decoder_height: The height of decoder feature maps.
            decoder_width: The width of decoder feature maps.
            decoder_use_seperable_conv: Employ seperable convolution for 
                decoder or not.
            weight_decay: The weight decay for model variables.
            reuse: Reuse the model variables or not.
            is_training: Is training or not.
            #fine_tune_batch_norm: Fine-tune the batch norm parameters or not.
            
        Returns:
            Decoder output size [batch_size, decoder_height, decoder_width,
            decoder_depth].
        """
        with tf.variable_scope(scope):
            with slim.arg_scope([slim.conv2d, slim.separable_conv2d],
                                weights_regularizer=slim.l2_regularizer(
                                    weight_decay),
                                normalizer_fn=slim.batch_norm,
                                normalizer_params=self._batch_norm_params,
                                reuse=reuse):
                feature_list = ['block1/unit_2/bottleneck_v1/conv3']
                decoder_features = feature_map
                for i, name in enumerate(feature_list):
                    decoder_features_list = [decoder_features]
                    feature_name = '{}/{}'.format('resnet_v1_50', name)
                    decoder_features_list.append(
                        slim.conv2d(end_points[feature_name], 48, 1,
                                   scope='feature_project' + str(i)))
                    for j, feature in enumerate(decoder_features_list):
                        decoder_features_list[j] = tf.image.resize_bilinear(
                            feature, [decoder_height, decoder_width], 
                            align_corners=True)
                        h = (None if isinstance(decoder_height, tf.Tensor)
                             else decoder_height)
                        w = (None if isinstance(decoder_width, tf.Tensor)
                             else decoder_width)
                        decoder_features_list[j].set_shape([None, h, w, None])
                    decoder_depth = 256
                    if decoder_use_seperable_conv:
                        decoder_features = self.split_seperable_conv2d(
                            tf.concat(decoder_features_list, axis=3),
                            filters=decoder_depth,
                            rate=1,
                            weight_decay=weight_decay,
                            scope='decoder_conv0')
                        decoder_features = self.split_seperable_conv2d(
                            decoder_features,
                            filters=decoder_depth,
                            rate=1,
                            weight_decay=weight_decay,
                            scope='decoder_conv1')
                    else:
                        num_convs = 2
                        decoder_features = slim.repeat(
                            tf.concat(decoder_features_list, axis=3),
                            num_convs,
                            slim.conv2d,
                            decoder_depth,
                            3,
                            scope='decoder_conv' + str(i))
                return decoder_features
            
    def _get_branch_logits(self,
                           feature_map,
                           num_classes,
                           atrous_rates=[12, 24, 36],
                           kernel_size=1,
                           weight_decay=0.0001,
                           reuse=None,
                           scope_suffix='seg_logits',
                           scope='resnet_v1_50'):
        """Gets the logits from each model's branch.
        
        The underlying model is branched out in the last layer when atrous
        spatial pyramid pooling is employed, and all branches are sum-merged
        to form the final logits.
        
        Args:
            feature_map: A float32 tensor with shape [batch_size, height,
                width, channels].
            num_classes: Number of classes to predict.
            atrous_rates: A list of atrous convolution rates for last layer.
            kernel_size: Kernel size for convolution.
            weight_decay: Weight decay for the model variables.
            reuse: Reuse model variables or not.
            scope_suffix: Scope suffix for the model variables.
            
        Returns:
            Merged logits with shape [batch_size, height, width, num_classes].
        """
        with tf.variable_scope(scope):
            with slim.arg_scope(
                [slim.conv2d],
                weights_regularizer=slim.l2_regularizer(
                    weight_decay),
                weights_initializer=tf.truncated_normal_initializer(stddev=0.01),
                reuse=reuse):
                branch_logits = []
                for i, rate in enumerate(atrous_rates):
                    scope = scope_suffix
                    if i:
                        scope += '_%d' % i
                        
                    branch_logits.append(
                        slim.conv2d(feature_map,
                                    num_classes,
                                    kernel_size=kernel_size,
                                    rate=rate,
                                    activation_fn=None,
                                    normalizer_fn=None,
                                    scope=scope))
            return tf.add_n(branch_logits)
    
    def postprocess(self, prediction_tensors):
        """Convert predicted output tensors to final forms.
        
        Args:
            prediction_tensors: The prediction tensors.
                
        Returns:
            The postprocessed results.
        """
        logits = tf.nn.softmax(prediction_tensors, axis=3)
        return logits
    
    def loss(self, prediction_tensors, groundtruth_tensors):
        """Compute scalar loss tensors with respect to provided groundtruth."""
        logits = tf.reshape(prediction_tensors, shape=[-1, self._num_classes])
        labels = tf.reshape(groundtruth_tensors, shape=[-1,])
        labels = tf.where(tf.greater(labels, 0.8),
                          tf.ones_like(labels),
                          labels)
        labels = tf.where(tf.logical_and(tf.less_equal(labels, 0.8),
                                         tf.greater(labels, 0.0)),
                          2 * tf.ones_like(labels),
                          labels)
        labels = tf.cast(labels, dtype=tf.int32)
        slim.losses.sparse_softmax_cross_entropy(logits, labels)
        loss = slim.losses.get_total_loss()
        return loss

你可能感兴趣的:(TensorFlow 实现语义分割模型:DeepLab V3+(占坑,因 TensorFlow 2.0 改版很大,以前很多 API 都将取消,所以博主停更了,但仍欢迎多多交流))