This week, I read a paper on attention mechanism, which mentioned that the existing methods did not explicitly consider the impact of users’ current behavior on their next action. Therefore, this paper proposes a new short-term attention priority model as a remedy, which can capture the user’s general interest from the long-term memory of the session context, while taking into account the current interest in the short-term memory recently clicked by the user. Finally, the experiment results on the dataset prove the effectiveness of this new short-term attention priority model. I continue to learn about the content of attention. By using code to implement the attention mechanism of different attention scoring functions, I understood the attention mechanism from the perspective of code.
本周,我阅读了一篇关于注意力机制相关的论文,论文中提到了现有的方法都没有明确考虑用户当前行为对其下一步行动的影响。于是,论文中提出了一种新的短期注意力优先级模型作为补救措施,该模型能够从会话上下文的长期记忆中捕获用户的一般兴趣,同时考虑到用户最近点击的短期记忆中的当前兴趣。最后,论文通过在数据集上做实验,实验结果证明了这种新的短期注意力优先级模型的有效性。我继续学习了attention的相关内容,通过用代码去实现不同注意力打分函数的注意力机制,从代码的角度去理解attention机制。
文献链接:STAMP: Short-Term Attention/Memory Priority Model for Session-based Recommendation
Predicting users’ actions based on anonymous sessions is a challenging problem in web-based behavioral modeling research, mainly due to the uncertainty of user behavior and the limited information. Recent advances in recurrent neural networks have led to promising approaches to solving this problem, with long short-term memory model proving effective in capturing users’ general interests from previous clicks. However, none of the existing approaches explicitly take the effects of users’ current actions on their next moves into account. In this study, we argue that a long-term memory model may be insufficient for modeling long sessions that usually contain user interests drift caused by unintended clicks. A novel short-term attention/memory priority model is proposed as a remedy, which is capable of capturing users’ general interests from the long-term memory of a session context, whilst taking into account users’ current interests from the short-term memory of the last-clicks. The validity and efficacy of the proposed attention mechanism is extensively evaluated on three benchmark data sets from the RecSys Challenge 2015 and CIKM Cup 2016. The numerical results show that our model achieves state-of-the-art performance in all the tests.
问题:几乎所有基于rnn的SRS模型都只考虑将会话建模为一个项目序列,而没有明确考虑到用户兴趣随时间的迁移而产生的变化。
方案:论文中考虑通过在SRS模型中引入一个近期的动作优先机制,即短时注意/记忆优先(STAMP)模型来解决这个问题,该模型可以同时考虑用户的一般兴趣和当前兴趣。
在典型的SRS任务中,会话由一系列命名项组成,用户的兴趣隐藏在这些隐式的反馈中。为了进一步提高RNN模型的预测精度,必须同时具备学习这种隐式反馈的长期利益和短期利益的能力。
论文研究的主要贡献:
1)论文中引入了一个短期注意力/记忆优先级模型:一个包含跨会话项的统一嵌入空间;一个用于基于会话的推荐系统中下一次点击预测的新神经注意力模型。
2)论文针对STAMP模型的实现,提出了一种新的注意力机制,该机制根据会话上下文计算注意权重,并根据用户当前的兴趣进行增强。输出的注意力向量被解读为用户时间兴趣的合成表示,并且比其他基于神经注意力的解决方案更敏感于用户兴趣随时间的迁移。因此,它能够同时捕捉用户的长期兴趣(响应最初的目的)和短期注意(当前的兴趣)。
3)论文中模型分别在两个数据集上进行了评估,分别是来自RecSys 2015的Yoochoose数据集和来自CIKM Cup 2016的Diginetica数据集。实验结果表明,该方法达到了目前的水平,所提出的注意力机制发挥了重要作用。
每一个session由S = [s1, s2, … , sN]表示,St = {s1, s2, … , st}表示一个截断的序列,其中1 < t < N。
V = {v1, v2, . . . }是指所有的item,X = {x1, x2, . . . } 是item的embedding。
其中:yˆ = {yˆ1, yˆ2, … , yˆ|V |}表示输出的score向量,yˆi 对应于item vi的分数,topk用来预测。
从上图中可以看出,STMP模型以两个embeddings (ms和mt)作为输入,其中ms表示用户对当前会话的总体兴趣,被定义为产生过交互物品嵌入的平均表示:
其中:external memory是指嵌入当前会话St的以前的项序列,mt表示该会话中用户当前的兴趣。
论文中使用lastclick xt表示用户当前的兴趣,即mt =xt;由于xt是从会话的external memory中提取的,因此xt为用户兴趣的短期内存;然后利用两个MLP网络对一般兴趣ms和当前兴趣mt进行处理,实现特征提取。
使用一个简单的没有隐藏层的MLP进行特征抽象,对ms的操作定义为:
对于给定的候选项目xi∈V,得分函数定义为:
最后,使用交叉熵函数计算loss:
从上图中可以看出,两个模型之间唯一区别:在STMP模型中,通常是从外部存储器ms的平均值去计算用户兴趣的抽象特征向量hs,而在STAMP模型中,hs 是从基于注意力的用户的一般兴趣(实值向量ma)计算出来的。
注意力网络由两部分组成:
1)一个简单的前馈神经网络,负责为当前会话前缀St中的每个项目生成注意力权重。
2)一个注意力组合通常负责计算基于注意力的用户兴趣的函数ma。
用于注意力计算的FNN定义为:
在获得关于当前会话前缀St的注意力系数向量α=(α1, α2, …, αt ) 后,基于注意力的用户对当前会话前缀St的兴趣ma计算如下:
为了验证本文提出的短期注意力优先模型的有效性,论文中还提出了一种只有短期注意的网络模型,该模型对于next-item的预测只给予最终产生过交互的物品嵌入st,并且也只使用了一层的简单MLP进行特征提取:
最终的评分表示为:
P@K分数被广泛用于SRS领域预测准确性的度量,P@K表示在排名列表的前K位具有正确推荐项目的测试用例的比例。P@20可用于所有测试,定义为:
MRR@20:所需项的倒数秩的平均值,如果秩大于20,则倒数秩为零。
MRR是一个标准化的范围分数[0, 1],其值的增加反映了大多数在推荐列表的排名中会出现更高的排名,这表明相应的推荐系统性能更好。
论文所提出的注意力机制可以捕获部分重要项目,以对感兴趣的有用特征进行建模,并通过最新的实验结果证明了STAMP的有效性。
1)用户的下一步移动主要受会话前缀的最后一次行为的影响,论文中的模型可以通过时间兴趣展示来有效地利用这些信息。
2)论文提出的注意力机制可以有效地捕获会话的长期和短期兴趣信息,通过实验结果证明了在注意力机制的帮助下,论文中的模型在所有数据集上都达到了最先进的性能。
import torch
import torch.nn as nn
import torch.nn.functional as F
class add_attention(nn.Module):
def __init__(self, q_size, k_size, v_size, seq_len):
super(add_attention, self).__init__()
self.linear_v = nn.Linear(v_size, seq_len)
self.linear_W = nn.Linear(k_size, k_size)
self.linear_U = nn.Linear(q_size, q_size)
self.tanh = nn.Tanh()
def forward(self, query, key, value, dropout=None):
key = self.linear_W(key)
query = self.linear_U(query)
k_q = self.tanh(query + key)
alpha = self.linear_v(k_q)
alpha = F.softmax(alpha, dim=-1)
out = torch.bmm(alpha, value)
return out, alpha
attention = add_attention(100, 100, 100, 10)
q = k = v = torch.randn((8, 10, 100))
out, attn = attention(q, k, v)
print(out.shape)
print(attn.shape)
import torch
import torch.nn as nn
import torch.nn.functional as F
class dot_attention(nn.Module):
def __init__(self):
super(dot_attention, self).__init__()
def forward(self, query, key, value, dropout=None):
alpha = torch.bmm(query, key.transpose(-1, -2))
alpha = F.softmax(alpha, dim=-1)
out = torch.bmm(alpha, value)
return out, alpha
attention = dot_attention()
q = k = v = torch.randn((8, 10, 100))
out, attn = attention(q, k, v)
print(out.shape)
print(attn.shape)
import torch
import torch.nn as nn
import math
import torch.nn.functional as F
class s_dot_attention(nn.Module):
def __init__(self):
super(s_dot_attention, self).__init__()
def forward(self, query, key, value, dropout=None):
d = k.size(-1)
alpha = torch.bmm(query, key.transpose(-1, -2)) / math.sqrt(d)
alpha = F.softmax(alpha, dim=-1)
out = torch.bmm(alpha, value)
return out, alpha
attention = s_dot_attention()
q = k = v = torch.randn((8, 10, 100))
out, attn = attention(q, k, v)
print(out.shape)
print(attn.shape)
import torch
import torch.nn as nn
import torch.nn.functional as F
class bilinear_attention(nn.Module):
def __init__(self, x_size):
super(bilinear_attention, self).__init__()
self.linear_W = nn.Linear(x_size, x_size)
def forward(self, query, key, value, dropout=None):
alpha = torch.bmm(query, self.linear_W(key).transpose(-1, -2))
alpha = F.softmax(alpha, dim=-1)
out = torch.bmm(alpha, value)
return out, alpha
attention = bilinear_attention(100)
q = k = v = torch.randn((8, 10, 100))
out, attn = attention(q, k, v)
print(out.shape)
print(attn.shape)
我认为引入Attention机制的原因主要有三点,一是参数少,模型复杂度跟之前学习过的CNN和RNN相比更小,参数也更少,对算力的要求也就更小;二是速度快,Attention解决了RNN不能并行计算的问题,并且attention机制每一步计算不依赖于上一步的计算结果,因此可以和CNN一样并行处理;三是效果好,Attention的作用是在比较长的文本中,依然能从中抓住重点,不丢失重要的信息。通过这段时间的学习,我依然还没有搞懂Attention机制的原理,下周我将继续学习Attention机制的相关内容,继续拓展一些内容,加深自己的理解。