seq2seq(tf2.0版本)

import tensorflow as tf
'''
LSTM中,每一个细胞单元的state = (c_state,hidden_state),output就是最后一个词语细胞的state中的hidden_state

'''
embedding_units = 256
units = 1024
input_vocab_size = len(input_tokenizer.word_index) +1
output_vocab_size = len(output_tokenizer.word_index) +1

# encode
class Encoder(tf.keras.Model):
    def __init__(self,vocab_size,embedding_units,encoding_units,batch_size):
        super(Encoder,self).__init__()
        self.batch_size = batch_size
        self.encoding_units = encoding_units
        self.embedding = tf.keras.layers.Embedding(vocab_size,embedding_units)
        self.gru = tf.keras.layers.GRU(self.encoding_units,return_sequences=True,
                                       return_state=True,recurrent_initializer='glorot_uniform')

    def call(self,x,hidden):
        # 输入,获取embedding
        x = self.embedding(x)
        # gru
        output,state = self.gru(x,initial_state=hidden)
        return output,state
    def initialize_hidden_state(self):
        return tf.zeros([self.batch_size,self.encoding_units])


# attention机制
class BahdanauAttention(tf.keras.Model):
    def __init__(self,units):
        super(BahdanauAttention,self).__init__()
        self.W1 = tf.keras.layers.Dense(units)
        self.W2 = tf.keras.layers.Dense(units)
        self.V = tf.keras.layers.Dense(1)
    def call(self,decoder_hidden,encoder_outputs):
        # decoder_hidden.shape:[batch_size,units]
        # encoder_outputs.shape:[batch_size,length,units]
        decoder_hidden_with_time_axis = tf.expand_dims(decoder_hidden,1)
        # before V:(batch_size,length,units)
        # after V:(batch_size,length,1)
        score = self.V(tf.nn.tanh(self.W1(encoder_outputs) + self.W2(decoder_hidden_with_time_axis)))
        # shape:(batch_size,length,1)
        attention_weights = tf.nn.softmax(score,axis=1)
        # shape:(batch_size,length,units)
        context_vector = attention_weights * encoder_outputs
        # shape:(batch_size,units)
        context_vector = tf.reduce_sum(context_vector,axis=1)
        return context_vector,attention_weights

attention_model = BahdanauAttention(units=10)
attention_result,attention_weights = attention_model(sample_hidden,sample_output)

你可能感兴趣的:(深度学习)