multi-head attention理解加代码

(multi-head attention 用于CNN相关理解)

饭前小菜

在早期的Machine Translation(机器翻译)中,Attention机制与RNN的结合。机器翻译解决的是输入是一串在某种语言中的一句话,输出是目标语言相对应的话的问题,如将中文翻译成英文。通常的配置是encoder-decoder结构,即encoder读入输入的句子并将其转换为一个固定长度的向量,然后decoder再将这个向量翻译成对应的目标语言的文字。
存在的问题:RNN机制实际上存在长梯度消失的问题,对于较长的句子,我们很难寄希望于将输入的序列转化为定长的向量而保存所有的有效信息,所有随着翻译句子的长度的增加,这种结构的效果会显著下降。
解决办法:那当然就是我们的attention啦!
multi-head attention 是继self-attention之后又一重大研究成果,其出发点是在transformer模型上,改进之前使用的传统attention。本人是将multi-head attention 用于CNN模型当中,踩了不少坑,但是复现代码的人确实是大牛。相关参考参考代码

传统attention

举个例子:
翻译’knowledge’时,只需要将注意力放在源句子中“知识”的部分,当翻译“power”时,只需要将注意力集中在“力量”。这样,当我们的decoder预测目标翻译的时候就可以看到encoder的所有信息,而不仅局限于原来模型中定长的隐向量,并且不会丧失较长的信息
multi-head attention理解加代码_第1张图片

transformer中的attention

multi-head attention理解加代码_第2张图片

问题来了

当然,既然attention机制如此的有效,那可不可以去掉模型中的RNN的部分,仅仅利用attention呢?答案是当然可以啦!

// attention
 def _attention(self,inputs, attention_size, time_major=False, return_alphas=False):
        if isinstance(inputs, tuple):
        # In case of Bi-RNN, concatenate the forward and the backward RNN outputs.
            inputs = tf.concat(inputs, 2)

        if time_major:
        # (T,B,D) => (B,T,D)
            inputs = tf.array_ops.transpose(inputs, [1, 0, 2])

        hidden_size = inputs.shape[2].value  # D value - hidden size of the RNN layer

        # Trainable parameters
        W_omega = tf.Variable(tf.random_normal([hidden_size, attention_size], stddev=0.1))
        b_omega = tf.Variable(tf.random_normal([attention_size], stddev=0.1))
        u_omega = tf.Variable(tf.random_normal([attention_size], stddev=0.1))

        # Applying fully connected layer with non-linear activation to each of the B*T timestamps;
        #  the shape of `v` is (B,T,D)*(D,A)=(B,T,A), where A=attention_size
        #v = tf.tanh(tf.tensordot(inputs, W_omega, axes=1) + b_omega)
        v = tf.sigmoid(tf.tensordot(inputs, W_omega, axes=1) + b_omega)
        # For each of the timestamps its vector of size A from `v` is reduced with `u` vector
        vu = tf.tensordot(v, u_omega, axes=1)   # (B,T) shape
        alphas = tf.nn.softmax(vu)              # (B,T) shape also
        
        # Output of (Bi-)RNN is reduced with attention vector; the result has (B,D) shape
        output = tf.reduce_sum(inputs * tf.expand_dims(alphas, -1), 1)

        if not return_alphas:
            return output
        else:
            return output, alphas

####self-attention
这里我就直接把做好的PPT粘过来
multi-head attention理解加代码_第3张图片
multi-head attention理解加代码_第4张图片
#~
multi-head attention理解加代码_第5张图片

multi-head attrntion

multi-head attention理解加代码_第6张图片
multi-head attention理解加代码_第7张图片
在得到多个Z向量后,最后一步就是将多个Z需要映射成我们之前的大小 :we need a way to condense these eight down into a single matrix
multi-head attention理解加代码_第8张图片

下面就来看看如何用代码来实现吧!


def multihead_attention(query_antecedent,# a Tensor with shape [batch, length_q, channels]
                        total_key_depth,# an integer
                        total_value_depth,# an integer多个V矩阵的最后一个维度,该设置和attention的输入数据维度有关系,不能随意设置,虽然对程序运行没有什么影响,但是对性能有影响
                        output_depth,# an integer 是将多个Z向量映射成单个向量的W矩阵的最后一个维度
                        num_heads,# an integer dividing total_key_depth and total_value_depth
                        memory_antecedent=None,
                        attention_type="dot_product",
                        q_filter_width=7,#该参数应该重点理解,原来是1,在计算q,k,v的时候,若用一维卷积计算,则代表卷积核的大小,可以按照需求设置
                        kv_filter_width=7,#同上
                        q_padding="SAME",
                        kv_padding="SAME",
                        vars_3d=False,
                        training=True):
    q, k, v = compute_qkv(query_antecedent, memory_antecedent,
                        total_key_depth, total_value_depth, q_filter_width,
                        kv_filter_width, q_padding, kv_padding,)
    q = split_heads(q, num_heads)
    k = split_heads(k, num_heads)
    v = split_heads(v, num_heads)

    key_depth_per_head = total_key_depth // num_heads
    if not vars_3d:
        q *= key_depth_per_head ** -0.5
    if attention_type == "dot_product":
        x = dot_product_attention(q, k, v)
    x = combine_heads(x)
    x = tf.layers.dense(x, output_depth, use_bias=False, name="output_transform" )
    return x
    

所需要调用的模块

def compute_qkv(query_antecedent,
                memory_antecedent,
                total_key_depth,
                total_value_depth,
                q_filter_width=None,
                kv_filter_width=None,
                q_padding="VALID",
                kv_padding="VALID"):
    if memory_antecedent is None:
        memory_antecedent = query_antecedent
    q = compute_attention_component(
          query_antecedent,
          total_key_depth,
          q_filter_width,
          q_padding,
          "q")
    k = compute_attention_component(
          memory_antecedent,
          total_key_depth,
          kv_filter_width,
          kv_padding,
          "k",)
    v = compute_attention_component(
          memory_antecedent,
          total_value_depth,
          kv_filter_width,
          kv_padding,
          "v")
    return q, k, v
#
def compute_attention_component(antecedent,
                                total_depth,
                                filter_width=None,
                                padding="SAME",
                                name="c"):
    if filter_width == 1:
        output = tf.layers.dense(antecedent, units=total_depth, use_bias=True, name=name)
    else:
        output = tf.layers.conv1d(antecedent, filters=total_depth, kernel_size=filter_width, strides=2, padding=padding, name=name)
    return output

def dot_product_attention(q,k,v):
    logits = tf.reduce_sum(tf.matmul(q, k, transpose_b= True), axis=3)#原本是没有求和的,这里做了个求和,所以在下面用了tf.expand_dims
    weights = tf.nn.softmax(logits, name="attention_weights")
    output =v * tf.expand_dims(weights, -1)
    return output

小菜鸟踩坑记~~~
第一次博客打卡,哈哈
: 有关attention的使用,之前一直有个疑惑,那就是attention是否只能用在LSTM或者RNN后,对于这个疑问,我之前一直搞不清楚,直到我一步步跟踪数据,才知道attention也是可以脱离RNN和LSTM,所以我将multi-head用在CNN上了,由于最近CNN超级受欢迎,不仅在于它的参数相对于RNN和LSTM少了很多,而且效果也是很好的,代码中有些参数说明不是很清楚,经过大量实验,我暂且用钟式大白话注释~~haha

你可能感兴趣的:(神经网络)