深度学习笔记——分类模型(六)ResNeXt50

# -*- coding: utf-8 -*-
import datetime
from keras.models import Model
from keras.layers import *
from keras.preprocessing.image import ImageDataGenerator


def GroupConv2D(input, filters, kernel_size, strides=(1, 1), padding='valid', cardinality=32, name=""):
    grouped_channels = int(filters / cardinality)

    group_list = []

    for c in range(cardinality):
        x = Lambda(lambda z: z[:, :, :, c * grouped_channels:(c + 1) * grouped_channels]
        if K.image_data_format() == 'channels_last' else
        lambda z: z[:, c * grouped_channels:(c + 1) * grouped_channels, :, :])(input)

        x = Conv2D(grouped_channels, kernel_size, strides=strides, padding=padding,
                   name=name + "_GroupConv2D_" + str(c + 1))(x)

        group_list.append(x)

    group_merge = concatenate(group_list, axis=-1)
    return group_merge


def identity_block(input_tensor, kernel_size, filters, stage, block):
    filters1, filters2, filters3 = filters

    bn_axis = 3

    conv_name_base = 'res' + str(stage) + block + '_branch'
    bn_name_base = 'bn' + str(stage) + block + '_branch'

    x = Conv2D(filters1, (1, 1), name=conv_name_base + '2a')(input_tensor)
    x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2a')(x)
    x = Activation('relu')(x)

    # x = Conv2D(filters2, kernel_size, padding='same', name=conv_name_base + '2b')(x)
    x = GroupConv2D(x, filters2, kernel_size, padding='same', name=conv_name_base + '2b')

    x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2b')(x)
    x = Activation('relu')(x)

    x = Conv2D(filters3, (1, 1), name=conv_name_base + '2c')(x)
    x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2c')(x)

    x = add([x, input_tensor])
    x = Activation('relu')(x)
    return x


def conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2, 2)):
    filters1, filters2, filters3 = filters

    bn_axis = 3

    conv_name_base = 'res' + str(stage) + block + '_branch'
    bn_name_base = 'bn' + str(stage) + block + '_branch'

    x = Conv2D(filters1, (1, 1), strides=strides,
               name=conv_name_base + '2a')(input_tensor)
    x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2a')(x)
    x = Activation('relu')(x)

    # x = Conv2D(filters2, kernel_size, padding='same', name=conv_name_base + '2b')(x)
    x = GroupConv2D(x, filters2, kernel_size, padding='same', name=conv_name_base + '2b')

    x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2b')(x)
    x = Activation('relu')(x)

    x = Conv2D(filters3, (1, 1), name=conv_name_base + '2c')(x)
    x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2c')(x)

    shortcut = Conv2D(filters3, (1, 1), strides=strides,
                      name=conv_name_base + '1')(input_tensor)
    shortcut = BatchNormalization(axis=bn_axis, name=bn_name_base + '1')(shortcut)

    x = add([x, shortcut])
    x = Activation('relu')(x)
    return x


def ResNeXt50(input_shape=(224, 224, 3), classes=1000):
    img_input = Input(shape=input_shape)

    if K.image_data_format() == 'channels_last':
        bn_axis = 3
    else:
        bn_axis = 1

    x = Conv2D(
        64, (7, 7), strides=(2, 2), padding='same', name='conv1')(img_input)
    x = BatchNormalization(axis=bn_axis, name='bn_conv1')(x)
    x = Activation('relu')(x)
    x = MaxPooling2D((3, 3), strides=(2, 2))(x)

    x = conv_block(x, 3, [128, 128, 256], stage=2, block='a', strides=(1, 1))
    x = identity_block(x, 3, [128, 128, 256], stage=2, block='b')
    x = identity_block(x, 3, [128, 128, 256], stage=2, block='c')

    x = conv_block(x, 3, [256, 256, 512], stage=3, block='a')
    x = identity_block(x, 3, [256, 256, 512], stage=3, block='b')
    x = identity_block(x, 3, [256, 256, 512], stage=3, block='c')
    x = identity_block(x, 3, [256, 256, 512], stage=3, block='d')

    x = conv_block(x, 3, [512, 512, 1024], stage=4, block='a')
    x = identity_block(x, 3, [512, 512, 1024], stage=4, block='b')
    x = identity_block(x, 3, [512, 512, 1024], stage=4, block='c')
    x = identity_block(x, 3, [512, 512, 1024], stage=4, block='d')
    x = identity_block(x, 3, [512, 512, 1024], stage=4, block='e')
    x = identity_block(x, 3, [512, 512, 1024], stage=4, block='f')

    x = conv_block(x, 3, [1024, 1024, 2048], stage=5, block='a')
    x = identity_block(x, 3, [1024, 1024, 2048], stage=5, block='b')
    x = identity_block(x, 3, [1024, 1024, 2048], stage=5, block='c')

    x = AveragePooling2D((7, 7), name='avg_pool')(x)

    x = Flatten()(x)
    if classes == 2:
        x = Dense(1, activation='sigmoid', name='fc1000')(x)
    else:
        x = Dense(classes, activation='softmax', name='fc1000')(x)

    model = Model(img_input, x, name='resnext50')

    return model


def main():
    width = 224
    height = 224
    batch_size = 8

    generator = ImageDataGenerator(horizontal_flip=True,
                                   vertical_flip=True,
                                   validation_split=0.2)

    train_generator = generator.flow_from_directory(directory="datasets/train",
                                                    target_size=(width, height),
                                                    batch_size=batch_size,
                                                    class_mode="binary",
                                                    subset="training")

    val_generator = generator.flow_from_directory(directory="datasets/train",
                                                  target_size=(width, height),
                                                  batch_size=batch_size,
                                                  class_mode="binary",
                                                  subset="validation")

    model = ResNeXt50(classes=2)

    model.summary()
    # Compile model
    model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])

    model.fit_generator(train_generator, validation_data=val_generator, epochs=10, verbose=1)


if __name__ == '__main__':
    tic = datetime.datetime.now()
    main()
    toc = datetime.datetime.now()
    print("\nThis model took ", (toc - tic))
 

你可能感兴趣的:(深度神经网络)