pytorch 压缩拉平通道的方法

pytorch有两种方法可以压缩拉平通道,例如将 N*C*W*H 转化为 N*C*WH 。

1.view():元素总数不变改变形状

'''
view()是根据元素总数来改变tensor形状的,即变形后的tensor元素总数不变
x.size[0]是x的第一个维度batch_size,-1代表自动计算该维度(其他所有维度合并)
'''

x = x.view(x.size[0],-1)

2.flatten():将指定维度合并为一个维度

#tensor拉平发生的位置
#flatten的两种方式
#将第一维之后的维度合并
x = torch.flatten(x,1)
x = x.flatten(1)

#也可以指定中间维度合并
#其中start_dim为起始维度,end_dim为终止维度。flatten的功能为将start_dim到end_dim的维度合并为一个维度。
flatten(input,start_dim=0,end_dim=-1)

你可能感兴趣的:(python基础学习,pytorch,行人重识别,人工智能,图像处理)