Pytorch函数expand()详解

Pytorch函数 .expand( )

其将单个维度扩大成更大维度,返回一个新的tensor,具体看下例:

import torch

a = torch.Tensor([[1], [2], [3],[4]])
# 未使用expand()函数前的a
print('a.size: ', a.size())
print('a: ', a)

b = a.expand(4, 2)
# 使用expand()函数后的输出
print('a.size: ', a.size())
print('a: ', a)
print('b.size: ', b.size())
print('b: ', b)

expand()函数使用前后a没有发生变化,输出都是:

a.size: torch.Size([4, 1])
a:
1
2
3
4
[torch.FloatTensor of size 4x1]

b 的输出为:

b.size: torch.Size([4, 2])
b:
1 1
2 2
3 3
4 4
[torch.FloatTensor of size 4x2]
由此得出结论,a通过expand()函数扩展某一维度后自身不会发生变化

a = torch.Tensor([[[[1,2], [2,3], [3,4],[4,5]]]])
b = a.expand(2, 1, 4, 2)
c = a.expand(1, 2, 4, 2)
# 使用expand()函数后的输出
print('a.size: ', a.size())

print('b.size: ', b.size())
print('b: ', b)

print('c.size: ', c.size())
print('c: ', c)

 b2 = b.expand(3, 1, 4, 2)  # b: torch.Size([2, 1, 4, 2])
 print('b2.size: ', b2.size())

输出:

a.size: torch.Size([1, 1, 4, 2])

b.size: torch.Size([2, 1, 4, 2])
b:
(0 ,0 ,.,.) =
1 2
2 3
3 4
4 5
(1 ,0 ,.,.) =
1 2
2 3
3 4
4 5
[torch.FloatTensor of size 2x1x4x2]

c.size: torch.Size([1, 2, 4, 2])
c:
(0 ,0 ,.,.) =
1 2
2 3
3 4
4 5
(0 ,1 ,.,.) =
1 2
2 3
3 4
4 5
[torch.FloatTensor of size 1x2x4x2]

b2输出:

Traceback (most recent call last):
File “”, line 1, in
RuntimeError: The expanded size of the tensor (3) must match the existing size (2) at non-singleton dimension 0. at /opt/conda/conda-bld/pytorch_1525796793591/work/torch/lib/TH/generic/THTensor.c:309

由此可见,只要是单维度均可进行扩展,但是若非单维度会报错

你可能感兴趣的:(pytorch,python,深度学习)