Tensorflow:复现ResNeXt

根据pytorch版本复现,传送门(https://github.com/prlz77/ResNeXt.pytorch)

使用tensorflow底层API实现(tf.nn),tf2.x略坑,高级API大改

本网络笔者没有进行具体的实验,参数也没有调整,效果如何不保证,仅做练手所用,如需使用请自行更改或移步

import tensorflow as tf
import numpy as np

def Conv(input, filters, kernel_size, stride, bias, weight_decay, name, cardinality=1, padding='SAME'):
    in_channel = input.shape[-1]

    filter_shape = [kernel_size, kernel_size, int(in_channel//cardinality), filters]
    l2_regularizer = tf.contrib.layers.l2_regularizer(scale=weight_decay)
    filter = tf.get_variable(initializer = tf.truncated_normal(filter_shape, mean=0.0, stddev=1.0, dtype=tf.float32),
                             regularizer=l2_regularizer, name=name+'_weights')

    if cardinality != 1:
        convolve = lambda i, k: tf.nn.conv2d(i, k, strides=[1, stride, stride, 1], padding=padding)
        input_groups = tf.split(axis=3, num_or_size_splits=cardinality, value=input)
        weights_groups = tf.split(axis=3, num_or_size_splits=cardinality, value=filter)
        output_groups = [convolve(i, k) for i, k in zip(input_groups, weights_groups)]
        x = tf.concat(axis=3, values=output_groups)
        return x

    x = tf.nn.conv2d(input, filter, strides=[1, stride, stride, 1], padding=padding, name=name)
    return x

def Dense(input, out_units, weight_decay, name):
    in_units = int(input.shape[-1])
    l2_regularizer = tf.contrib.layers.l2_regularizer(scale=weight_decay)
    W = tf.get_variable(initializer = tf.truncated_normal([in_units, out_units], mean=0.0, stddev=1.0, dtype=tf.float32),
                        regularizer=l2_regularizer, name=name+'_weights')
    bias = tf.get_variable(shape = [out_units], dtype=tf.float32, name=name+'_bias')
    return tf.matmul(input, W) + bias


def BN(input, name, is_training=True, moving_decay=0.9, eps=1e-5):
    batch_mean, batch_var = tf.nn.moments(input, axes=list(range(len(input.shape)-1)))

    gamma = tf.get_variable(name+'_gamma', input.shape[-1], initializer=tf.constant_initializer(1), dtype=tf.float32)
    beta = tf.get_variable(name+'_beta', input.shape[-1], initializer=tf.constant_initializer(0), dtype=tf.float32)

    # 采用滑动平均更新均值与方差
    ema = tf.train.ExponentialMovingAverage(moving_decay)

    def mean_var_with_update():
        ema_apply_op = ema.apply([batch_mean, batch_var])
        with tf.control_dependencies([ema_apply_op]):
            return tf.identity(batch_mean), tf.identity(batch_var)

    # 训练时,更新均值与方差,测试时使用之前最后一次保存的均值与方差
    mean, var = tf.cond(tf.equal(is_training, True), mean_var_with_update,
                        lambda: (ema.average(batch_mean), ema.average(batch_var)))

    # 最后执行batch normalization
    return tf.nn.batch_normalization(input, mean, var, beta, gamma, eps, name=name)

class ResNeXt():
    def __init__(self, args):
        self.learning_rate = args.learning_rate
        self.momentum = args.momentum
        self.cardinality = args.cardinality
        self.depth = args.depth
        self.base_width = args.base_width
        self.widen_factor = args.widen_factor
        self.weight_decay = args.weight_decay
        self.class_num = args.class_num
        self.stages = [64, 64 * self.widen_factor, 128 * self.widen_factor, 256 * self.widen_factor]
        self.pool_stride=[1, 2, 2]
        self.block_depth = (self.depth - 2) // 9

    def _Bottleneck(self, x, in_channels, out_channels, stride, cardinality, base_width, widen_factor):

        _x = x

        width_ratio = out_channels / (widen_factor * 64.)
        D = cardinality * int(base_width * width_ratio)

        x = Conv(x, D, kernel_size=1, stride=1, bias=False, weight_decay=self.weight_decay,
                 padding='VALID', name='ord_conv1')
        x = BN(x, is_training=self.trainable, name='bottleneck_bn1')
        x = tf.nn.relu(x, name='bottleneck_relu1')

        #group conv
        x = Conv(x, D, kernel_size=3, stride=stride, bias=False, weight_decay=self.weight_decay,
                 cardinality=cardinality, padding='SAME', name='group_conv')
        x = BN(x, is_training=self.trainable, name='bottleneck_bn2')
        x = tf.nn.relu(x, name='bottleneck_relu2')

        x = Conv(x, out_channels, kernel_size=1, stride=1, bias=False, weight_decay=self.weight_decay,
                 padding='VALID', name='ord_conv2')
        x = tf.nn.relu(x, name='bottleneck_relu3')

        if in_channels != out_channels:
            _x = Conv(_x, out_channels, kernel_size=1, stride=stride, bias=False,
                             weight_decay=self.weight_decay, padding='VALID', name='short_conv')
            _x = BN(_x, is_training=self.trainable, name='short_bn')

        x = x + _x
        x = tf.nn.relu(x, name='bottleneck_bn3')


        return x

    #每个block中包含多个bottleneck,本网络是3个block,每个block3个bottleneck
    def _block(self, x, in_channels, out_channels, pool_stride=2):
        for j in range(self.block_depth):
            with tf.variable_scope('Bottleneck_'+str(j+1)):
                if j == 0:
                    x = self._Bottleneck(x, in_channels, out_channels, pool_stride, self.cardinality, self.base_width, self.widen_factor)
                else:
                    x = self._Bottleneck(x, out_channels, out_channels, 1, self.cardinality, self.base_width, self.widen_factor)
        return x

    def build_graph(self):

        self.image = tf.placeholder(name='input', shape=[None, 29, 29, 3], dtype=tf.float32)
        self.label = tf.placeholder(name='label', shape=[None, 10], dtype=tf.float32)
        self.trainable = tf.placeholder(name='trainable', shape=None, dtype=bool)
        self.epoch = tf.placeholder(name='schedule', shape=None, dtype=tf.float32)

        with tf.variable_scope('Header'):
            x = Conv(self.image, 64, 3, 1, False, self.weight_decay, padding='SAME', name='conv1')
            x = BN(x, is_training=self.trainable, name='bn1')
            x = tf.nn.relu(x, name='relu1')

        for i in range(3):
            with tf.variable_scope('Block_'+str(i+1)):
                x = self._block(x, self.stages[i], self.stages[i+1], self.pool_stride[i])

        with tf.variable_scope('Top'):
            x = tf.nn.avg_pool(x, ksize=[1, x.shape[1], x.shape[2], 1], strides=[1, 1, 1, 1], padding='VALID')
            x = tf.squeeze(x, axis=[1,2], name='squeeze')
            self.logits = Dense(x, self.class_num, self.weight_decay, name='fc')
            self.predictions = tf.nn.softmax(self.logits)

        self.loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=self.logits, labels=self.label))
        self.pre = tf.equal(tf.argmax(self.predictions,1), tf.argmax(self.label,1))
        self.acc = tf.reduce_mean(tf.cast(self.pre, tf.float32))

        lr = self.learning_rate*(tf.sqrt(0.8)**self.epoch)
        self.opt = tf.train.MomentumOptimizer(lr, momentum=self.momentum, use_nesterov=True).minimize(self.loss)

        print('Having built graph!')


if __name__=='__main__':
    import argparse

    parser = argparse.ArgumentParser()

    parser.add_argument('--mode', default='train', type=str)
    parser.add_argument('--epoch', default=300, type=int)
    parser.add_argument('--class_num', default=10, type=int)
    parser.add_argument('--batch_size', default=4, type=int)
    #network parameter
    parser.add_argument('--learning_rate', default=1e-2, type=float)
    parser.add_argument('--momentum', default=0.9, type=float)
    parser.add_argument('--weight_decay', default=5e-4, type=float)
    #architecture
    parser.add_argument('--cardinality', default=8, type=int)#8
    parser.add_argument('--depth', default=29, type=int)#29
    parser.add_argument('--base_width', default=32, type=int)#64
    parser.add_argument('--widen_factor', default=4, type=int)#4
    #data
    parser.add_argument('--train_data', default='./data/train.tfrecord', type=str)
    parser.add_argument('--test_data', default='./data/test.tfrecord', type=str)

    args = parser.parse_args()

    model = ResNeXt(args)
    model.build_graph()





 

你可能感兴趣的:(Tensorflow:复现ResNeXt)