(1)torch.unsqueeze:
这个函数主要是对数据维度进行扩充。第二个参数为0数据为行方向扩,为1列方向扩,其余数值错误。unsqueeze()函数起升维的作用,参数表示在哪个地方加一个维度。
ps:总会出现squeeze()和unsqueeze()操作,就是升维和降维的操作,但是为什么需要进行这样的操作呢?unsqueeze()函数的功能是在tensor的某个维度上添加一个维数为1的维度,这个功能用view()函数也可以实现。这一功能尤其在神经网络输入单个样本时很有用,由于pytorch神经网络要求的输入都是mini-batch型的,维度为[batch_size, channels, w, h],而一个样本的维度为[c, w, h],此时用unsqueeze()增加一个维度变为[1, c, w, h]就很方便了。
例子:
import torch
a = torch.tensor([1,2,3,4])
b = torch.unsqueeze(a,0)
c = torch.unsqueeze(a,1)
b: tensor([[1, 2, 3, 4]])
b.size(): torch.Size([1, 4])
c.size(): torch.Size([4, 1])
c: tensor([[1],
[2],
[3],
[4]])
(2)torch.repeat()
PyTorch中的repeat()函数可以对张量进行重复扩充。
当参数只有两个时:(列的重复倍数,行的重复倍数)。1表示不重复
当参数有三个时:(通道数的重复倍数,列的重复倍数,行的重复倍数)。
例如:
import torch
a= torch.arange(30).reshape(5,6)
print(a)
print('b:',a.repeat(2,2))
print('c:',a.repeat(2,1,1))
a:a:tensor([[ 0, 1, 2, 3, 4, 5],
[ 6, 7, 8, 9, 10, 11],
[12, 13, 14, 15, 16, 17],
[18, 19, 20, 21, 22, 23],
[24, 25, 26, 27, 28, 29]])
b: tensor([[ 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5],
[ 6, 7, 8, 9, 10, 11, 6, 7, 8, 9, 10, 11],
[12, 13, 14, 15, 16, 17, 12, 13, 14, 15, 16, 17],
[18, 19, 20, 21, 22, 23, 18, 19, 20, 21, 22, 23],
[24, 25, 26, 27, 28, 29, 24, 25, 26, 27, 28, 29],
[ 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5],
[ 6, 7, 8, 9, 10, 11, 6, 7, 8, 9, 10, 11],
[12, 13, 14, 15, 16, 17, 12, 13, 14, 15, 16, 17],
[18, 19, 20, 21, 22, 23, 18, 19, 20, 21, 22, 23],
[24, 25, 26, 27, 28, 29, 24, 25, 26, 27, 28, 29]])
c: tensor([[[ 0, 1, 2, 3, 4, 5],
[ 6, 7, 8, 9, 10, 11],
[12, 13, 14, 15, 16, 17],
[18, 19, 20, 21, 22, 23],
[24, 25, 26, 27, 28, 29]],
[[ 0, 1, 2, 3, 4, 5],
[ 6, 7, 8, 9, 10, 11],
[12, 13, 14, 15, 16, 17],
[18, 19, 20, 21, 22, 23],
[24, 25, 26, 27, 28, 29]]])