【增减维度】numpy和torch中的squeeze、unsqueeze理解

文章目录

  • 1 为何要增减维度
  • 2 numpy中的squeeze 函数
  • 3 torch中的squeeze 函数
  • 4 torch中的unsqueeze 函数

1 为何要增减维度

神经网络conv2d的输入必须是四维的(batch,channel,height,width),前处理或者后处理通常需要维度扩充或者维度压缩,必须维度匹配!
一个减少维度,一个增加维度,增加和减少的维度只能是1(单维度)。

numpy中squeeze函数,无unsqueeze函数,numpy中增加维度用np.expand_dims(x, axis)函数,可参考链接
torch的tensor中,两个函数都有。

2 numpy中的squeeze 函数

解释:
从数组的形状中删除单维度条目,即把shape中为1的维度去掉,相当于减少维度

用法:

arr_1 = numpy.squeeze(arr, axis = None)

arr表示输入的数组;
axis的取值可为None或0,默认为None,表示删除所有shape为1的维度。axis为0表示删除 一层 shape为1的维度

举例:

import numpy as np

arr = np.array([[[[1,2,3],[4,5,6]]]])
print(type(arr), arr, arr.shape, sep='\n')
print("==========================")

arr_1 = np.squeeze(arr, axis=0)
print(type(arr_1), arr_1, arr_1.shape, sep='\n')
print("==========================")

arr_2 = np.squeeze(arr, axis=None)
print(type(arr_2), arr_2, arr_2.shape, sep='\n')

输出:


[[[1 2 3]
  [4 5 6]]]
(1, 2, 3)
==========================

[[1 2 3]
 [4 5 6]]
(2, 3)
==========================

[[1 2 3]
 [4 5 6]]
(2, 3)

3 torch中的squeeze 函数

举例:

import torch

arr = torch.Tensor(1, 3, 1, 5)
print(type(arr), arr, arr.shape, sep='\n')
print("==========================")

# 里面的数字表示压缩哪个维度,依旧只有维度为1才能压
arr_1 = arr.squeeze(0)          # 压缩第一维度,且第一维度是1,可压缩
print(type(arr_1), arr_1, arr_1.shape, sep='\n')
print("==========================")

arr_2 = arr.squeeze(1)        # 压缩第二维度,但第二维度不是1,故不可压缩
print(type(arr_2), arr_2, arr_2.shape, sep='\n')
print("==========================")

arr_3 = arr.squeeze(2)        # 压缩第三维度,且第三维度是1,可压缩
print(type(arr_3), arr_3, arr_3.shape, sep='\n')

输出:


tensor([[[[1.9349e-19, 4.5445e+30, 4.7429e+30, 7.1354e+31, 7.1118e-04]],

         [[1.7444e+28, 7.3909e+22, 1.8727e+31, 1.4182e-19, 4.6168e+24]],

         [[4.2964e+24, 1.2514e-14, 8.9634e-33, 7.1345e+31, 7.1118e-04]]]])
torch.Size([1, 3, 1, 5])
==========================

tensor([[[1.9349e-19, 4.5445e+30, 4.7429e+30, 7.1354e+31, 7.1118e-04]],

        [[1.7444e+28, 7.3909e+22, 1.8727e+31, 1.4182e-19, 4.6168e+24]],

        [[4.2964e+24, 1.2514e-14, 8.9634e-33, 7.1345e+31, 7.1118e-04]]])
torch.Size([3, 1, 5])
==========================

tensor([[[[1.9349e-19, 4.5445e+30, 4.7429e+30, 7.1354e+31, 7.1118e-04]],

         [[1.7444e+28, 7.3909e+22, 1.8727e+31, 1.4182e-19, 4.6168e+24]],

         [[4.2964e+24, 1.2514e-14, 8.9634e-33, 7.1345e+31, 7.1118e-04]]]])
torch.Size([1, 3, 1, 5])
==========================

tensor([[[1.9349e-19, 4.5445e+30, 4.7429e+30, 7.1354e+31, 7.1118e-04],
         [1.7444e+28, 7.3909e+22, 1.8727e+31, 1.4182e-19, 4.6168e+24],
         [4.2964e+24, 1.2514e-14, 8.9634e-33, 7.1345e+31, 7.1118e-04]]])
torch.Size([1, 3, 5])

4 torch中的unsqueeze 函数

解释:
通过unsuqeeze(int)中的int整数,增加一个维度,int整数表示维度增加到哪儿去,且维度为1。

举例:

import torch

arr = torch.Tensor(3, 5)
print(type(arr), arr, arr.shape, sep='\n')
print("==========================")

# 本身是二维,增加一维变三维,可通过0,1,2三个数字来控制维度增加到哪
arr_1 = arr.unsqueeze(0)
print(type(arr_1), arr_1, arr_1.shape, sep='\n')
print("==========================")

arr_2 = arr.unsqueeze(1)
print(type(arr_2), arr_2, arr_2.shape, sep='\n')
print("==========================")

arr_3 = arr.unsqueeze(2)        # 数字再大就报错了
print(type(arr_3), arr_3, arr_3.shape, sep='\n')

输出:


tensor([[3.2483e+33, 1.9690e-19, 6.8589e+22, 1.3340e+31, 1.1708e-19],
        [7.2128e+22, 9.2216e+29, 7.5546e+31, 1.6932e+22, 3.0728e+32],
        [2.9514e+29, 2.8940e+12, 7.5338e+28, 1.8037e+28, 3.4740e-12]])
torch.Size([3, 5])
==========================

tensor([[[3.2483e+33, 1.9690e-19, 6.8589e+22, 1.3340e+31, 1.1708e-19],
         [7.2128e+22, 9.2216e+29, 7.5546e+31, 1.6932e+22, 3.0728e+32],
         [2.9514e+29, 2.8940e+12, 7.5338e+28, 1.8037e+28, 3.4740e-12]]])
torch.Size([1, 3, 5])
==========================

tensor([[[3.2483e+33, 1.9690e-19, 6.8589e+22, 1.3340e+31, 1.1708e-19]],

        [[7.2128e+22, 9.2216e+29, 7.5546e+31, 1.6932e+22, 3.0728e+32]],

        [[2.9514e+29, 2.8940e+12, 7.5338e+28, 1.8037e+28, 3.4740e-12]]])
torch.Size([3, 1, 5])
==========================

tensor([[[3.2483e+33],
         [1.9690e-19],
         [6.8589e+22],
         [1.3340e+31],
         [1.1708e-19]],

        [[7.2128e+22],
         [9.2216e+29],
         [7.5546e+31],
         [1.6932e+22],
         [3.0728e+32]],

        [[2.9514e+29],
         [2.8940e+12],
         [7.5338e+28],
         [1.8037e+28],
         [3.4740e-12]]])
torch.Size([3, 5, 1])

你可能感兴趣的:(深度学习基础知识,python,pytorch,神经网络)