seq2seq的实现方式(3)

书接上文

这里实现的是方式(4),采用attention的方式。


    def build_model(self):
        """"""
        encoder_input = layers.Input(shape=(self.input_seq_len,))
        encoder_embeding = layers.Embedding(input_dim=len(self.en_word_id_dict),
                                            output_dim=self.encode_embeding_len,
                                            # mask_zero=True
                                            )(encoder_input)
        encoder_lstm, state_h, state_c = layers.LSTM(units=self.encode_lstm_hidden_len,
                                                     return_state=True, return_sequences=True)(encoder_embeding)

        decoder_input = layers.Input(shape=(self.output_seq_len,), name="decoder_input")
 
        decoder_attention_a = encoder_lstm
        encoder_state_h = state_h
        encoder_state_c = state_c
        outputs = []

        decoder_embeding = layers.Embedding(input_dim=len(self.ch_word_id_dict),
                                            output_dim=self.decode_embeding_len,
                                            # mask_zero=True
                                            name="decoder_embeding"
                                            )(decoder_input)
       
        decoder_lstm_layer = layers.LSTM(self.decode_embeding_len,
                                             return_state=True, name="decoder_lstm_layer")
        decoder_embeding_slice_layer =  layers.Lambda(lambda x: x[:, t, :], name="decoder_embeding_slice_layer")
        decoder_embeding_slice_repeat_layer = layers.RepeatVector(1, name="decoder_embeding_slice_repeat_layer")
        attention_layer = AttentionLayer_3()
        concat_layer = layers.Concatenate(axis=-1, name="concat_layer")
        for t in range(self.output_seq_len):
            decoder_embeding_slice =decoder_embeding_slice_layer(decoder_embeding)
            decoder_embeding_slice = decoder_embeding_slice_repeat_layer(decoder_embeding_slice)
            decoder_lstm, h, c = decoder_lstm_layer(decoder_embeding_slice, initial_state=[encoder_state_h, encoder_state_c])
            encoder_state_h = h
            encoder_state_c = c
            decoder_lstm_att = attention_layer([decoder_attention_a, decoder_lstm])
            decoder_lstm = decoder_embeding_slice_repeat_layer(decoder_lstm)
            decoder_lstm = concat_layer([decoder_lstm_att, decoder_lstm])

            outputs.append(decoder_lstm)

        outputs = layers.Concatenate(axis=1)(outputs)

        decoder_out = layers.Dense(len(self.ch_word_id_dict), activation="softmax", name="decoder_out")(outputs)
        model = Model([encoder_input, decoder_input], decoder_out)
        model.compile(optimizer='rmsprop', loss='categorical_crossentropy')

        model.summary()
        return model

什么是爱? 一分也是爱 !
在这里插入图片描述

你可能感兴趣的:(自然语言处理)