【小技巧】关于pytorch中拥有batch的矩阵的相加

应用场景

将两个batch_size = 100的矩阵相加时,如果维度不完全相同,例如 100 × 512 100 \times 512 100×512 100 × 196 × 512 100 \times 196 \times 512 100×196×512大小的矩阵相加时,因为pytorch不能够直接进行broadcasting,所以如果直接a + b相加的话会直接报错(如果没有batch_size的话则可以正常进行broadcasting,例如 1 × 512 1 \times 512 1×512 1 × 196 × 512 1 \times 196 \times 512 1×196×512的矩阵相加则不会报错)。在PyTorch中,通常可以直接用for循环来直接手动进行相加来对Tensor进行赋值,但是如果是在模型中的forward函数里面这么干的话,可能会直接破坏掉graph的创建,导致后面模型无法进行反向传播。

这时候就应该使用PyTorch提供的view来对矩阵进行变形(其实这个函数就相当于是numpy中的reshape,只是名字不同而已)。还是上面的例子,这次要先将 100 × 512 100 \times 512 100×512的矩阵转换成 100 × 1 × 512 100 \times 1 \times 512 100×1×512,然后再将它和 100 × 512 × 196 100 \times 512 \times 196 100×512×196的矩阵进行相加:

# 假设t_q和t_v是要相加的两个矩阵
t_q = torch.ones([100, 512])
t_v = torch.ones([100, 196, 512])

t_q = t_q.view(-1, 1, 512)
# 上面的-1是指新的维度,通常用来指示batch_size放在哪
print(t_q.shape)	# torch.Size([100, 1, 512])
t_com = t_q + t_v
print(t_com)

这样PyTorch就可以正常地进行broadcasting了。上面的view操作其实是告诉PyTorch,我们要将100个 1 × 512 1 \times 512 1×512的矩阵和100个 196 × 512 196 \times 512 196×512的矩阵相加,因为PyTorch会默认将size中第一个数字(即idx=0位置的数字)当作batch_size,所以如果是直接将 100 × 512 100 \times 512 100×512矩阵和 100 × 196 × 512 100 \times 196 \times 512 100×196×512矩阵相加的话PyTorch会认为你正在将 1 × 100 × 512 1 \times 100 \times 512 1×100×512的矩阵与 100 × 196 × 512 100 \times 196 \times 512 100×196×512的矩阵进行相加,因此会报错。

备注

在构建模型的时候,一定不能在计算过程中直接对Tensor进行赋值,否则会导致autograph机制失效,导致模型无法通过后向传播更新参数。

你可能感兴趣的:(pytorch,debug)