Pytorch:flatten()函数,压缩tensor的维度

前言

有时候会面对需要把数据进行维度转换的情况,
比如本来512*N*W*H(BNWH)的维度需要转换为512*(N*W*H)的一个output和(N*W*H)*512的一个output,然后将两者进行矩阵乘法。
(NHW)*512 X 512*(NWH) = (NHW)*(NHW)
然后再和初始的512*N*W*H进行矩阵乘法,结果仍旧是512*N*W*H,常用在一些non-local conv block中。

代码

import torch
import numpy as np

input = torch.randn(2,3,4,4)
# 将从第二个维度开始进行压缩
# 可以根据自己需要选择从哪里开始压缩
out = input.flatten(start_dim=1,end_dim=3) 
out.shape

得到:

torch.Size([2, 48])  # 3*4*4=48

你可能感兴趣的:(PyTorch,人工智能,深度学习,算法,python,pytorch)