Keras实现多任务学习

 

from keras.layers import Input,LSTM,Bidirectional,Dense,Dropout,Concatenate,Embedding,GlobalMaxPool1D
from keras.models import Model
from keras_contrib.layers import CRF
import keras.backend as K
from keras.utils import plot_model

K.clear_session()
maxlen = 40

#输入
inputs = Input(shape=(maxlen,768),name="sen_emb")
pos1_en = Input(shape=(maxlen,),name="pos_en1_id")
pos2_en = Input(shape=(maxlen,),name="pos_en2_id")
pos1_emb = Embedding(maxlen,8,input_length=maxlen,name = "pos_en1_emb")(pos1_en)
pos2_emb = Embedding(maxlen,8,input_length=maxlen,name = "pos_en2_emb")(pos2_en)
x = Concatenate(axis=2)([inputs,pos1_emb,pos2_emb])

#参数共享部分
x = Bidirectional(LSTM(128,return_sequences=True))(x)   #双向LSTM

#任务一,10分类的文本分类任务
out1 = GlobalMaxPool1D()(x)
out1 = Dense(64, activation='relu')(out1)
out1 = Dropout(0.5)(out1)
out1 = Dense(10, activation='softmax',name = "out1")(out1)

#任务二,实体识别任务
crf = CRF(2, sparse_target=True,name ="crf_output")
crf_output = crf(x)

#模型有两个输出out1,crf_output
model = Model(inputs=[inputs,pos1_en,pos2_en], outputs=[out1,crf_output])
model.summary()

###模型有两个loss,categorical_crossentropy和crf.loss_function
model.compile(optimizer='adam',
              loss={'out1': 'categorical_crossentropy','crf_output': crf.loss_function},
              loss_weights={'out1':1, 'crf_output': 1},
              metrics=["acc"])

plot_model(model,to_file="model.png")

 

Keras实现多任务学习_第1张图片

Keras实现多任务学习_第2张图片

 

 

 

你可能感兴趣的:(keras,多任务学习)