CNN模型搭建

model.py

from keras.layers import Activation, Convolution2D, Dropout, Dense, Flatten
from keras.layers.advanced_activations import PReLU
from keras.layers import AveragePooling2D, BatchNormalization
from keras.models import Sequential

def simple_CNN(input_shape, num_classes):

    model = Sequential()

    model.add(Convolution2D(16, 7, 7, border_mode='same',
                            input_shape=input_shape))
    model.add(PReLU())
    model.add(BatchNormalization())
    model.add(AveragePooling2D(pool_size=(5, 5),strides=(2, 2), border_mode='same'))
    model.add(Dropout(.5))

    model.add(Convolution2D(32, 5, 5, border_mode='same'))
    model.add(PReLU())
    model.add(BatchNormalization())
    model.add(AveragePooling2D(pool_size=(3, 3),strides=(2, 2), border_mode='same'))
    model.add(Dropout(.5))

    model.add(Convolution2D(32, 3, 3, border_mode='same'))
    model.add(PReLU())
    model.add(BatchNormalization())
    model.add(AveragePooling2D(pool_size=(3, 3),strides=(2, 2), border_mode='same'))
    model.add(Dropout(.5))

    model.add(Flatten())
    model.add(Dense(1028))
    model.add(PReLU())
    model.add(Dropout(0.5))
    model.add(Dense(1028))
    model.add(PReLU())
    model.add(Dropout(0.5))
    model.add(Dense(num_classes))
    model.add(Activation('softmax'))

    return model

if __name__ == "__main__":
    input_shape = (64, 64, 1)
    num_classes = 7

    model = simple_CNN((48, 48, 1), num_classes)
    model.summary()

CNN网络结构

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d_1 (Conv2D)            (None, 48, 48, 16)        800       
_________________________________________________________________
p_re_lu_1 (PReLU)            (None, 48, 48, 16)        36864     
_________________________________________________________________
batch_normalization_1 (Batch (None, 48, 48, 16)        64        
_________________________________________________________________
average_pooling2d_1 (Average (None, 24, 24, 16)        0         
_________________________________________________________________
dropout_1 (Dropout)          (None, 24, 24, 16)        0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 24, 24, 32)        12832     
_________________________________________________________________
p_re_lu_2 (PReLU)            (None, 24, 24, 32)        18432     
_________________________________________________________________
batch_normalization_2 (Batch (None, 24, 24, 32)        128       
_________________________________________________________________
average_pooling2d_2 (Average (None, 12, 12, 32)        0         
_________________________________________________________________
dropout_2 (Dropout)          (None, 12, 12, 32)        0         
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 12, 12, 32)        9248      
_________________________________________________________________
p_re_lu_3 (PReLU)            (None, 12, 12, 32)        4608      
_________________________________________________________________
batch_normalization_3 (Batch (None, 12, 12, 32)        128       
_________________________________________________________________
average_pooling2d_3 (Average (None, 6, 6, 32)          0         
_________________________________________________________________
dropout_3 (Dropout)          (None, 6, 6, 32)          0         
_________________________________________________________________
flatten_1 (Flatten)          (None, 1152)              0         
_________________________________________________________________
dense_1 (Dense)              (None, 1028)              1185284   
_________________________________________________________________
p_re_lu_4 (PReLU)            (None, 1028)              1028      
_________________________________________________________________
dropout_4 (Dropout)          (None, 1028)              0         
_________________________________________________________________
dense_2 (Dense)              (None, 1028)              1057812   
_________________________________________________________________
p_re_lu_5 (PReLU)            (None, 1028)              1028      
_________________________________________________________________
dropout_5 (Dropout)          (None, 1028)              0         
_________________________________________________________________
dense_3 (Dense)              (None, 7)                 7203      
_________________________________________________________________
activation_1 (Activation)    (None, 7)                 0         
=================================================================
Total params: 2,335,459
Trainable params: 2,335,299
Non-trainable params: 160
_________________________________________________________________

你可能感兴趣的:(ai)