Attention算法学习记录

博文目录:

0. Attention的模型解释

1. pytorch下attention算法的实现

2. 深度学习中的直觉与模型

 

0. Attention的模型解释

Attention Mechanism目前非常流行,广泛应用于机器翻译、语音识别、图像标注(Image Caption)等很多领域,之所以它这么受欢迎,是因为Attention给模型赋予了区分辨别的能力,例如,在机器翻译、语音识别应用中,为句子中的每个词赋予不同的权重,使神经网络模型的学习变得更加灵活(soft),同时Attention本身可以做为一种对齐关系,解释翻译输入/输出句子之间的对齐关系,解释模型到底学到了什么知识,为我们打开深度学习的黑箱,提供了一个窗口,如图1所示。

Attention算法学习记录_第1张图片

图1 NLP中的attention可视化

又比如在图像标注应用中,可以解释图片不同的区域对于输出Text序列的影响程度。

Attention算法学习记录_第2张图片

图2 图像标注中的attention可视化

通过上述Attention Mechanism在图像标注应用的case可以发现,Attention Mechanism与人类对外界事物的观察机制很类似,当人类观察外界事物的时候,一般不会把事物当成一个整体去看,往往倾向于根据需要选择性的去获取被观察事物的某些重要部分,比如我们看到一个人时,往往先Attention到这个人的脸,然后再把不同区域的信息组合起来,形成一个对被观察事物的整体印象。因此,Attention Mechanism可以帮助模型对输入的X每个部分赋予不同的权重,抽取出更加关键及重要的信息,使模型做出更加准确的判断,同时不会对模型的计算和存储带来更大的开销,这也是Attention Mechanism应用如此广泛的原因。

 

1. pytorch下attention算法的实现

import torch
import torch.nn as nn


class Attention(nn.Module):
    """ Applies attention mechanism on the `context` using the `query`.

    **Thank you** to IBM for their initial implementation of :class:`Attention`. Here is
    their `License
    `__.

    Args:
        dimensions (int): Dimensionality of the query and context.
        attention_type (str, optional): How to compute the attention score:

            * dot: :math:`score(H_j,q) = H_j^T q`
            * general: :math:`score(H_j, q) = H_j^T W_a q`

    Example:

         >>> attention = Attention(256)
         >>> query = torch.randn(5, 1, 256)
         >>> context = torch.randn(5, 5, 256)
         >>> output, weights = attention(query, context)
         >>> output.size()
         torch.Size([5, 1, 256])
         >>> weights.size()
         torch.Size([5, 1, 5])
    """

    def __init__(self, dimensions, attention_type='general'):
        super(Attention, self).__init__()

        if attention_type not in ['dot', 'general']:
            raise ValueError('Invalid attention type selected.')

        self.attention_type = attention_type
        if self.attention_type == 'general':
            self.linear_in = nn.Linear(dimensions, dimensions, bias=False)

        self.linear_out = nn.Linear(dimensions * 2, dimensions, bias=False)
        self.softmax = nn.Softmax(dim=-1)
        self.tanh = nn.Tanh()

[docs]    def forward(self, query, context):
        """
        Args:
            query (:class:`torch.FloatTensor` [batch size, output length, dimensions]): Sequence of
                queries to query the context.
            context (:class:`torch.FloatTensor` [batch size, query length, dimensions]): Data
                overwhich to apply the attention mechanism.

        Returns:
            :class:`tuple` with `output` and `weights`:
            * **output** (:class:`torch.LongTensor` [batch size, output length, dimensions]):
              Tensor containing the attended features.
            * **weights** (:class:`torch.FloatTensor` [batch size, output length, query length]):
              Tensor containing attention weights.
        """
        batch_size, output_len, dimensions = query.size()
        query_len = context.size(1)

        if self.attention_type == "general":
            query = query.reshape(batch_size * output_len, dimensions)
            query = self.linear_in(query)
            query = query.reshape(batch_size, output_len, dimensions)

        # TODO: Include mask on PADDING_INDEX?

        # (batch_size, output_len, dimensions) * (batch_size, query_len, dimensions) ->
        # (batch_size, output_len, query_len)
        attention_scores = torch.bmm(query, context.transpose(1, 2).contiguous())

        # Compute weights across every context sequence
        attention_scores = attention_scores.view(batch_size * output_len, query_len)
        attention_weights = self.softmax(attention_scores)
        attention_weights = attention_weights.view(batch_size, output_len, query_len)

        # (batch_size, output_len, query_len) * (batch_size, query_len, dimensions) ->
        # (batch_size, output_len, dimensions)
        mix = torch.bmm(attention_weights, context)

        # concat -> (batch_size * output_len, 2*dimensions)
        combined = torch.cat((mix, query), dim=2)
        combined = combined.view(batch_size * output_len, 2 * dimensions)

        # Apply linear_out on every 2nd dimension of concat
        # output -> (batch_size, output_len, dimensions)
        output = self.linear_out(combined).view(batch_size, output_len, dimensions)
        output = self.tanh(output)

        return output, attention_weights

2. 深度学习中的直觉与模型

  • 3 X 1 and 1 X 3 代替 3 X 3

 

  • LSTM中的门设计

 

  • 生成对抗网络

Attention机制的本质来自于人类视觉注意力机制。人们视觉在感知东西的时候一般不会是一个场景从到头看到尾每次全部都看,而往往是根据需求观察注意特定的一部分。而且当人们发现一个场景经常在某部分出现自己想观察的东西时,人们会进行学习在将来再出现类似场景时把注意力放到该部分上。:

将更多的注意力聚焦到有用的部分,Attention的本质就是加权。但值得注意的是,同一张图片,人在做不同任务的时候,注意力的权重分布应该是不同的。

 

基于以上的直觉,Attention可以用于:

  1. 学习权重分布:

    • 这个加权可以是保留所有分量均做加权(即soft attention);也可以是在分布中以某种采样策略选取部分分量(即hard attention),此时常用RL来做;
    • 这个加权可以作用在原图上,也可以作用在特征图上;
    • 这个加权可以在时间维度、空间维度、mapping维度以及feature维度。

    以seq2seq举例,传统的模型decoder输入的上下文c是一成不变的,但这显然不够合理,如果这是一个翻译模型,原文为“我喜欢食物”,当我翻译到"i like"时,我翻译下一个"food"时应该更关注的是"食物",而不是别的。

经过attention对上下文c进行加权,decoder每步输入的c都不一样,更加关注翻译当前单词所需的上下文:

 

效果展示:

 

任务聚焦、解耦(通过attention mask)

多任务模型,可以通过Attention对feature进行权重再分配,聚焦各自关键特征。在图像分割论文Fully Convolutional Network with Task Partitioning for Inshore Ship Detection in Optical Remote Sensing Images中:

针对靠岸舰船,本文通过任务解耦的方法来处理。因为高层特征表达能力强,分类更准,但定位不准;底层低位准,但分类不准。为了应对这一问题,本文利用一个深层网络得到一个粗糙的分割结果图(船头/船尾、船身、海洋和陆地分别是一类)即Attention Map;利用一个浅层网络得到船头/船尾预测图,位置比较准,但是有很多虚景。**训练中,使用Attention Map对浅层网络的loss进行引导,只反传在粗的船头/船尾位置上的loss,其他地方的loss不反传。**相当于,深层的网络能得到一个船头/船尾的大概位置,然后浅层网络只需要关注这些大概位置,然后预测出精细的位置,图像中的其他部分(如船身、海洋和陆地)都不关注,从而降低了学习的难度。 

 

你可能感兴趣的:(study)