【Torch API】pytorch中的expand()和expand_as()函数

pytorch中的expand()和expand_as()函数

1.expand()函数:

(1)函数功能: 

expand()函数的功能是用来扩展张量中某维数据的尺寸,它返回输入张量在某维扩展为更大尺寸后的张量。

              扩展张量不会分配新的内存,只是在存在的张量上创建一个新的视图(关于张量的视图可以参考博文:由浅入深地分析张量),而且原始tensor和处理后的tensor是不共享内存的。

              expand()函数括号中的输入参数为指定经过维度尺寸扩展后的张量的size。

 (2)应用举例:

1)
import torch
a = torch.tensor([1, 2, 3])
c = a.expand(2, 3)
print(a)
print(c)
 
# 输出信息:
tensor([1, 2, 3])
tensor([[1, 2, 3],
        [1, 2, 3]]
 
 
 
2)
import torch
a = torch.tensor([1, 2, 3])
c = a.expand(3, 3)
print(a)
print(c)
 
# 输出信息:
tensor([1, 2, 3])
tensor([[1, 2, 3],
        [1, 2, 3],
        [1, 2, 3]])
 
 
3)
import torch
a = torch.tensor([[1], [2], [3]])
print(a.size())
c = a.expand(3, 3)
print(a)
print(c)
 
# 输出信息:
torch.Size([3, 1])
tensor([[1],
        [2],
        [3]])
tensor([[1, 1, 1],
        [2, 2, 2],
        [3, 3, 3]])
 
 
4)
import torch
a = torch.tensor([[1], [2], [3]])
print(a.size())
c = a.expand(3, 4)
print(a)
print(c)
 
# 输出信息:
torch.Size([3, 1])
tensor([[1],
        [2],
        [3]])
tensor([[1, 1, 1, 1],
        [2, 2, 2, 2],
        [3, 3, 3, 3]])

(3)注意事项:

             expand()函数只能将size=1的维度扩展到更大的尺寸,如果扩展其他size()的维度会报错。

2.expand_as()函数:

(1)函数功能:

 expand_as()函数与expand()函数类似,功能都是用来扩展张量中某维数据的尺寸,区别是它括号内的输入参数是另一个张量,作用是将输入tensor的维度扩展为与指定tensor相同的size。

(2)应用举例:

1)
import torch
a = torch.tensor([[2], [3], [4]])
print(a)
b = torch.tensor([[2, 2], [3, 3], [5, 5]])
print(b.size())
c = a.expand_as(b)
print(c)
print(c.size())
 
# 输出信息:
tensor([[2],
        [3],
        [4]])
torch.Size([3, 2])
tensor([[2, 2],
        [3, 3],
        [4, 4]])
torch.Size([3, 2])
 
 
2)
import torch
a = torch.tensor([1, 2, 3])
print(a)
b = torch.tensor([[2, 2, 2], [3, 3, 3]])
print(b.size())
c = a.expand_as(b)
print(c)
print(c.size())
 
# 输出信息:
tensor([1, 2, 3])
torch.Size([2, 3])
tensor([[1, 2, 3],
        [1, 2, 3]])
torch.Size([2, 3])

你可能感兴趣的:(基础知识,pytorch,深度学习,人工智能)