atae_lstm:aspect-attention-aspect-embed
aspect = self.embed(aspect_indices) #transform aspect索引 to aspect向量
#从上图中可以看出,所有的aspect embedding的向量va对于不同的词向量wi都是一样的,即 只有一个va
aspect_pool = torch.div(torch.sum(aspect, dim=1), aspect_len.view(aspect_len.size(0), 1)) #view()的功能同unsqueeze()?
#torch.sum(aspect, dim=1) 应该是在维度1上加和,即列上加和
#torch.div 除法,有几个aspect就除以几,这一步就是为了求平均后的aspect vector
aspect = torch.unsqueeze(aspect_pool, dim=1).expand(-1, x_len_max, -1)
#增加一个维度便于输入NN
#expand:返回tensor的一个新视图,单个维度扩大为更大的尺寸,-1的意思大概是不改变维度?
这个代码是实现aspect-embedding的一部分
torch.div()是torch中的除法,
torch.div(a, b) ,a和b的尺寸是广播一致的,必须是类型一致的,就是如果a是FloatTensor那么b也必须是FloatTensor,可以使用tensor.to(torch.float64)进行转换。
torch.squeeze(tensor)
torch.squeeze()函数的作用是压缩一个tensor的维数为1的维度,使该tensor降维变成最紧凑的形式
In [4]: a
Out[4]:
tensor([[[0, 1, 2]],
[[3, 4, 5]],
[[6, 7, 8]]])
In [5]: a.size()
Out[5]: torch.Size([3, 1, 3])
In [6]: a.dim()
Out[6]: 3
In [7]: b = torch.squeeze(a)
In [8]: b
Out[8]:
tensor([[0, 1, 2],
[3, 4, 5],
[6, 7, 8]])
In [9]: b.size()
Out[9]: torch.Size([3, 3])
In [10]: b.dim()
Out[10]: 2
torch.unsqueeze(tensor, dim)
unsqueeze()函数的功能是在tensor的某个维度上添加一个维数为1的维度,这个功能用view()函数也可以实现。这一功能尤其在神经网络输入单个样本时很有用,由于pytorch神经网络要求的输入都是mini-batch型的,维度为**[batch_size, channels, w, h],而一个样本的维度为[c, w, h],此时用unsqueeze()增加一个维度变为[1, c, w, h]**就很方便了。
In [17]: b
Out[17]:
tensor([[0, 1, 2],
[3, 4, 5],
[6, 7, 8]])
In [18]: b.size(), b.dim()
Out[18]: (torch.Size([3, 3]), 2)
In [20]: b_un = torch.unsqueeze(b, 0)
In [21]: b_un
Out[21]:
tensor([[[0, 1, 2],
[3, 4, 5],
[6, 7, 8]]])
In [22]: b_un.size(), b_un.dim()
Out[22]: (torch.Size([1, 3, 3]), 3)
https://pytorch-cn.readthedocs.io/zh/latest/package_references/Tensor/#expandsizes
expand(*sizes)
返回tensor的一个新视图,单个维度扩大为更大的尺寸。 tensor也可以扩大为更高维,新增加的维度将附在前面。 扩大tensor不需要分配新内存,只是仅仅新建一个tensor的视图,其中通过将stride设为0,一维将会扩展位更高维。任何一个一维的在不分配新内存情况下可扩展为任意的数值。
参数: - sizes(torch.Size or int…)-需要扩展的大小
x = torch.Tensor([[1], [2], [3]])
print "x.size():",x.size()
y=x.expand( 3,4 )#使用expand()函数的时候,x自身不会改变,因此需要将结果重新赋值。
print "x.size():",x.size()
print "y.size():",y.size()
print x
print y