使用Squeeze-Net 为基础网络的ssd(keras)

至于该网络优劣点还有待实验

# coding=utf-8
"""Keras implementation of SSD."""

import keras.backend as K
from keras.layers import Activation
from keras.layers import AtrousConvolution2D
from keras.layers import Conv2D
from keras.layers import Dense
from keras.layers import Flatten
from keras.layers import GlobalAveragePooling2D
from keras.layers import Input
from keras.layers import MaxPooling2D
from keras.layers import merge
from keras.layers import Reshape
from keras.layers import ZeroPadding2D
from keras.models import Model, Sequential

from ssd_layers import Normalize
from ssd_layers import PriorBox

from keras.models import Sequential
from keras.layers import Dense, Flatten, Dropout, Concatenate
from keras.layers.convolutional import Conv2D, MaxPooling2D
import numpy as np


def SqueezeNet(inputs, nb_classes=21):
    """ Keras Implementation of SqueezeNet(arXiv 1602.07360)
    @param nb_classes: total number of final categories
    Arguments:
    inputs -- shape of the input images (channel, cols, rows)
    """
    img_size = (inputs[0], inputs[1])
    input_img = (Input(shape=inputs))
    conv1 = Conv2D(
        64, (3, 3), activation='relu', kernel_initializer='glorot_uniform',
        strides=(1, 1), name='conv1', padding='same',
        data_format="channels_last")(input_img)
    # maxpool1
    maxpool1 = MaxPooling2D(
        pool_size=(2, 2), strides=(2, 2), name='maxpool1',
        data_format="channels_last")(conv1)

    # fire1
    fire1_squeeze = Conv2D(
        15, (1, 1), activation='relu', kernel_initializer='glorot_uniform',
        name='fire1_squeeze',
        data_format="channels_last")(maxpool1)
    fire1_expand1 = Conv2D(
        49, (1, 1), activation='relu', kernel_initializer='glorot_uniform',
        padding='same', name='fire1_expand1',
        data_format="channels_last")(fire1_squeeze)
    fire1_expand2 = Conv2D(
        53, (3, 3), activation='relu', kernel_initializer='glorot_uniform',
        padding='same', name='fire1_expand2',
        data_format="channels_last")(fire1_squeeze)
    merge1 = Concatenate(axis=3)([fire1_expand1, fire1_expand2])

    maxpool2 = MaxPooling2D(
        pool_size=(2, 2), strides=(2, 2), name='maxpool2',
        data_format="channels_last")(merge1)

    # fire2
    fire2_squeeze = Conv2D(
        15, (1, 1), activation='relu', kernel_initializer='glorot_uniform',
        name='fire2_squeeze',
        data_format="channels_last")(maxpool2)
    fire2_expand1 = Conv2D(
        54, (1, 1), activation='relu', kernel_initializer='glorot_uniform',
        padding='same', name='fire2_expand1',
        data_format="channels_last")(fire2_squeeze)
    fire2_expand2 = Conv2D(
        52, (3, 3), activation='relu', kernel_initializer='glorot_uniform',
        padding='same', name='fire2_expand2',
        data_format="channels_last")(fire2_squeeze)
    merge2 = Concatenate(axis=3)([fire2_expand1, fire2_expand2])

    # maxpool3
    maxpool3 = MaxPooling2D(
        pool_size=(3, 3), strides=(2, 2), padding='same', name='maxpool3',
        data_format="channels_last")(merge2)

    # fire3
    fire3_squeeze = Conv2D(
        29, (1, 1), activation='relu', kernel_initializer='glorot_uniform',
        name='fire3_squeeze',
        data_format="channels_last")(maxpool3)
    fire3_expand1 = Conv2D(
        92, (1, 1), activation='relu', kernel_initializer='glorot_uniform',
        padding='same', name='fire3_expand1',
        data_format="channels_last")(fire3_squeeze)
    fire3_expand2 = Conv2D(
        94, (3, 3), activation='relu', kernel_initializer='glorot_uniform',
        padding='same', name='fire3_expand2',
        data_format="channels_last")(fire3_squeeze)
    merge3 = Concatenate(axis=3)([fire3_expand1, fire3_expand2])

    # fire4
    fire4_squeeze = Conv2D(
        29, (1, 1), activation='relu', kernel_initializer='glorot_uniform',
        name='fire4_squeeze',
        data_format="channels_last")(merge3)
    fire4_expand1 = Conv2D(
        90, (1, 1), activation='relu', kernel_initializer='glorot_uniform',
        name='fire4_expand1',
        data_format="channels_last")(fire4_squeeze)
    fire4_expand2 = Conv2D(
        83, (3, 3), activation='relu', kernel_initializer='glorot_uniform',
        padding='same', name='fire4_expand2',
        data_format="channels_last")(fire4_squeeze)
    merge4 = Concatenate(axis=3)([fire4_expand1, fire4_expand2])

    # maxpool4
    maxpool4 = MaxPooling2D(
        pool_size=(2, 2), strides=(2, 2), name='maxpool4', padding='same',
        data_format="channels_last")(merge4)

    # fire5
    fire5_squeeze = Conv2D(
        44, (1, 1), activation='relu', kernel_initializer='glorot_uniform',
        name='fire5_squeeze',
        data_format="channels_last")(maxpool4)
    fire5_expand1 = Conv2D(
        166, (1, 1), activation='relu', kernel_initializer='glorot_uniform',
        padding='same', name='fire5_expand1',
        data_format="channels_last")(fire5_squeeze)
    fire5_expand2 = Conv2D(
        161, (3, 3), activation='relu', kernel_initializer='glorot_uniform',
        padding='same', name='fire5_expand2',
        data_format="channels_last")(fire5_squeeze)
    merge5 = Concatenate(axis=3)([fire5_expand1, fire5_expand2])

    # fire6
    fire6_squeeze = Conv2D(
        45, (1, 1), activation='relu', kernel_initializer='glorot_uniform',
        name='fire6_squeeze',
        data_format="channels_last")(merge5)
    fire6_expand1 = Conv2D(
        155, (1, 1), activation='relu', kernel_initializer='glorot_uniform',
        padding='same', name='fire6_expand1',
        data_format="channels_last")(fire6_squeeze)
    fire6_expand2 = Conv2D(
        146, (3, 3), activation='relu', kernel_initializer='glorot_uniform',
        padding='same', name='fire6_expand2',
        data_format="channels_last")(fire6_squeeze)
    merge6 = Concatenate(axis=3)([fire6_expand1, fire6_expand2])

    # fire7
    fire7_squeeze = Conv2D(
        49, (1, 1), activation='relu', kernel_initializer='glorot_uniform',
        name='fire7_squeeze',
        data_format="channels_last")(merge6)
    fire7_expand1 = Conv2D(
        163, (1, 1), activation='relu', kernel_initializer='glorot_uniform',
        padding='same', name='fire7_expand1',
        data_format="channels_last")(fire7_squeeze)
    fire7_expand2 = Conv2D(
        171, (3, 3), activation='relu', kernel_initializer='glorot_uniform',
        padding='same', name='fire7_expand2',
        data_format="channels_last")(fire7_squeeze)
    merge7 = Concatenate(axis=3)([fire7_expand1, fire7_expand2])

    # fire8
    fire8_squeeze = Conv2D(
        25, (1, 1), activation='relu', kernel_initializer='glorot_uniform',
        name='fire8_squeeze',
        data_format="channels_last")(merge7)
    fire8_expand1 = Conv2D(
        29, (1, 1), activation='relu', kernel_initializer='glorot_uniform',
        padding='same', name='fire8_expand1',
        data_format="channels_last")(fire8_squeeze)
    fire8_expand2 = Conv2D(
        54, (3, 3), activation='relu', kernel_initializer='glorot_uniform',
        padding='same', name='fire8_expand2',
        data_format="channels_last")(fire8_squeeze)
    merge8 = Concatenate(axis=3)([fire8_expand1, fire8_expand2])

    # maxpool9
    maxpool9 = MaxPooling2D(
        pool_size=(3, 3), strides=(2, 2), padding='same', name='maxpool9',
        data_format="channels_last")(merge8)

    # fire9
    fire9_squeeze = Conv2D(
        37, (1, 1), activation='relu', kernel_initializer='glorot_uniform',
        name='fire9_squeeze',
        data_format="channels_last")(maxpool9)
    fire9_expand1 = Conv2D(
        45, (1, 1), activation='relu', kernel_initializer='glorot_uniform',
        padding='same', name='fire9_expand1',
        data_format="channels_last")(fire9_squeeze)
    fire9_expand2 = Conv2D(
        56, (3, 3), activation='relu', kernel_initializer='glorot_uniform',
        padding='same', name='fire9_expand2',
        data_format="channels_last")(fire9_squeeze)
    merge9 = Concatenate(axis=3)([fire9_expand1, fire9_expand2])

    # maxpool10
    maxpool10 = MaxPooling2D(
        pool_size=(2, 2), strides=(2, 2), name='maxpool10',
        data_format="channels_last")(merge9)

    # fire10
    fire10_squeeze = Conv2D(
        38, (1, 1), activation='relu', kernel_initializer='glorot_uniform',
        padding='same', name='fire10_squeeze',
        data_format="channels_last")(maxpool10)
    fire10_expand1 = Conv2D(
        41, (1, 1), activation='relu', kernel_initializer='glorot_uniform',
        padding='same', name='fire10_expand1',
        data_format="channels_last")(fire10_squeeze)
    fire10_expand2 = Conv2D(
        44, (3, 3), activation='relu', kernel_initializer='glorot_uniform',
        padding='same', name='fire10_expand2',
        data_format="channels_last")(fire10_squeeze)
    merge10 = Concatenate(axis=3)([fire10_expand1, fire10_expand2])

    # cov12-1
    conv12_1 = Conv2D(
        51, (3, 3), activation='relu', kernel_initializer='glorot_uniform',
        padding='same', strides=(2, 2), name='conv12_1',
        data_format='channels_last')(merge10)
    # padding_1
    padding_1 = ZeroPadding2D((1, 1), data_format='channels_last')(conv12_1)

    # conv12_2
    conv12_2 = Conv2D(
        46, (3, 3), activation='relu', kernel_initializer='glorot_uniform',
        name='conv12_2',
        data_format='channels_last')(padding_1)

    # conv13_1
    conv13_1 = Conv2D(
        55, (3, 3), activation='relu', kernel_initializer='glorot_uniform',
        padding='same', name='conv13_1',
        data_format='channels_last')(conv12_2)

    # padding_2
    # padding_2 = ZeroPadding2D((1,1),data_format='channels_last')(conv13_1)

    # conv13_2
    conv13_2 = Conv2D(
        85, (3, 3), activation='relu', kernel_initializer='glorot_uniform',
        name='conv13_2',
        data_format='channels_last')(conv13_1)

    # Prediction from Fire4
    num_priors = 3
    fire4_norm = Normalize(20, name='fire4_norm')(merge4)
    fire4_norm_mbox_loc = Conv2D(
        num_priors*4, (3, 3), name='fire4_norm_mbox_loc', padding='same',
        data_format='channels_last')(fire4_norm)
    fire4_mbox_norm_loc_flat = Flatten()(fire4_norm_mbox_loc)

    name = 'fire4_norm_mbox_conf'
    if nb_classes != 21:
        name += '_{}'.format(nb_classes)
    fire4_norm_mbox_conf = Conv2D(
        num_priors*nb_classes, (3, 3), name=name, padding='same',
        data_format='channels_last')(fire4_norm)
    fire4_norm_mbox_conf_flat = Flatten()(fire4_norm_mbox_conf)

    fire4_norm_mbox_priorbox = PriorBox(img_size, 30.0, aspect_ratios=[2],
                                        variances=[0.1, 0.1, 0.2, 0.2],
                                        name='fire4_norm_mbox_priorbox')(fire4_norm)
    fire4_priorbox_flatten = Flatten()(fire4_norm_mbox_priorbox)

    # Prediction from Fire8
    num_priors = 6
    fire8_mbox_loc = Conv2D(
        num_priors*4, (3, 3), name='fire8_mbox_loc', padding='same',
        data_format='channels_last')(merge8)
    fire8_mbox_loc_flat = Flatten()(fire8_mbox_loc)

    name = 'fire8_mbox_conf'
    if nb_classes != 21:
        name += '_{}'.format(nb_classes)
    fire8_mbox_conf = Conv2D(
        num_priors*nb_classes, (3, 3), name=name, padding='same',
        data_format='channels_last')(merge8)
    fire8_mbox_conf_flat = Flatten()(fire8_mbox_conf)

    fire8_mbox_priorbox = PriorBox(img_size, 60.0, max_size=114.0, aspect_ratios=[2, 3],
                                   variances=[0.1, 0.1, 0.2, 0.2],
                                   name='fire8_mbox_priorbox')(merge8)

    # Prediction from Fire9
    num_priors = 6
    fire9_mbox_loc = Conv2D(
        num_priors*4, (3, 3), name='fire9_mbox_loc', padding='same',
        data_format='channels_last')(merge9)
    fire9_mbox_loc_flat = Flatten()(fire9_mbox_loc)

    name = 'fire9_mbox_conf'
    if nb_classes != 21:
        name += '_{}'.format(nb_classes)
    fire9_mbox_conf = Conv2D(
        num_priors*nb_classes, (3, 3), name=name, padding='same',
        data_format='channels_last')(merge9)
    fire9_mbox_conf_flat = Flatten()(fire9_mbox_conf)

    fire9_mbox_priorbox = PriorBox(img_size, 114.0, max_size=168.0, aspect_ratios=[2, 3],
                                   variances=[0.1, 0.1, 0.2, 0.2],
                                   name='fire9_mbox_priorbox')(merge9)

    # Prediction from Fire10
    num_priors = 6
    fire10_mbox_loc = Conv2D(
        num_priors*4, (3, 3), name='fire10_mbox_loc', padding='same',
        data_format='channels_last')(merge10)
    fire10_mbox_loc_flat = Flatten()(fire10_mbox_loc)

    name = 'fire10_mbox_conf'
    if nb_classes != 21:
        name += '_{}'.format(nb_classes)
    fire10_mbox_conf = Conv2D(
        nb_classes*num_priors, (3, 3), name=name, padding='same',
        data_format='channels_last')(merge10)
    fire10_mbox_conf_flat = Flatten()(fire10_mbox_conf)

    fire10_mbox_priorbox = PriorBox(img_size, 168.0, max_size=222.0, aspect_ratios=[2, 3],
                                    variances=[0.1, 0.1, 0.2, 0.2],
                                    name='fire10_mbox_priorbox')(merge10)

    # Prediction from Conv12_2
    num_priors = 6
    conv12_maxpool = MaxPooling2D(pool_size=(1, 1), data_format="channels_last")(conv12_2)
    conv12_mbox_loc = Conv2D(
        num_priors*4, (3, 3), name='conv12_mbox_loc', padding='same',
        data_format='channels_last')(conv12_maxpool)
    conv12_mbox_loc_flat = Flatten()(conv12_mbox_loc)

    name = 'conv12_mbox_conf'
    if nb_classes != 21:
        name += '_{}'.format(nb_classes)
    conv12_mbox_conf = Conv2D(
        num_priors*nb_classes, (3, 3), name=name, padding='same',
        data_format='channels_last')(conv12_maxpool)
    conv12_mbox_conf_flat = Flatten()(conv12_mbox_conf)

    conv12_mbox_priorbox = PriorBox(img_size, 222.0, max_size=276.0, aspect_ratios=[2, 3],
                                    variances=[0.1, 0.1, 0.2, 0.2],
                                    name='conv12_mbox_priorbox')(conv12_maxpool)
    # pool6
    # pool6 = GlobalAveragePooling2D(name='pool6')(conv8_2)

    # Prediction from Conv13_2
    conv13_maxpool = MaxPooling2D(pool_size=(1, 1), data_format="channels_last")(conv13_2)
    num_priors = 6
    conv13_mbox_loc = Conv2D(
        num_priors*4, (3, 3), name='conv13_mbox_loc', padding='same',
        data_format='channels_last')(conv13_maxpool)
    conv13_mbox_loc_flat = Flatten()(conv13_mbox_loc)

    name = 'conv13_mbox_conf'
    if nb_classes != 21:
        name += '_{}'.format(nb_classes)
    conv13_mbox_conf = Conv2D(
        num_priors*nb_classes, (3, 3), name=name, padding='same',
        data_format='channels_last')(conv13_maxpool)
    conv13_mbox_conf_flat = Flatten()(conv13_mbox_conf)

    conv13_mbox_priorbox = PriorBox(img_size, 276.0, max_size=330.0, aspect_ratios=[2, 3],
                                    variances=[0.1, 0.1, 0.2, 0.2],
                                    name='conv13_mbox_priorbox')(conv13_maxpool)

    # Gather all predictions
    mbox_loc = Concatenate(axis=1)([fire4_mbox_norm_loc_flat,
                                    fire8_mbox_loc_flat,
                                    fire9_mbox_loc_flat,
                                    fire10_mbox_loc_flat,
                                    conv12_mbox_loc_flat,
                                    conv13_mbox_loc_flat])
    mbox_conf = Concatenate(axis=1)([fire4_norm_mbox_conf_flat,
                                     fire8_mbox_conf_flat,
                                     fire9_mbox_conf_flat,
                                     fire10_mbox_conf_flat,
                                     conv12_mbox_conf_flat,
                                     conv13_mbox_conf_flat])

    # fire4_mbox_priorbox_reshape = Reshape((-1,8),name = ' fire4_mbox_priorbox_reshape')(fire4_norm_mbox_priorbox)


    mbox_priorbox = Concatenate(axis=1)([fire4_norm_mbox_priorbox,
                                         fire8_mbox_priorbox,
                                         fire9_mbox_priorbox,
                                         fire10_mbox_priorbox,
                                         conv12_mbox_priorbox,
                                         conv13_mbox_priorbox])

    # dense = Dense(4096,activation='relu')(flatten_bbox)
    num_boxes = mbox_loc._keras_shape[-1] // 4
    if hasattr(mbox_loc, '_keras_shape'):
        num_boxes = mbox_loc._keras_shape[-1] // 4
    elif hasattr(mbox_loc, 'int_shape'):
        num_boxes = K.int_shape(mbox_loc)[-1] // 4

    mbox_loc_final = Reshape((num_boxes, 4), name='mbox_loc_final')(mbox_loc)
    mbox_conf_logits = Reshape((num_boxes, nb_classes), name='mbox_conf_logits')(mbox_conf)
    mbox_conf_final = Activation('softmax', name='mbox_conf_final')(mbox_conf_logits)

    predictions = Concatenate(axis=2, name='preditions')([mbox_loc_final, mbox_conf_final, mbox_priorbox])

    return Model(inputs=input_img, outputs=predictions)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
  • 177
  • 178
  • 179
  • 180
  • 181
  • 182
  • 183
  • 184
  • 185
  • 186
  • 187
  • 188
  • 189
  • 190
  • 191
  • 192
  • 193
  • 194
  • 195
  • 196
  • 197
  • 198
  • 199
  • 200
  • 201
  • 202
  • 203
  • 204
  • 205
  • 206
  • 207
  • 208
  • 209
  • 210
  • 211
  • 212
  • 213
  • 214
  • 215
  • 216
  • 217
  • 218
  • 219
  • 220
  • 221
  • 222
  • 223
  • 224
  • 225
  • 226
  • 227
  • 228
  • 229
  • 230
  • 231
  • 232
  • 233
  • 234
  • 235
  • 236
  • 237
  • 238
  • 239
  • 240
  • 241
  • 242
  • 243
  • 244
  • 245
  • 246
  • 247
  • 248
  • 249
  • 250
  • 251
  • 252
  • 253
  • 254
  • 255
  • 256
  • 257
  • 258
  • 259
  • 260
  • 261
  • 262
  • 263
  • 264
  • 265
  • 266
  • 267
  • 268
  • 269
  • 270
  • 271
  • 272
  • 273
  • 274
  • 275
  • 276
  • 277
  • 278
  • 279
  • 280
  • 281
  • 282
  • 283
  • 284
  • 285
  • 286
  • 287
  • 288
  • 289
  • 290
  • 291
  • 292
  • 293
  • 294
  • 295
  • 296
  • 297
  • 298
  • 299
  • 300
  • 301
  • 302
  • 303
  • 304
  • 305
  • 306
  • 307
  • 308
  • 309
  • 310
  • 311
  • 312
  • 313
  • 314
  • 315
  • 316
  • 317
  • 318
  • 319
  • 320
  • 321
  • 322
  • 323
  • 324
  • 325
  • 326
  • 327
  • 328
  • 329
  • 330
  • 331
  • 332
  • 333
  • 334
  • 335
  • 336
  • 337
  • 338
  • 339
  • 340
  • 341
  • 342
  • 343
  • 344
  • 345
  • 346
  • 347
  • 348
  • 349
  • 350
  • 351
  • 352
  • 353
  • 354
  • 355
  • 356
  • 357
  • 358
  • 359
  • 360
  • 361
  • 362
  • 363
  • 364
  • 365
  • 366
  • 367
  • 368
  • 369
  • 370
  • 371
  • 372
  • 373
  • 374
  • 375
  • 376
  • 377
  • 378
  • 379
  • 380
  • 381
  • 382
  • 383
  • 384
  • 385
  • 386
  • 387
  • 388
  • 389
  • 390
  • 391
  • 392
  • 393
  • 394
  • 395
  • 396
  • 397
  • 398
  • 399
  • 400
  • 401
  • 402
  • 403
  • 404
  • 405

测试函数:

    model = SqueezeNet((300,300,3), 4)
    model.compile(loss='categorical_crossentropy',optimizer='sgd',metrics=['accuracy'])
    model.summary()

你可能感兴趣的:(使用Squeeze-Net 为基础网络的ssd(keras))