对于机器翻译(NMT)中的注意力机制的实现-----python + tensorflow

对机器翻译中常用的两种注意力机制进行实现

(1)加性注意力机制  (2) 乘性注意力机制

def attention(hidden,enc_output,W1,W2,W3,V,att_select=None):
    '''
    input: 隐藏层状态tensor, 编码器每个单元输出tensor, 权重矩阵W1,W2,V,注意力机制选择
    'additive': 需要三个权重矩阵, 'multiplicate': 只需要一个权重矩阵 W3
    output:  带权重的文本tensor [batch_size,1,hidden_size]
        
    '''
    if att_select =='additive':
        #hidden = [batch_size,hidden_size] ----> [batch_size,1,hidden_size]
        hidden_with_time_axis= tf.expand_dims(hidden,axis=1)
        #temp = [batch_size, enc_output.shape[1], hidden_size]
        temp= tf.nn.tanh([W1(enc_output) + W2(hidden_with_time_axis)])
        # score = [batch_size,enc_output.shape[1],1]
        score= V(temp)
        
        attention_weights= tf.nn.softmax(score,axis=1)
        #context_vector =[batch_size, enc_output.shape[1], hidden_size] 
        #   ---> [batch_size, hidden_size]
        context_vector= attention_weights * enc_output
        context_vector = tf.reduce_sum(context_vector,axis=1)
        # context_vector = [batch_size, 1,hidden_size]
        context_vector = tf.expand_dims(context_vector,axis=1)
        
    elif att_select == 'multiplicate':
        #hidden = [batch_size,hidden_size] ----> [batch_size,hidden_size,1]
        hidden_with_time_axis = tf.expand_dims(hidden, axis =2)
        #temp = [batch_size, enc_output.shape[1], hidden_size]
        # ----> [[batch_size,hidden_size,enc_output.shape[1]]
        
        temp = W3(tf.transpose(enc_output,perm=[0,2,1]))
        
        # score = [batch_size,enc_output.shape[1],hidden.shape[1]]
        score = temp * hidden_with_time_axis
        score = tf.transpose(score,perm=[0,2,1])
        
        score = tf.expand_dims(tf.reduce_sum(score,axis=2),axis=2)
                
        attention_weights = tf.nn.softmax(score,axis =1)
        # like above additive
        context_vector = attention_weights * enc_output
        context_vector = tf.reduce_sum(context_vector,axis=1)
        
        #context_vector = tf.expand_dims(context_vector, axis=1)
        
    return context_vector,attention_weights

对于函数的输入,做一下解释:

hidden  shape=[batch_size, 1, hidden_size]

enc_output   shape =[batch_size, text_len, hidden_size]

W1 = [hidden_size, hidden_size]   --------- tf.keras.layers.Dense 实现

W2 = W1 

W3 = [text_len, text_len]

V=[hidden_size,1]

att_select = additive \ multiplicate 

你可能感兴趣的:(python_code,attention)