tensorflow2使用Xception预训练网络完成多输出任务-同时完成分类和回归

本文的数据基于最近在研究的一个项目,数据保存在npy文件,数据格式 [[800,640],1,1,1],800*640的图片,按键,x坐标,y坐标。

现在要建立神经网络完成输入图片对 按键、x坐标、y坐标的预测。即同时完成对按键的分类和对坐标的回归。

Xception是tensorflow2自带的在imagenet上预训练的模型,在imagenet上取得了验证集top1 0.79和top5 0.945的准确率,该模型的输入数据维度顺序必须为(高度,宽度,通道数),默认尺寸是299*299。

Xception常用参数:

include_top : False 不包含全连接层  True包含全连接层

weights :None 随机初始化     imagenet 加载在imagenet上的预训练权值

imput_shape: 可选参数,代表输出尺寸元组,仅当include_top =False时有效,否则输入必须是299*299。输入数据宽高不能小于71.

ok,开始加载数据并进行预处理

#WIDTH 图像宽  HEIGHT图像高
train_data = np.load(file_name, allow_pickle=True)
print('training_data-{}.npy'.format(i), len(train))
train = train_data[:-500]
test = train_data[-500:]
#交换宽高,并添加通道数维度,然后进行归一化
train_X = np.array([i[0] for i in train]).swapaxes(0,1) .reshape(-1,HEIGHT,WIDTH,3)/255

train_Y_key = np.array([i[1] for i in train])
train_Y_x = np.array([i[2] for i in train])/WIDTH
train_Y_y = np.array([i[3] for i in train])/HEIGHT

test_X = np.array([i[0] for i in test]).swapaxes(0,1).reshape(-1,HEIGHT,WIDTH,3)/255
test_Y_key = np.array([i[1] for i in test])
test_Y_x = np.array([i[2] for i in test])/WIDTH
test_Y_y = np.array([i[3] for i in test])/HEIGHT

构建模型,模型输入为 (HEIGHT,WIDTH, 3)  输出为三个[out_x, out_y,out_key]

xception = tf.keras.applications.Xception(weights='imagenet',include_top=False,input_shape=(HEIGHT,WIDTH, 3))
xception.trainable = False
inputs = tf.keras.layers.Input(shape=(HEIGHT,WIDTH,  3))
x = xception(inputs)

x = tf.keras.layers.GlobalAveragePooling2D()(x)

x1 = tf.keras.layers.Dense(512, activation='relu')(x)
x1 = tf.keras.layers.Dense(256, activation='relu')(x1)

out_x = tf.keras.layers.Dense(1,name='out_x')(x1)
out_y = tf.keras.layers.Dense(1,name='out_y')(x1)

x2 = tf.keras.layers.Dense(512, activation='relu')(x)
out_key = tf.keras.layers.Dense(9,activation='softmax', name='out_key')(x2)


predictions = [out_x, out_y,out_key]

model = tf.keras.models.Model(inputs=inputs, outputs=predictions)
print(model.summary())

模型详情:

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_2 (InputLayer)            [(None, 400, 320, 3) 0                                            
__________________________________________________________________________________________________
xception (Model)                (None, 13, 10, 2048) 20861480    input_2[0][0]                    
__________________________________________________________________________________________________
global_average_pooling2d (Globa (None, 2048)         0           xception[1][0]                   
__________________________________________________________________________________________________
dense (Dense)                   (None, 512)          1049088     global_average_pooling2d[0][0]   
__________________________________________________________________________________________________
dense_1 (Dense)                 (None, 256)          131328      dense[0][0]                      
__________________________________________________________________________________________________
dense_2 (Dense)                 (None, 512)          1049088     global_average_pooling2d[0][0]   
__________________________________________________________________________________________________
out_x (Dense)                   (None, 1)            257         dense_1[0][0]                    
__________________________________________________________________________________________________
out_y (Dense)                   (None, 1)            257         dense_1[0][0]                    
__________________________________________________________________________________________________
out_key (Dense)                 (None, 9)            4617        dense_2[0][0]                    
==================================================================================================
Total params: 23,096,115
Trainable params: 2,234,635
Non-trainable params: 20,861,480
__________________________________________________________________________________________________
None

下面是编译模型的代码

model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
              loss={'out_x':'mse',
                    'out_y':'mse',
                    'out_key':'sparse_categorical_crossentropy'},
              metrics=['mae','acc']
)

训练代码如下

model.fit(train_X, [train_Y_x,train_Y_y,train_Y_key], epochs=1,batch_size = 32,validation_data = (test_X,[test_Y_x,test_Y_y,test_Y_key])
              )

so easy!

 

你可能感兴趣的:(tensorflow)