import tensorflow as tf
def resnet(inputs,num_classes=21, n_blocks=(3, 4, 23, 3), filter_list=(64, 128, 256, 512)):
def _conv_bn_relu(inputs, filters, k, s, padding='same', relu=True):
x = tf.layers.conv2d(inputs, filters, k, s, padding)
x = tf.layers.batch_normalization(x)
return tf.nn.relu(x) if relu else x
def res_block(inputs, filters, strides, is_first=False):
if is_first:
# downsample if strides==2
shortcut = _conv_bn_relu(inputs, filters * 4, 1, strides, relu=False)
net = _conv_bn_relu(inputs, filters, 1, strides)
else:
shortcut = inputs
net = _conv_bn_relu(inputs, filters, 1, 1)
net = _conv_bn_relu(net, filters, 3, 1)
net = _conv_bn_relu(net, filters * 4, 1, 1, relu=False)
return tf.nn.relu(shortcut + net)
def build(inputs):
x = tf.layers.conv2d(inputs, 64, 7, 2, 'same')
x = tf.layers.max_pooling2d(x, 3, 2, 'same')
end_points = {}
for i, n in enumerate(n_blocks):
strides = 1 if i == 0 else 2
filters = filter_list[i]
x = res_block(x, filters, strides, is_first=True) # 0
for j in range(1, n):
x = res_block(x, filters, 1)
end_points[i] = x
x = tf.layers.average_pooling2d(x, 7, 7)
x = tf.layers.flatten(x)
logits = tf.layers.dense(x, num_classes)
predict = tf.argmax(logits, axis=1)
return logits, predict,end_points
return build(inputs)
class ResNet:
def __init__(self, num_classes):
self.num_classes = num_classes
self.n_blocks = (3, 4, 23, 3)
self.filter_list = (64, 128, 256, 512)
def _conv_bn_relu(self, inputs, filters, k, s, padding='same', relu=True):
x = tf.layers.conv2d(inputs, filters, k, s, padding)
x = tf.layers.batch_normalization(x)
return tf.nn.relu(x) if relu else x
def res_block(self, inputs, filters, strides, is_first=False):
if is_first:
# downsample if strides==2
shortcut = self._conv_bn_relu(inputs, filters * 4, 1, strides, relu=False)
net = self._conv_bn_relu(inputs, filters, 1, strides)
else:
shortcut = inputs
net = self._conv_bn_relu(inputs, filters, 1, 1)
net = self._conv_bn_relu(net, filters, 3, 1)
net = self._conv_bn_relu(net, filters * 4, 1, 1, relu=False)
return tf.nn.relu(shortcut + net)
def __call__(self, inputs):
x = tf.layers.conv2d(inputs, 64, 7, 2, 'same')
x = tf.layers.max_pooling2d(x, 3, 2, 'same')
for i, n in enumerate(self.n_blocks):
strides = 1 if i == 0 else 2
filters = self.filter_list[i]
x = self.res_block(x, filters, strides, is_first=True) # 0
for _ in range(1, n):
x = self.res_block(x, filters, 1)
x = tf.layers.average_pooling2d(x, 7, 7)
x = tf.squeeze(x, [1, 2])
self.logits = tf.layers.dense(x, self.num_classes)
self.predict = tf.argmax(self.logits, axis=1)
return self.logits, self.predict
if __name__ == '__main__':
x = tf.placeholder(tf.float32, [None, 224, 224, 3])
# x = ResNet(21)(x)
x = resnet(x)
print(x)