VGG-16获取并设置参数

实现获取VGG-16网络参数,并广播参数构建一个新的网络传入参数。

from keras import layers

from keras.layers import  Conv2D, MaxPooling2D,ZeroPadding2D,GlobalAvgPool2D
from keras.models import Sequential

import numpy
def vgg(img_w,img_h,n_channels,weights_path=None):
    model = Sequential()
    model.add(ZeroPadding2D((1,1),input_shape=(img_w,img_h,n_channels)))
    model.add(Conv2D(64, (3, 3), activation='relu',name="conv1"))
    model.add(ZeroPadding2D((1,1)))
    model.add(Conv2D(64, (3, 3), activation='relu',name="conv2"))
    model.add(MaxPooling2D((2,2), strides=(2,2)))

    model.add(ZeroPadding2D((1,1)))
    model.add(Conv2D(128,(3, 3), activation='relu',name="conv3"))
    model.add(ZeroPadding2D((1,1)))
    model.add(Conv2D(128, (3, 3), activation='relu',name="conv4"))
    model.add(MaxPooling2D((2,2), strides=(2,2)))

    model.add(ZeroPadding2D((1,1)))
    model.add(Conv2D(256,(3, 3), activation='relu',name="conv5"))
    model.add(ZeroPadding2D((1,1)))
    model.add(Conv2D(256, (3, 3), activation='relu',name="conv6"))
    model.add(ZeroPadding2D((1,1)))
    model.add(Conv2D(256, (3, 3), activation='relu',name="conv7"))
    model.add(MaxPooling2D((2,2), strides=(2,2)))

    model.add(ZeroPadding2D((1,1)))
    model.add(Conv2D(512, (3, 3), activation='relu',name="conv8"))
    model.add(ZeroPadding2D((1,1)))
    model.add(Conv2D(512, (3, 3), activation='relu',name="conv9"))
    model.add(ZeroPadding2D((1,1)))
    model.add(Conv2D(512, (3, 3), activation='relu',name="conv10"))
    model.add(MaxPooling2D((2,2), strides=(2,2)))

    model.add(ZeroPadding2D((1,1)))
    model.add(Conv2D(512,(3, 3), activation='relu',name="conv11"))
    model.add(ZeroPadding2D((1,1)))
    model.add(Conv2D(512,(3, 3), activation='relu',name="conv12"))
    model.add(ZeroPadding2D((1,1)))
    model.add(Conv2D(512, (3, 3), activation='relu',name="conv13"))
    model.add(MaxPooling2D((2,2), strides=(2,2)))
    #model.add(GlobalAvgPool2D(data_format="channels_last"))
    #model.load_weights('./vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5')
    if weights_path:
        model.load_weights(weights_path)


    return model
def load_save_weight(img_w,img_h,n_channels):
    model = vgg(32,32,3,'./vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5')
    new_model = vgg(img_w,img_h,n_channels)
    my_weights = model.get_layer('conv1').get_weights()
    final_weights = numpy.zeros((3, 3, 1500, 64))
    for i in range(64):
        a = my_weights[0][:, :, :, i]
        b = numpy.tile(a, (1, 1, 500))
        final_weights[:, :, :, i] = b
    my_weights[0] = final_weights
    new_model.get_layer('conv1').set_weights(my_weights)

    a2 = model.get_layer('conv2').get_weights()
    a3 = model.get_layer('conv3').get_weights()
    a4 = model.get_layer('conv4').get_weights()
    a5 = model.get_layer('conv5').get_weights()
    a6 = model.get_layer('conv6').get_weights()
    a7 = model.get_layer('conv7').get_weights()
    a8 = model.get_layer('conv8').get_weights()
    a9 = model.get_layer('conv9').get_weights()
    a10 = model.get_layer('conv10').get_weights()
    a11 = model.get_layer('conv11').get_weights()
    a12 = model.get_layer('conv12').get_weights()
    a13 = model.get_layer('conv13').get_weights()

    new_model.get_layer('conv2').set_weights(a2)
    new_model.get_layer('conv3').set_weights(a3)
    new_model.get_layer('conv4').set_weights(a4)
    new_model.get_layer('conv5').set_weights(a5)
    new_model.get_layer('conv6').set_weights(a6)
    new_model.get_layer('conv7').set_weights(a7)
    new_model.get_layer('conv8').set_weights(a8)
    new_model.get_layer('conv9').set_weights(a9)
    new_model.get_layer('conv10').set_weights(a10)
    new_model.get_layer('conv11').set_weights(a11)
    new_model.get_layer('conv12').set_weights(a12)
    new_model.get_layer('conv13').set_weights(a13)
    return new_model
if __name__ == "__main__":
    my_model = load_save_weight(32,32,1500)
    my_model.save_weights("./my_vgg_weights.h5")

你可能感兴趣的:(keras)