1.expand()函数:
(1)函数功能:
expand()函数的功能是用来扩展张量中某维数据的尺寸,它返回输入张量在某维扩展为更大尺寸后的张量。
扩展张量不会分配新的内存,只是在存在的张量上创建一个新的视图(关于张量的视图可以参考博文:由浅入深地分析张量),而且原始tensor和处理后的tensor是不共享内存的。
expand()函数括号中的输入参数为指定经过维度尺寸扩展后的张量的size。
(2)应用举例:
1)
import torch
a = torch.tensor([1, 2, 3])
c = a.expand(2, 3)
print(a)
print(c)
# 输出信息:
tensor([1, 2, 3])
tensor([[1, 2, 3],
[1, 2, 3]]
2)
import torch
a = torch.tensor([1, 2, 3])
c = a.expand(3, 3)
print(a)
print(c)
# 输出信息:
tensor([1, 2, 3])
tensor([[1, 2, 3],
[1, 2, 3],
[1, 2, 3]])
3)
import torch
a = torch.tensor([[1], [2], [3]])
print(a.size())
c = a.expand(3, 3)
print(a)
print(c)
# 输出信息:
torch.Size([3, 1])
tensor([[1],
[2],
[3]])
tensor([[1, 1, 1],
[2, 2, 2],
[3, 3, 3]])
4)
import torch
a = torch.tensor([[1], [2], [3]])
print(a.size())
c = a.expand(3, 4)
print(a)
print(c)
# 输出信息:
torch.Size([3, 1])
tensor([[1],
[2],
[3]])
tensor([[1, 1, 1, 1],
[2, 2, 2, 2],
[3, 3, 3, 3]])
(3)注意事项:
expand()函数只能将size=1的维度扩展到更大的尺寸,如果扩展其他size()的维度会报错。
2.expand_as()函数:
(1)函数功能:
expand_as()函数与expand()函数类似,功能都是用来扩展张量中某维数据的尺寸,区别是它括号内的输入参数是另一个张量,作用是将输入tensor的维度扩展为与指定tensor相同的size。
(2)应用举例:
1)
import torch
a = torch.tensor([[2], [3], [4]])
print(a)
b = torch.tensor([[2, 2], [3, 3], [5, 5]])
print(b.size())
c = a.expand_as(b)
print(c)
print(c.size())
# 输出信息:
tensor([[2],
[3],
[4]])
torch.Size([3, 2])
tensor([[2, 2],
[3, 3],
[4, 4]])
torch.Size([3, 2])
2)
import torch
a = torch.tensor([1, 2, 3])
print(a)
b = torch.tensor([[2, 2, 2], [3, 3, 3]])
print(b.size())
c = a.expand_as(b)
print(c)
print(c.size())
# 输出信息:
tensor([1, 2, 3])
torch.Size([2, 3])
tensor([[1, 2, 3],
[1, 2, 3]])
torch.Size([2, 3])