torch.unsqueeze()/torch.repeat()

(1)torch.unsqueeze:

这个函数主要是对数据维度进行扩充。第二个参数为0数据为行方向扩,为1列方向扩,其余数值错误。unsqueeze()函数起升维的作用,参数表示在哪个地方加一个维度。

ps:总会出现squeeze()和unsqueeze()操作,就是升维和降维的操作,但是为什么需要进行这样的操作呢?unsqueeze()函数的功能是在tensor的某个维度上添加一个维数为1的维度,这个功能用view()函数也可以实现。这一功能尤其在神经网络输入单个样本时很有用,由于pytorch神经网络要求的输入都是mini-batch型的,维度为[batch_size, channels, w, h],而一个样本的维度为[c, w, h],此时用unsqueeze()增加一个维度变为[1, c, w, h]就很方便了。


例子:

import torch
a = torch.tensor([1,2,3,4])
b = torch.unsqueeze(a,0)

c = torch.unsqueeze(a,1)


b: tensor([[1, 2, 3, 4]])

b.size(): torch.Size([1, 4])

c.size(): torch.Size([4, 1])
c: tensor([[1],
[2],
[3],
[4]])



(2)torch.repeat()

PyTorch中的repeat()函数可以对张量进行重复扩充。

当参数只有两个时:(列的重复倍数,行的重复倍数)。1表示不重复

当参数有三个时:(通道数的重复倍数,列的重复倍数,行的重复倍数)。

例如:

import torch
a= torch.arange(30).reshape(5,6)
print(a)
print('b:',a.repeat(2,2))
print('c:',a.repeat(2,1,1))


a:a:tensor([[ 0,  1,  2,  3,  4,  5],
                  [ 6,  7,  8,  9, 10, 11],
                  [12, 13, 14, 15, 16, 17],
                  [18, 19, 20, 21, 22, 23],
                  [24, 25, 26, 27, 28, 29]])
b: tensor([[ 0,  1,  2,  3,  4,  5,  0,  1,  2,  3,  4,  5],
                [ 6,  7,  8,  9, 10, 11,  6,  7,  8,  9, 10, 11],
                [12, 13, 14, 15, 16, 17, 12, 13, 14, 15, 16, 17],
                [18, 19, 20, 21, 22, 23, 18, 19, 20, 21, 22, 23],
                [24, 25, 26, 27, 28, 29, 24, 25, 26, 27, 28, 29],
                [ 0,  1,  2,  3,  4,  5,  0,  1,  2,  3,  4,  5],
                [ 6,  7,  8,  9, 10, 11,  6,  7,  8,  9, 10, 11],
                [12, 13, 14, 15, 16, 17, 12, 13, 14, 15, 16, 17],
                [18, 19, 20, 21, 22, 23, 18, 19, 20, 21, 22, 23],
                [24, 25, 26, 27, 28, 29, 24, 25, 26, 27, 28, 29]])
c: tensor([[[ 0,  1,  2,  3,  4,  5],
         [ 6,  7,  8,  9, 10, 11],
         [12, 13, 14, 15, 16, 17],
         [18, 19, 20, 21, 22, 23],
         [24, 25, 26, 27, 28, 29]],
 
        [[ 0,  1,  2,  3,  4,  5],
         [ 6,  7,  8,  9, 10, 11],
         [12, 13, 14, 15, 16, 17],
         [18, 19, 20, 21, 22, 23],
         [24, 25, 26, 27, 28, 29]]])
 

你可能感兴趣的:(深度学习,python,人工智能)