示例:
x = torch.rand(1,2,3)
print(x)
print("x.shape:", x.shape)
y = x.expand(2, 2,3)
print(y)
print("y.shape:", y.shape)
输出:
tensor([[[0.8020, 0.2300, 0.8632],
[0.9463, 0.0172, 0.9756]]])
x.shape: torch.Size([1, 2, 3])
tensor([[[0.8020, 0.2300, 0.8632],
[0.9463, 0.0172, 0.9756]],
[[0.8020, 0.2300, 0.8632],
[0.9463, 0.0172, 0.9756]]])
y.shape: torch.Size([2, 2, 3])
注意
expand只能对被操作的tensor中维数为1的维度进行扩展,例如上面的示例中只能对第维度扩展,如果改成x.expand(1, 4, 3)
会报错,线面介绍的expand_as和expand几乎一样, expand_as出入的参数是一个tensor
示例:
x = torch.rand(1,2,3)
z = torch.rand(3, 2,3)
print(x)
print("x.shape:", x.shape)
y = x.expand_as(z)
print(y)
print("y.shape:", y.shape)
输出:
tensor([[[0.5696, 0.3571, 0.2565],
[0.9646, 0.5754, 0.7819]]])
x.shape: torch.Size([1, 2, 3])
tensor([[[0.5696, 0.3571, 0.2565],
[0.9646, 0.5754, 0.7819]],
[[0.5696, 0.3571, 0.2565],
[0.9646, 0.5754, 0.7819]],
[[0.5696, 0.3571, 0.2565],
[0.9646, 0.5754, 0.7819]]])
y.shape: torch.Size([3, 2, 3])