resnet-tf

import tensorflow as tf
def resnet(inputs,num_classes=21, n_blocks=(3, 4, 23, 3), filter_list=(64, 128, 256, 512)):
    def _conv_bn_relu(inputs, filters, k, s, padding='same', relu=True):
        x = tf.layers.conv2d(inputs, filters, k, s, padding)
        x = tf.layers.batch_normalization(x)
        return tf.nn.relu(x) if relu else x

    def res_block(inputs, filters, strides, is_first=False):
        if is_first:
            # downsample if strides==2
            shortcut = _conv_bn_relu(inputs, filters * 4, 1, strides, relu=False)
            net = _conv_bn_relu(inputs, filters, 1, strides)
        else:
            shortcut = inputs
            net = _conv_bn_relu(inputs, filters, 1, 1)
        net = _conv_bn_relu(net, filters, 3, 1)
        net = _conv_bn_relu(net, filters * 4, 1, 1, relu=False)
        return tf.nn.relu(shortcut + net)

    def build(inputs):
        x = tf.layers.conv2d(inputs, 64, 7, 2, 'same')
        x = tf.layers.max_pooling2d(x, 3, 2, 'same')
        end_points = {}
        for i, n in enumerate(n_blocks):
            strides = 1 if i == 0 else 2
            filters = filter_list[i]
            x = res_block(x, filters, strides, is_first=True)  # 0
            for j in range(1, n):
                x = res_block(x, filters, 1)
            end_points[i] = x
        x = tf.layers.average_pooling2d(x, 7, 7)
        x = tf.layers.flatten(x)
        logits = tf.layers.dense(x, num_classes)
        predict = tf.argmax(logits, axis=1)
        return logits, predict,end_points

    return build(inputs)


class ResNet:
    def __init__(self, num_classes):
        self.num_classes = num_classes
        self.n_blocks = (3, 4, 23, 3)
        self.filter_list = (64, 128, 256, 512)

    def _conv_bn_relu(self, inputs, filters, k, s, padding='same', relu=True):
        x = tf.layers.conv2d(inputs, filters, k, s, padding)
        x = tf.layers.batch_normalization(x)
        return tf.nn.relu(x) if relu else x

    def res_block(self, inputs, filters, strides, is_first=False):
        if is_first:
            # downsample if strides==2
            shortcut = self._conv_bn_relu(inputs, filters * 4, 1, strides, relu=False)
            net = self._conv_bn_relu(inputs, filters, 1, strides)
        else:
            shortcut = inputs
            net = self._conv_bn_relu(inputs, filters, 1, 1)
        net = self._conv_bn_relu(net, filters, 3, 1)
        net = self._conv_bn_relu(net, filters * 4, 1, 1, relu=False)
        return tf.nn.relu(shortcut + net)

    def __call__(self, inputs):
        x = tf.layers.conv2d(inputs, 64, 7, 2, 'same')
        x = tf.layers.max_pooling2d(x, 3, 2, 'same')
        for i, n in enumerate(self.n_blocks):
            strides = 1 if i == 0 else 2
            filters = self.filter_list[i]
            x = self.res_block(x, filters, strides, is_first=True)  # 0
            for _ in range(1, n):
                x = self.res_block(x, filters, 1)

        x = tf.layers.average_pooling2d(x, 7, 7)
        x = tf.squeeze(x, [1, 2])
        self.logits = tf.layers.dense(x, self.num_classes)
        self.predict = tf.argmax(self.logits, axis=1)
        return self.logits, self.predict


if __name__ == '__main__':
    x = tf.placeholder(tf.float32, [None, 224, 224, 3])
    # x = ResNet(21)(x)
    x = resnet(x)
    print(x)

你可能感兴趣的:(tf)