Pytorch Merge操作

简述

Pytorch中没有内置Merge操作,需要手动实现。

下面会以多个四维的Tensor直接的Merge操作来展示。(len, *image_shape(占三维))

Pytorch Merge操作_第1张图片

比如上面,这样的图片。

每一行都是一个数据X_i(包含三张黑白图)。
所以,如果是直接用torch.cat([x0, x1, x2])的话,就是直接把这九张图按照顺序排列了下。

[x1[0],..., x1[n], x2[0],..., x2[n], x3[0],..., x3[n]]

图片对应的编号顺序:

123456789

但是如果有时候,需要把这个大图进行一个 转置 , 就需要用到了Merge的操作。

将(x1, x2, x3),三个变量merge操作。之后,就是

[x1[0], x2[0], x3[0], ..., x1[n], x2[n], x3[n]]

这样再输出的话,就是在大图上做了转置的效果。

因此需要做Merge。

Merge实现

先给范式:x1,x2,...xn的相互merge

torch.stack([x1,x2,..,xn]).transpose(1, 0).contiguous().view(len(x1+x2+..+xn), *x.shape[1:])

简单来说,就是

  1. 先用stack按照第一维度来进行叠加(是会扩充维度的)
  2. 之后,将扩充出来的维度和一开始的index维度(也就是x1[0], x1[2], ..., x1[index]),进行转置(注意的是,需要使用contiguous() 因为需要在物理层面上也要完成转置,之后才能view。这是pytorch的内部机制)
  3. 最后,再用view的方式将扩充好的维度压缩回去。

给个范例的输出:

Pytorch Merge操作_第2张图片

再给个实例上的代码区别的部分:

图一

	plt.imshow(np.transpose(
            vutils.make_grid(
                torch.cat([G_x.cpu().detach(), x.cpu().detach(), y.cpu().detach()]), nrow=3, padding=0,
                normalize=True, scale_each=True), (1, 2, 0)), cmap='gray')

图二(转置后)

	plt.imshow(np.transpose(
            vutils.make_grid(
                torch.stack([G_x.cpu().detach(), x.cpu().detach(), y.cpu().detach()]).transpose(1, 0).contiguous().view(
                    BATCH_SIZE * 3, 1, 96, 96), nrow=3, padding=0,
                normalize=True, scale_each=True), (1, 2, 0)), cmap='gray')

你可能感兴趣的:(Pytorch学习,Python)