1、FCN_TensorFlow——VGG16_FCN8s构造代码分析

首先,感谢Marvin Teichmann分享的KittiSeg代码,源码见其GitHub主页

先贴一张全连接的VGG16模型,如图1:

                               1、FCN_TensorFlow——VGG16_FCN8s构造代码分析_第1张图片

图 1

1、全卷积神经网络(FCN)是在图1的基础上,将全连接层改为卷积替代并将其用于语义分割上,详情见论文《Fully Convolutional Networks for Semantic Segmentation》

1、FCN_TensorFlow——VGG16_FCN8s构造代码分析_第2张图片

图 2 将全连接层修改为卷积层使得分类网络的输出变为一个热点图

1、FCN_TensorFlow——VGG16_FCN8s构造代码分析_第3张图片

图 3 全卷积网络可以有效的进行像素级任务的稠密预测,例如语义分割

1、FCN_TensorFlow——VGG16_FCN8s构造代码分析_第4张图片

图 4 黑线32s;虚线16s;点线8s

2、VGG16_FCN8s代码如下

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import logging
from math import ceil
import sys

import numpy as np
import tensorflow as tf

VGG_MEAN = [103.939, 116.779, 123.68]


class FCN8VGG:

    def __init__(self, vgg16_npy_path=None):
        if vgg16_npy_path is None:
            path = sys.modules[self.__class__.__module__].__file__
            # print path
            path = os.path.abspath(os.path.join(path, os.pardir))
            # print path
            path = os.path.join(path, "vgg16.npy")
            vgg16_npy_path = path
            logging.info("Load npy file from '%s'.", vgg16_npy_path)
        if not os.path.isfile(vgg16_npy_path):
            logging.error(("File '%s' not found. Download it from "
                           "ftp://mi.eng.cam.ac.uk/pub/mttt2/"
                           "models/vgg16.npy"), vgg16_npy_path)
            sys.exit(1)

        self.data_dict = np.load(vgg16_npy_path, encoding='latin1').item()
        self.wd = 5e-4
        print("npy file loaded")

    def build(self, rgb, train=False, num_classes=20, random_init_fc8=False,
              debug=False, use_dilated=False):
        """
        Build the VGG model using loaded weights
        Parameters
        ----------
        rgb: image batch tensor
            Image in rgb shap. Scaled to Intervall [0, 255]
        train: bool
            Whether to build train or inference graph
        num_classes: int
            How many classes should be predicted (by fc8)
        random_init_fc8 : bool
            Whether to initialize fc8 layer randomly.
            Finetuning is required in this case.
        debug: bool
            Whether to print additional Debug Information.
        """
        # Convert RGB to BGR

        with tf.name_scope('Processing'):

            red, green, blue = tf.split(rgb, 3, 3)
            # assert red.get_shape().as_list()[1:] == [224, 224, 1]
            # assert green.get_shape().as_list()[1:] == [224, 224, 1]
            # assert blue.get_shape().as_list()[1:] == [224, 224, 1]
            bgr = tf.concat([
                blue - VGG_MEAN[0],
                green - VGG_MEAN[1],
                red - VGG_MEAN[2],
            ], 3)

            if debug:
                bgr = tf.Print(bgr, [tf.shape(bgr)],
                               message='Shape of input image: ',
                               summarize=4, first_n=1)

        self.conv1_1 = self._conv_layer(bgr, "conv1_1")
        self.conv1_2 = self._conv_layer(self.conv1_1, "conv1_2&

你可能感兴趣的:(1、FCN_TensorFlow——VGG16_FCN8s构造代码分析)