网络结构之 Inception V2

原文:AIUAI - 网络结构之 Inception V2

Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift
Rethinking the Inception Architecture for Computer Vision

GoogleNet 网络结构的一种变形 - InceptionV2,改动主要有:

对比 网络结构之 GoogleNet(Inception V1)

  • [1] - 5x5 卷积层被替换为两个连续的 3x3 卷积层. 网络的最大深度增加 9 个权重层. 参数量增加了大约 25%,计算量增加了大约 30%.

    两个 3x3 卷积层作用可以代替一个 5x5 卷积层.
  • [2] - 28x28 的 Inception 模块的数量由 2 增加到了 3.
  • [3] - Inception 模块,Ave 和 Max Pooling 层均有用到. 参考表格.
  • [4] - 两个 Inception 模块间不再使用 pooling 层;而在模块 3c 和 4e 中的 concatenation 前采用了 stride-2 conv/pooling 层.
  • [5] - 网络结构的第一个卷积层采用了深度乘子为 8 的可分离卷积(separable convolution with depth multiplier 8),减少了计算量,但训练时增加了内存消耗.

Tensorflow Slim 的 Inception V2 定义

"""
Inception V2 分类网络的定义.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf

from nets import inception_utils

slim = tf.contrib.slim
trunc_normal = lambda stddev: tf.truncated_normal_initializer(0.0, stddev)


def inception_v2_base(inputs,
									final_endpoint='Mixed_5c',
									min_depth=16,
									depth_multiplier=1.0,
									use_separable_conv=True,
									data_format='NHWC',
									scope=None):
  """
  Inception V2 基础网络结构定义.
  
  根据给定的输入和最终网络节点构建 Inception V2 网络. 
  可以构建表格中从输入到 inception(5b) 网络层的网络结构.

  参数:
    inputs: Tensor,尺寸为 [batch_size, height, width, channels].
    final_endpoint: 指定网络定义结束的节点endpoint,即网络深度.
							候选值:['Conv2d_1a_7x7', 'MaxPool_2a_3x3', 
							'Conv2d_2b_1x1', 'Conv2d_2c_3x3', 
							'MaxPool_3a_3x3', 'Mixed_3b', 'Mixed_3c', 
							'Mixed_4a', 'Mixed_4b', 'Mixed_4c', 'Mixed_4d', 'Mixed_4e', 
							'Mixed_5a', 'Mixed_5b', 'Mixed_5c'].
    min_depth: 所有卷积 ops 的最小深度值(通道数,depth value (number of channels)).
					  当 depth_multiplier < 1 时,强制执行;
					  当 depth_multiplier >= 1 时,不是主动约束项.
    depth_multiplier: 所有卷积 ops 深度(depth (number of channels))的浮点数乘子.
								该值必须大于 0.
								一般是将该值设为 (0, 1) 间的浮点数值,以减少参数量或模型的计算量.
    use_separable_conv: 网络第一个卷积层Conv2d_1a_7x7,采用 separable convolution. 
										如果值为 False,则采用传统的 conv 层.
    data_format: 激活值的数据格式 ('NHWC' or 'NCHW').
    scope: 可选变量作用域 variable_scope.

  返回值:
    tensor_out: 对应到网络最终节点final_endpoint 的输出张量Tensor.
    end_points: 外部使用的激活值集合,例如,summaries 和 losses.

  Raises:
    ValueError: if final_endpoint is not set to one of the predefined values,
                or depth_multiplier <= 0
  """

  # end_points 保存相关外用的激活值,例如 summaries 或 losses.
  end_points = {}

  # 用于寻找每一层的最薄的深度(depths,通道数).
  if depth_multiplier <= 0:
    raise ValueError('depth_multiplier is not greater than zero.')
  depth = lambda d: max(int(d * depth_multiplier), min_depth)

  if data_format != 'NHWC' and data_format != 'NCHW':
    raise ValueError('data_format must be either NHWC or NCHW.')
  if data_format == 'NCHW' and use_separable_conv:
    raise ValueError(
        'separable convolution only supports NHWC layout. NCHW data format can'
        ' only be used when use_separable_conv is False.'
    )

  concat_dim = 3 if data_format == 'NHWC' else 1
  with tf.variable_scope(scope, 'InceptionV2', [inputs]):
    with slim.arg_scope([slim.conv2d, slim.max_pool2d, slim.avg_pool2d],
									stride=1, 
									padding='SAME', 
									data_format=data_format):

      # 下面的注释中,假设网络的输入尺寸为 224x224.
	  # 实际上,网络的输入尺寸可以是任何大于 32x32 的尺寸.

      # 224 x 224 x 3
      end_point = 'Conv2d_1a_7x7'

      if use_separable_conv: # 采用可分离卷积
        # depthwise_multiplier here is different from depth_multiplier.
        # depthwise_multiplier determines the output channels of the initial
        # depthwise conv (see docs for tf.nn.separable_conv2d), while
        # depth_multiplier controls the # channels of the subsequent 1x1
        # convolution. Must have
        #   in_channels * depthwise_multipler <= out_channels
        # so that the separable convolution is not overparameterized.
        depthwise_multiplier = min(int(depth(64) / 3), 8)
        net = slim.separable_conv2d(inputs, depth(64), [7, 7],
														depth_multiplier=depthwise_multiplier,
														stride=2,
														padding='SAME',
														weights_initializer=trunc_normal(1.0),
														scope=end_point)
      else: # 采用一般卷积
        net = slim.conv2d(inputs, depth(64), [7, 7], stride=2,
									weights_initializer=trunc_normal(1.0),
									scope=end_point)
      end_points[end_point] = net
      if end_point == final_endpoint: return net, end_points
      # 112 x 112 x 64
      end_point = 'MaxPool_2a_3x3'
      net = slim.max_pool2d(net, [3, 3], scope=end_point, stride=2)
      end_points[end_point] = net
      if end_point == final_endpoint: return net, end_points
      # 56 x 56 x 64
      end_point = 'Conv2d_2b_1x1'
      net = slim.conv2d(net, depth(64), [1, 1], scope=end_point,
									weights_initializer=trunc_normal(0.1))
      end_points[end_point] = net
      if end_point == final_endpoint: return net, end_points
      # 56 x 56 x 64
      end_point = 'Conv2d_2c_3x3'
      net = slim.conv2d(net, depth(192), [3, 3], scope=end_point)
      end_points[end_point] = net
      if end_point == final_endpoint: return net, end_points
      # 56 x 56 x 192
      end_point = 'MaxPool_3a_3x3'
      net = slim.max_pool2d(net, [3, 3], scope=end_point, stride=2)
      end_points[end_point] = net
      if end_point == final_endpoint: return net, end_points
      # 28 x 28 x 192
      # Inception module.
      end_point = 'Mixed_3b'
      with tf.variable_scope(end_point):
        with tf.variable_scope('Branch_0'):
          branch_0 = slim.conv2d(net, depth(64), [1, 1], scope='Conv2d_0a_1x1')
        with tf.variable_scope('Branch_1'):
          branch_1 = slim.conv2d(net, depth(64), [1, 1],
												 weights_initializer=trunc_normal(0.09),
												 scope='Conv2d_0a_1x1')
          branch_1 = slim.conv2d(branch_1, depth(64), [3, 3], scope='Conv2d_0b_3x3')
        with tf.variable_scope('Branch_2'):
          branch_2 = slim.conv2d(net, depth(64), [1, 1],
												 weights_initializer=trunc_normal(0.09),
												 scope='Conv2d_0a_1x1')
          branch_2 = slim.conv2d(branch_2, depth(96), [3, 3], scope='Conv2d_0b_3x3')
          branch_2 = slim.conv2d(branch_2, depth(96), [3, 3], scope='Conv2d_0c_3x3')
        with tf.variable_scope('Branch_3'):
          branch_3 = slim.avg_pool2d(net, [3, 3], scope='AvgPool_0a_3x3')
          branch_3 = slim.conv2d(branch_3, depth(32), [1, 1],
												 weights_initializer=trunc_normal(0.1),
												 scope='Conv2d_0b_1x1')
        net = tf.concat(axis=concat_dim, values=[branch_0, branch_1, branch_2, branch_3])
        end_points[end_point] = net
        if end_point == final_endpoint: return net, end_points
      # 28 x 28 x 256
      end_point = 'Mixed_3c'
      with tf.variable_scope(end_point):
        with tf.variable_scope('Branch_0'):
          branch_0 = slim.conv2d(net, depth(64), [1, 1], scope='Conv2d_0a_1x1')
        with tf.variable_scope('Branch_1'):
          branch_1 = slim.conv2d(net, depth(64), [1, 1],
												 weights_initializer=trunc_normal(0.09),
												 scope='Conv2d_0a_1x1')
          branch_1 = slim.conv2d(branch_1, depth(96), [3, 3], scope='Conv2d_0b_3x3')
        with tf.variable_scope('Branch_2'):
          branch_2 = slim.conv2d(net, depth(64), [1, 1],
												 weights_initializer=trunc_normal(0.09),
												 scope='Conv2d_0a_1x1')
          branch_2 = slim.conv2d(branch_2, depth(96), [3, 3], scope='Conv2d_0b_3x3')
          branch_2 = slim.conv2d(branch_2, depth(96), [3, 3], scope='Conv2d_0c_3x3')
        with tf.variable_scope('Branch_3'):
          branch_3 = slim.avg_pool2d(net, [3, 3], scope='AvgPool_0a_3x3')
          branch_3 = slim.conv2d(branch_3, depth(64), [1, 1],
												 weights_initializer=trunc_normal(0.1),
												 scope='Conv2d_0b_1x1')
        net = tf.concat(axis=concat_dim, values=[branch_0, branch_1, branch_2, branch_3])
        end_points[end_point] = net
        if end_point == final_endpoint: return net, end_points
      # 28 x 28 x 320
      end_point = 'Mixed_4a'
      with tf.variable_scope(end_point):
        with tf.variable_scope('Branch_0'):
          branch_0 = slim.conv2d(net, depth(128), [1, 1],
												 weights_initializer=trunc_normal(0.09),
												 scope='Conv2d_0a_1x1')
          branch_0 = slim.conv2d(branch_0, depth(160), [3, 3], stride=2, scope='Conv2d_1a_3x3')
        with tf.variable_scope('Branch_1'):
          branch_1 = slim.conv2d(net, depth(64), [1, 1], 
												 weights_initializer=trunc_normal(0.09),
												 scope='Conv2d_0a_1x1')
          branch_1 = slim.conv2d(branch_1, depth(96), [3, 3], scope='Conv2d_0b_3x3')
          branch_1 = slim.conv2d(branch_1, depth(96), [3, 3], stride=2, scope='Conv2d_1a_3x3')
        with tf.variable_scope('Branch_2'):
          branch_2 = slim.max_pool2d(net, [3, 3], stride=2, scope='MaxPool_1a_3x3')
        net = tf.concat(axis=concat_dim, values=[branch_0, branch_1, branch_2])
        end_points[end_point] = net
        if end_point == final_endpoint: return net, end_points
      # 14 x 14 x 576
      end_point = 'Mixed_4b'
      with tf.variable_scope(end_point):
        with tf.variable_scope('Branch_0'):
          branch_0 = slim.conv2d(net, depth(224), [1, 1], scope='Conv2d_0a_1x1')
        with tf.variable_scope('Branch_1'):
          branch_1 = slim.conv2d(net, depth(64), [1, 1],
												 weights_initializer=trunc_normal(0.09),
												 scope='Conv2d_0a_1x1')
          branch_1 = slim.conv2d(branch_1, depth(96), [3, 3], scope='Conv2d_0b_3x3')
        with tf.variable_scope('Branch_2'):
          branch_2 = slim.conv2d(net, depth(96), [1, 1],
												 weights_initializer=trunc_normal(0.09),
												 scope='Conv2d_0a_1x1')
          branch_2 = slim.conv2d(branch_2, depth(128), [3, 3], scope='Conv2d_0b_3x3')
          branch_2 = slim.conv2d(branch_2, depth(128), [3, 3], scope='Conv2d_0c_3x3')
        with tf.variable_scope('Branch_3'):
          branch_3 = slim.avg_pool2d(net, [3, 3], scope='AvgPool_0a_3x3')
          branch_3 = slim.conv2d(branch_3, depth(128), [1, 1],
												 weights_initializer=trunc_normal(0.1),
												 scope='Conv2d_0b_1x1')
        net = tf.concat(axis=concat_dim, values=[branch_0, branch_1, branch_2, branch_3])
        end_points[end_point] = net
        if end_point == final_endpoint: return net, end_points
      # 14 x 14 x 576
      end_point = 'Mixed_4c'
      with tf.variable_scope(end_point):
        with tf.variable_scope('Branch_0'):
          branch_0 = slim.conv2d(net, depth(192), [1, 1], scope='Conv2d_0a_1x1')
        with tf.variable_scope('Branch_1'):
          branch_1 = slim.conv2d(net, depth(96), [1, 1],
												 weights_initializer=trunc_normal(0.09),
												 scope='Conv2d_0a_1x1')
          branch_1 = slim.conv2d(branch_1, depth(128), [3, 3], scope='Conv2d_0b_3x3')
        with tf.variable_scope('Branch_2'):
          branch_2 = slim.conv2d(net, depth(96), [1, 1],
												 weights_initializer=trunc_normal(0.09),
												 scope='Conv2d_0a_1x1')
          branch_2 = slim.conv2d(branch_2, depth(128), [3, 3], scope='Conv2d_0b_3x3')
          branch_2 = slim.conv2d(branch_2, depth(128), [3, 3], scope='Conv2d_0c_3x3')
        with tf.variable_scope('Branch_3'):
          branch_3 = slim.avg_pool2d(net, [3, 3], scope='AvgPool_0a_3x3')
          branch_3 = slim.conv2d(branch_3, depth(128), [1, 1],
												 weights_initializer=trunc_normal(0.1),
												 scope='Conv2d_0b_1x1')
        net = tf.concat(axis=concat_dim, values=[branch_0, branch_1, branch_2, branch_3])
        end_points[end_point] = net
        if end_point == final_endpoint: return net, end_points
      # 14 x 14 x 576
      end_point = 'Mixed_4d'
      with tf.variable_scope(end_point):
        with tf.variable_scope('Branch_0'):
          branch_0 = slim.conv2d(net, depth(160), [1, 1], scope='Conv2d_0a_1x1')
        with tf.variable_scope('Branch_1'):
          branch_1 = slim.conv2d(net, depth(128), [1, 1],
												  weights_initializer=trunc_normal(0.09),
												  scope='Conv2d_0a_1x1')
          branch_1 = slim.conv2d(branch_1, depth(160), [3, 3], scope='Conv2d_0b_3x3')
        with tf.variable_scope('Branch_2'):
          branch_2 = slim.conv2d(net, depth(128), [1, 1],
												 weights_initializer=trunc_normal(0.09),
												 scope='Conv2d_0a_1x1')
          branch_2 = slim.conv2d(branch_2, depth(160), [3, 3], scope='Conv2d_0b_3x3')
          branch_2 = slim.conv2d(branch_2, depth(160), [3, 3], scope='Conv2d_0c_3x3')
        with tf.variable_scope('Branch_3'):
          branch_3 = slim.avg_pool2d(net, [3, 3], scope='AvgPool_0a_3x3')
          branch_3 = slim.conv2d(branch_3, depth(96), [1, 1],
												 weights_initializer=trunc_normal(0.1),
												 scope='Conv2d_0b_1x1')
        net = tf.concat(axis=concat_dim, values=[branch_0, branch_1, branch_2, branch_3])
        end_points[end_point] = net
        if end_point == final_endpoint: return net, end_points
      # 14 x 14 x 576
      end_point = 'Mixed_4e'
      with tf.variable_scope(end_point):
        with tf.variable_scope('Branch_0'):
          branch_0 = slim.conv2d(net, depth(96), [1, 1], scope='Conv2d_0a_1x1')
        with tf.variable_scope('Branch_1'):
          branch_1 = slim.conv2d(net, depth(128), [1, 1],
												 weights_initializer=trunc_normal(0.09),
												 scope='Conv2d_0a_1x1')
          branch_1 = slim.conv2d(branch_1, depth(192), [3, 3], scope='Conv2d_0b_3x3')
        with tf.variable_scope('Branch_2'):
          branch_2 = slim.conv2d(net, depth(160), [1, 1],
												 weights_initializer=trunc_normal(0.09),
												 scope='Conv2d_0a_1x1')
          branch_2 = slim.conv2d(branch_2, depth(192), [3, 3], scope='Conv2d_0b_3x3')
          branch_2 = slim.conv2d(branch_2, depth(192), [3, 3], scope='Conv2d_0c_3x3')
        with tf.variable_scope('Branch_3'):
          branch_3 = slim.avg_pool2d(net, [3, 3], scope='AvgPool_0a_3x3')
          branch_3 = slim.conv2d(branch_3, depth(96), [1, 1],
												 weights_initializer=trunc_normal(0.1),
												 scope='Conv2d_0b_1x1')
        net = tf.concat(axis=concat_dim, values=[branch_0, branch_1, branch_2, branch_3])
        end_points[end_point] = net
        if end_point == final_endpoint: return net, end_points
      # 14 x 14 x 576
      end_point = 'Mixed_5a'
      with tf.variable_scope(end_point):
        with tf.variable_scope('Branch_0'):
          branch_0 = slim.conv2d(net, depth(128), [1, 1],
												 weights_initializer=trunc_normal(0.09),
												 scope='Conv2d_0a_1x1')
          branch_0 = slim.conv2d(branch_0, depth(192), [3, 3], stride=2, scope='Conv2d_1a_3x3')
        with tf.variable_scope('Branch_1'):
          branch_1 = slim.conv2d(net, depth(192), [1, 1],
												 weights_initializer=trunc_normal(0.09),
												 scope='Conv2d_0a_1x1')
          branch_1 = slim.conv2d(branch_1, depth(256), [3, 3], scope='Conv2d_0b_3x3')
          branch_1 = slim.conv2d(branch_1, depth(256), [3, 3], stride=2, scope='Conv2d_1a_3x3')
        with tf.variable_scope('Branch_2'):
          branch_2 = slim.max_pool2d(net, [3, 3], stride=2, scope='MaxPool_1a_3x3')
        net = tf.concat(axis=concat_dim, values=[branch_0, branch_1, branch_2])
        end_points[end_point] = net
        if end_point == final_endpoint: return net, end_points
      # 7 x 7 x 1024
      end_point = 'Mixed_5b'
      with tf.variable_scope(end_point):
        with tf.variable_scope('Branch_0'):
          branch_0 = slim.conv2d(net, depth(352), [1, 1], scope='Conv2d_0a_1x1')
        with tf.variable_scope('Branch_1'):
          branch_1 = slim.conv2d(net, depth(192), [1, 1],
												 weights_initializer=trunc_normal(0.09),
												 scope='Conv2d_0a_1x1')
          branch_1 = slim.conv2d(branch_1, depth(320), [3, 3], scope='Conv2d_0b_3x3')
        with tf.variable_scope('Branch_2'):
          branch_2 = slim.conv2d(net, depth(160), [1, 1],
												 weights_initializer=trunc_normal(0.09),
												 scope='Conv2d_0a_1x1')
          branch_2 = slim.conv2d(branch_2, depth(224), [3, 3], scope='Conv2d_0b_3x3')
          branch_2 = slim.conv2d(branch_2, depth(224), [3, 3], scope='Conv2d_0c_3x3')
        with tf.variable_scope('Branch_3'):
          branch_3 = slim.avg_pool2d(net, [3, 3], scope='AvgPool_0a_3x3')
          branch_3 = slim.conv2d(branch_3, depth(128), [1, 1],
												 weights_initializer=trunc_normal(0.1),
												 scope='Conv2d_0b_1x1')
        net = tf.concat(axis=concat_dim, values=[branch_0, branch_1, branch_2, branch_3])
        end_points[end_point] = net
        if end_point == final_endpoint: return net, end_points
      # 7 x 7 x 1024
      end_point = 'Mixed_5c'
      with tf.variable_scope(end_point):
        with tf.variable_scope('Branch_0'):
          branch_0 = slim.conv2d(net, depth(352), [1, 1], scope='Conv2d_0a_1x1')
        with tf.variable_scope('Branch_1'):
          branch_1 = slim.conv2d(net, depth(192), [1, 1],
												 weights_initializer=trunc_normal(0.09),
												 scope='Conv2d_0a_1x1')
          branch_1 = slim.conv2d(branch_1, depth(320), [3, 3], scope='Conv2d_0b_3x3')
        with tf.variable_scope('Branch_2'):
          branch_2 = slim.conv2d(net, depth(192), [1, 1],
												 weights_initializer=trunc_normal(0.09),
												 scope='Conv2d_0a_1x1')
          branch_2 = slim.conv2d(branch_2, depth(224), [3, 3], scope='Conv2d_0b_3x3')
          branch_2 = slim.conv2d(branch_2, depth(224), [3, 3], scope='Conv2d_0c_3x3')
        with tf.variable_scope('Branch_3'):
          branch_3 = slim.max_pool2d(net, [3, 3], scope='MaxPool_0a_3x3')
          branch_3 = slim.conv2d(branch_3, depth(128), [1, 1],
												 weights_initializer=trunc_normal(0.1),
												 scope='Conv2d_0b_1x1')
        net = tf.concat(axis=concat_dim, values=[branch_0, branch_1, branch_2, branch_3])
        end_points[end_point] = net
        if end_point == final_endpoint: return net, end_points
    raise ValueError('Unknown final endpoint %s' % final_endpoint)


def inception_v2(inputs,
							num_classes=1000,
							is_training=True,
							dropout_keep_prob=0.8,
							min_depth=16,
							depth_multiplier=1.0,
							prediction_fn=slim.softmax,
							spatial_squeeze=True,
							reuse=None,
							scope='InceptionV2',
							global_pool=False):
  """
  Inception v2 分类模型.
  网络训练的默认图片输入尺寸为 224x224.

  参数:
    inputs: Tensor,尺寸为 [batch_size, height, width, channels].
    num_classes: 待预测的类别数. 
				如果 num_classes=0或None,则忽略 logits 层;返回 logits 层的输入特征(dropout 层前的网络层).
    is_training: 是否是训练阶段.
    dropout_keep_prob: 保留的激活值的比例.
    min_depth: 所有卷积 ops 的最小深度值(通道数,depth value (number of channels)).
					  当 depth_multiplier < 1 时,强制执行;
					  当 depth_multiplier >= 1 时,不是主动约束项.
    depth_multiplier: 所有卷积 ops 深度(depth (number of channels))的浮点数乘子.
								该值必须大于 0.
								一般是将该值设为 (0, 1) 间的浮点数值,以减少参数量或模型的计算量.
    prediction_fn: 计算 logits 预测值输出的函数,如softmax.
    spatial_squeeze: 如果是 True, logits 的 shape 是 [B, C];
                                如果是 false,则 logits 的 shape 是 [B, 1, 1, C];
                                其中,B 是 batch_size,C 是类别数.
    reuse: 是否重用网络及网络的变量值.
				如果需要重用,则必须给定重用的 'scope'.
    scope: 可选变量作用域 variable_scope.
    global_pool: 可选 boolean 值,选择是否在 logits 网络层前使用 avgpooling 层.
						默认值是 fasle,则采用固定窗口的 pooling 层,将 inputs 降低到 1x1.
						inputs 越大,则 outputs 越大.
						如果值是 true, 则任何 inputs 尺寸都 pooled 到 1x1.

  Returns:
    net: Tensor,如果 num_classes 为非零值,则返回 logits(pre-softmax activations).
			如果 num_classes 是 0 或 None,则返回 logits 网络层的 non-dropped-out 输入.
    end_points: 字典,包含网络各层的激活值.

  Raises:
    ValueError: if final_endpoint is not set to one of the predefined values,
                or depth_multiplier <= 0
  """
  if depth_multiplier <= 0:
    raise ValueError('depth_multiplier is not greater than zero.')

  # Final pooling and prediction
  with tf.variable_scope(scope, 'InceptionV2', [inputs], reuse=reuse) as scope:
    with slim.arg_scope([slim.batch_norm, slim.dropout], is_training=is_training):
      net, end_points = inception_v2_base(inputs,  scope=scope, 
										  min_depth=min_depth, depth_multiplier=depth_multiplier)
      with tf.variable_scope('Logits'):
        if global_pool:
          # Global average pooling.
          net = tf.reduce_mean(net, [1, 2], keep_dims=True, name='global_pool')
          end_points['global_pool'] = net
        else:
          # Pooling with a fixed kernel size.
          kernel_size = _reduced_kernel_size_for_small_input(net, [7, 7])  ##
          net = slim.avg_pool2d(net, kernel_size, padding='VALID',
												scope='AvgPool_1a_{}x{}'.format(*kernel_size))
          end_points['AvgPool_1a'] = net
        if not num_classes:
          return net, end_points
        # 1 x 1 x 1024
        net = slim.dropout(net, keep_prob=dropout_keep_prob, scope='Dropout_1b')
        logits = slim.conv2d(net, num_classes, [1, 1], activation_fn=None,
										normalizer_fn=None, scope='Conv2d_1c_1x1')
        if spatial_squeeze:
          logits = tf.squeeze(logits, [1, 2], name='SpatialSqueeze')
      end_points['Logits'] = logits
      end_points['Predictions'] = prediction_fn(logits, scope='Predictions')
  return logits, end_points

inception_v2.default_image_size = 224


def _reduced_kernel_size_for_small_input(input_tensor, kernel_size):
  """
  定义核大小,用于对小尺寸的核大小的自动减小.
  Define kernel size which is automatically reduced for small input.
  
  创建 graph 时,如果输入图片的尺寸未知,则该函数假设输入图片尺寸足够大.

  参数:
    input_tensor: 输入 Tensor,尺寸为 [batch_size, height, width, channels].
    kernel_size: desired kernel size of length 2: [kernel_height, kernel_width]

  Returns:
    a tensor with the kernel size.

  TODO(jrru): Make this function work with unknown shapes. Theoretically, this
  can be done with the code below. Problems are two-fold: (1) If the shape was
  known, it will be lost. (2) inception.slim.ops._two_element_tuple cannot
  handle tensors that define the kernel size.
      shape = tf.shape(input_tensor)
      return = tf.stack([tf.minimum(shape[1], kernel_size[0]),
                         tf.minimum(shape[2], kernel_size[1])])

  """
  shape = input_tensor.get_shape().as_list()
  if shape[1] is None or shape[2] is None:
    kernel_size_out = kernel_size
  else:
    kernel_size_out = [min(shape[1], kernel_size[0]),
								  min(shape[2], kernel_size[1])]
  return kernel_size_out

inception_v2_arg_scope = inception_utils.inception_arg_scope

你可能感兴趣的:(网络结构)