seq2seq attention代码实现 / attention_decoder使用

        from tensorflow.contrib.legacy_seq2seq.python.ops import seq2seq as seq2seq_lib


        lstm_cell_fw = tf.nn.rnn_cell.LSTMCell(self.hidden_dim,initializer=_initializer)
        lstm_cell_bw = tf.nn.rnn_cell.LSTMCell(self.hidden_dim,initializer=_initializer)

        if is_training:
            lstm_cell_fw = tf.nn.rnn_cell.DropoutWrapper(lstm_cell_fw, output_keep_prob=(1 - self.dropout_rate))
            lstm_cell_bw = tf.nn.rnn_cell.DropoutWrapper(lstm_cell_bw, output_keep_prob=(1 - self.dropout_rate))

        lstm_cell_fw = tf.nn.rnn_cell.MultiRNNCell([lstm_cell_fw] * self.num_layers)
        lstm_cell_bw = tf.nn.rnn_cell.MultiRNNCell([lstm_cell_bw] * self.num_layers)

        # padding的是0
        self.sequence_len = tf.reduce_sum(tf.sign(self.inputs), reduction_indices=1)
        self.sequence_len = tf.cast(self.sequence_len, tf.int32)

        # forward and backward
        self.encoder_outputs, self.encoder_output_states = rnn.bidirectional_dynamic_rnn(
            lstm_cell_fw,
            lstm_cell_bw,
            self.input_emb,
            dtype=tf.float32,
            sequence_length=self.sequence_len
        )

        state_fw = self.encoder_output_states[0][-1]
        state_bw = self.encoder_output_states[1][0]
        encoder_state = tf.concat([tf.concat(state_fw, 1),
                                   tf.concat(state_bw, 1)], 1)

        decoder_cell = tf.nn.rnn_cell.LSTMCell(self.hidden_dim, initializer=_initializer)
        
        if is_training:
            decoder_cell = tf.nn.rnn_cell.DropoutWrapper(decoder_cell, output_keep_prob=(1 - self.dropout_rate))

        decoder_output_list, _ = seq2seq_lib.attention_decoder(
            decoder_inputs=[tf.zeros(shape=tf.shape(state_fw[0]))] * self.num_steps,
            initial_state=state_fw,
            attention_states=self.encoder_outputs[0],
            cell=decoder_cell)

        decoder_outputs = tf.stack(decoder_output_list,axis=1)

你可能感兴趣的:(TensorFlow)