Pytorch使用备忘

1. view()

g_x = self.g(x).view(batch_size, self.inter_channels, -1)

  view()是做维度调整, 就是按照参数表里的数字重新调整tensor的各维度上的大小, 最后那个-1的意思是把剩下的维度全都合并归到最后的一维
  比如示例中的g(x)算完之后是[128, 16, 7, 7] : (b, c, h, w), 示例中做的是把它变成[128, 16, 49], 就是把h和w两个维度合成一个

2. permute()

   permute()是对tensor的维度顺序进行调整, 作用跟transpose()很像, 但是permute()是针对tensor做的, 所以是tensor.permute(), 而transpose是numpy库里的, 所以是numpy.transpose(), 用法上差不多, 比如

#permute [b, c, hw] -> [b, hw, c]
g_x = g_x.permute(0, 2, 1)

#transpose [h, w, c] -> [c, h, w]
image_numpy = np.transpose(image_numpy, (1, 2, 0))

你可能感兴趣的:(Pytorch使用备忘)