根据pytorch版本复现,传送门(https://github.com/prlz77/ResNeXt.pytorch)
使用tensorflow底层API实现(tf.nn),tf2.x略坑,高级API大改
本网络笔者没有进行具体的实验,参数也没有调整,效果如何不保证,仅做练手所用,如需使用请自行更改或移步
import tensorflow as tf
import numpy as np
def Conv(input, filters, kernel_size, stride, bias, weight_decay, name, cardinality=1, padding='SAME'):
in_channel = input.shape[-1]
filter_shape = [kernel_size, kernel_size, int(in_channel//cardinality), filters]
l2_regularizer = tf.contrib.layers.l2_regularizer(scale=weight_decay)
filter = tf.get_variable(initializer = tf.truncated_normal(filter_shape, mean=0.0, stddev=1.0, dtype=tf.float32),
regularizer=l2_regularizer, name=name+'_weights')
if cardinality != 1:
convolve = lambda i, k: tf.nn.conv2d(i, k, strides=[1, stride, stride, 1], padding=padding)
input_groups = tf.split(axis=3, num_or_size_splits=cardinality, value=input)
weights_groups = tf.split(axis=3, num_or_size_splits=cardinality, value=filter)
output_groups = [convolve(i, k) for i, k in zip(input_groups, weights_groups)]
x = tf.concat(axis=3, values=output_groups)
return x
x = tf.nn.conv2d(input, filter, strides=[1, stride, stride, 1], padding=padding, name=name)
return x
def Dense(input, out_units, weight_decay, name):
in_units = int(input.shape[-1])
l2_regularizer = tf.contrib.layers.l2_regularizer(scale=weight_decay)
W = tf.get_variable(initializer = tf.truncated_normal([in_units, out_units], mean=0.0, stddev=1.0, dtype=tf.float32),
regularizer=l2_regularizer, name=name+'_weights')
bias = tf.get_variable(shape = [out_units], dtype=tf.float32, name=name+'_bias')
return tf.matmul(input, W) + bias
def BN(input, name, is_training=True, moving_decay=0.9, eps=1e-5):
batch_mean, batch_var = tf.nn.moments(input, axes=list(range(len(input.shape)-1)))
gamma = tf.get_variable(name+'_gamma', input.shape[-1], initializer=tf.constant_initializer(1), dtype=tf.float32)
beta = tf.get_variable(name+'_beta', input.shape[-1], initializer=tf.constant_initializer(0), dtype=tf.float32)
# 采用滑动平均更新均值与方差
ema = tf.train.ExponentialMovingAverage(moving_decay)
def mean_var_with_update():
ema_apply_op = ema.apply([batch_mean, batch_var])
with tf.control_dependencies([ema_apply_op]):
return tf.identity(batch_mean), tf.identity(batch_var)
# 训练时,更新均值与方差,测试时使用之前最后一次保存的均值与方差
mean, var = tf.cond(tf.equal(is_training, True), mean_var_with_update,
lambda: (ema.average(batch_mean), ema.average(batch_var)))
# 最后执行batch normalization
return tf.nn.batch_normalization(input, mean, var, beta, gamma, eps, name=name)
class ResNeXt():
def __init__(self, args):
self.learning_rate = args.learning_rate
self.momentum = args.momentum
self.cardinality = args.cardinality
self.depth = args.depth
self.base_width = args.base_width
self.widen_factor = args.widen_factor
self.weight_decay = args.weight_decay
self.class_num = args.class_num
self.stages = [64, 64 * self.widen_factor, 128 * self.widen_factor, 256 * self.widen_factor]
self.pool_stride=[1, 2, 2]
self.block_depth = (self.depth - 2) // 9
def _Bottleneck(self, x, in_channels, out_channels, stride, cardinality, base_width, widen_factor):
_x = x
width_ratio = out_channels / (widen_factor * 64.)
D = cardinality * int(base_width * width_ratio)
x = Conv(x, D, kernel_size=1, stride=1, bias=False, weight_decay=self.weight_decay,
padding='VALID', name='ord_conv1')
x = BN(x, is_training=self.trainable, name='bottleneck_bn1')
x = tf.nn.relu(x, name='bottleneck_relu1')
#group conv
x = Conv(x, D, kernel_size=3, stride=stride, bias=False, weight_decay=self.weight_decay,
cardinality=cardinality, padding='SAME', name='group_conv')
x = BN(x, is_training=self.trainable, name='bottleneck_bn2')
x = tf.nn.relu(x, name='bottleneck_relu2')
x = Conv(x, out_channels, kernel_size=1, stride=1, bias=False, weight_decay=self.weight_decay,
padding='VALID', name='ord_conv2')
x = tf.nn.relu(x, name='bottleneck_relu3')
if in_channels != out_channels:
_x = Conv(_x, out_channels, kernel_size=1, stride=stride, bias=False,
weight_decay=self.weight_decay, padding='VALID', name='short_conv')
_x = BN(_x, is_training=self.trainable, name='short_bn')
x = x + _x
x = tf.nn.relu(x, name='bottleneck_bn3')
return x
#每个block中包含多个bottleneck,本网络是3个block,每个block3个bottleneck
def _block(self, x, in_channels, out_channels, pool_stride=2):
for j in range(self.block_depth):
with tf.variable_scope('Bottleneck_'+str(j+1)):
if j == 0:
x = self._Bottleneck(x, in_channels, out_channels, pool_stride, self.cardinality, self.base_width, self.widen_factor)
else:
x = self._Bottleneck(x, out_channels, out_channels, 1, self.cardinality, self.base_width, self.widen_factor)
return x
def build_graph(self):
self.image = tf.placeholder(name='input', shape=[None, 29, 29, 3], dtype=tf.float32)
self.label = tf.placeholder(name='label', shape=[None, 10], dtype=tf.float32)
self.trainable = tf.placeholder(name='trainable', shape=None, dtype=bool)
self.epoch = tf.placeholder(name='schedule', shape=None, dtype=tf.float32)
with tf.variable_scope('Header'):
x = Conv(self.image, 64, 3, 1, False, self.weight_decay, padding='SAME', name='conv1')
x = BN(x, is_training=self.trainable, name='bn1')
x = tf.nn.relu(x, name='relu1')
for i in range(3):
with tf.variable_scope('Block_'+str(i+1)):
x = self._block(x, self.stages[i], self.stages[i+1], self.pool_stride[i])
with tf.variable_scope('Top'):
x = tf.nn.avg_pool(x, ksize=[1, x.shape[1], x.shape[2], 1], strides=[1, 1, 1, 1], padding='VALID')
x = tf.squeeze(x, axis=[1,2], name='squeeze')
self.logits = Dense(x, self.class_num, self.weight_decay, name='fc')
self.predictions = tf.nn.softmax(self.logits)
self.loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=self.logits, labels=self.label))
self.pre = tf.equal(tf.argmax(self.predictions,1), tf.argmax(self.label,1))
self.acc = tf.reduce_mean(tf.cast(self.pre, tf.float32))
lr = self.learning_rate*(tf.sqrt(0.8)**self.epoch)
self.opt = tf.train.MomentumOptimizer(lr, momentum=self.momentum, use_nesterov=True).minimize(self.loss)
print('Having built graph!')
if __name__=='__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--mode', default='train', type=str)
parser.add_argument('--epoch', default=300, type=int)
parser.add_argument('--class_num', default=10, type=int)
parser.add_argument('--batch_size', default=4, type=int)
#network parameter
parser.add_argument('--learning_rate', default=1e-2, type=float)
parser.add_argument('--momentum', default=0.9, type=float)
parser.add_argument('--weight_decay', default=5e-4, type=float)
#architecture
parser.add_argument('--cardinality', default=8, type=int)#8
parser.add_argument('--depth', default=29, type=int)#29
parser.add_argument('--base_width', default=32, type=int)#64
parser.add_argument('--widen_factor', default=4, type=int)#4
#data
parser.add_argument('--train_data', default='./data/train.tfrecord', type=str)
parser.add_argument('--test_data', default='./data/test.tfrecord', type=str)
args = parser.parse_args()
model = ResNeXt(args)
model.build_graph()