目录
一、sequeeze()函数
二、unsequeeze()函数
三、size()函数
四、torch.cat函数
在Pytorch做深度学习过程中,CNN的卷积和池化过程中会用到torch.cat、squeeze()、unsqueeze()和size()函数,下面分别做讲解:
x.squeeze(dim)
用途:进行维度压缩,去掉tensor中维数为1的维度
参数设置:如果设置dim=a,就是去掉指定维度中维数为1的
示例:
import torch
x = torch.tensor([[[1],[2]],[[3],[4]]])
print('x:',x)
print(x.shape)
x1 = x.squeeze()
print('x1:',x1)
print(x1.shape)
x2 = x.squeeze(2)
print('x2:',x2)
print(x2.shape)
输出:
x: tensor([[[1],
[2]],
[[3],
[4]]])
torch.Size([2, 2, 1])
x1: tensor([[1, 2],
[3, 4]])
torch.Size([2, 2])
x2: tensor([[1, 2],
[3, 4]])
torch.Size([2, 2, 1])
可以看出:
(1) x.squeeze(),shape 由(2,2,1)变为(2,2),说明维度为1时被去掉。
(2) x.squeeze(2),shape仍然为(2,2,1),这是因为只有维度为1时才会去掉。
x.unsqueeze(dim=a)
用途:进行维度扩充,在指定位置加上维数为1的维度
参数设置:如果设置dim=a,就是在维度为a的位置进行扩充
示例:
import torch
x = torch.tensor([1,2,3,4])
print(x)
print(x.shape)
x1 = x.unsqueeze(0)
print(x1)
print(x1.shape)
x2 = x.unsqueeze(1)
print(x2)
print(x2.shape)
y = torch.tensor([[1,2,3,4],[9,8,7,6]])
print(y)
print(y.shape)
y1 = y.unsqueeze(0)
print(y1.shape)
print(y1)
print(y1.shape)
y2 = y.unsqueeze(1)
print(y2)
print(y2.shape)
输出:
x: tensor([1, 2, 3, 4])
torch.Size([4])
x1: tensor([[1, 2, 3, 4]])
torch.Size([1, 4])
x2: tensor([[1],
[2],
[3],
[4]])
torch.Size([4, 1])
y: tensor([[1, 2, 3, 4],
[9, 8, 7, 6]])
torch.Size([2, 4])
y1: tensor([[[1, 2, 3, 4],
[9, 8, 7, 6]]])
torch.Size([1, 2, 4])
y2: tensor([[[1, 2, 3, 4]],
[[9, 8, 7, 6]]])
torch.Size([2, 1, 4])
可以看出:
(1) x.unsqueeze(0) ,shape 由(4)变为(1,4),证明在第一个位置增加一个维度。
(2) x.unsqueeze(1) ,shape 由(4)变为(4,1),证明在第二个位置增加一个维度。
介绍
size()函数主要是用来统计矩阵元素个数,或矩阵某一维上的元素个数的函数。
参数
numpy.size(a, axis=None)
a:输入的矩阵
axis:int型的可选参数,指定返回哪一维的元素个数。当没有指定时,返回整个矩阵的元素个数。
示例:
a = np.array([[1,2,3],[4,5,6]])
print(a.shape)
print(np.size(a,0))
print(np.size(a,1))
print(np.size(a))
输出:
(2,3)
2
3
6
示例:
b = tensor([[[1, 2, 3, 4]],
[[9, 8, 7, 6]]])
print(b.shape)
print(b.size(0))
print(b.size(1))
print(b.size(2))
输出:
torch.Size([2, 1, 4])
2
1
4
可以看出:
axis的值没有设定,返回矩阵的元素个数
axis = 0,返回该二维矩阵的行数
axis = 1,返回该二维矩阵的列数
torch.cat(inputs, dimension=0) → Tensor,在给定维度上对输入的张量序列seq 进行连接操作。
torch.cat()可以看做 torch.split() 和 torch.chunk()的反操作。 cat() 函数面例子更好的理解。
参数:
inputs (sequence of Tensors) – 可以是任意相同Tensor 类型的python 序列
dimension (int, optional) – 沿着此维连接张量序列。
示例:
x = torch.randn(2, 3)
print(x)
print(x.shape)
x1 = torch.cat((x,x,x,),0)
print(x1)
print(x1.shape)
y = torch.cat((x,x,x,),1)
print(y1)
print(y1.shape)
输出:
tensor([[-1.1883, 0.5793, -0.2716],
[-0.8177, 0.0659, 0.8393]])
torch.Size([2, 3])
tensor([[-1.1883, 0.5793, -0.2716],
[-0.8177, 0.0659, 0.8393],
[-1.1883, 0.5793, -0.2716],
[-0.8177, 0.0659, 0.8393],
[-1.1883, 0.5793, -0.2716],
[-0.8177, 0.0659, 0.8393]])
torch.Size([6, 3])
tensor([[-1.1883, 0.5793, -0.2716, -1.1883, 0.5793, -0.2716, -1.1883, 0.5793,
-0.2716],
[-0.8177, 0.0659, 0.8393, -0.8177, 0.0659, 0.8393, -0.8177, 0.0659,
0.8393]])
torch.Size([2, 9])