【PyTorch基础】——expand()和expand_as()

1、expand()函数

功能:
扩展张量中某维数据的尺寸,返回输入张量在某维扩展为更大尺寸后的张量,且原始tensor和扩展后tensor不共享内存。

参数:
括号中输入参数为指定经过维度尺寸扩展后的张量的size。

注意事项:
expand()函数只能将size=1的维度扩展到更大的尺寸,如果扩展其他size()的维度会报错。(如下:示例三)

# 示例一
import torch
a = torch.tensor([1, 2, 3])   # a.size:一行三列
b = a.expand(2, 3) # 将 a 扩展为:两行三列
print(a)
print(a.size)
print(b)
print(b.size)

# 输出:
tensor([1, 2, 3])
torch.Size([3])
tensor([[1, 2, 3],
        [1, 2, 3]])
torch.Size([2, 3])


# 示例二
import torch
a = torch.tensor([1, 2, 3])  # a.size:一行三列
b = a.expand(3, 3)  # 将 a 扩展为:三行三列
print(a)
print(b)
# 输出:
tensor([1, 2, 3])
tensor([[1, 2, 3],
        [1, 2, 3],
        [1, 2, 3]])

# 示例三
import torch
a = torch.tensor([[1], [2], [3]])  # a.size:三行一列
b = a.expand(3, 3)  # 将 a 扩展为:三行三列
print(a.size())
print(a)
print(b.size())
print(b)
# 输出:
torch.Size([3, 1])
tensor([[1],
        [2],
        [3]])
torch.Size([3, 3])
tensor([[1, 1, 1],
        [2, 2, 2],
        [3, 3, 3]])

2、expand_as()函数

功能:
与expand()类似,用来扩展张量中某维数据的尺寸。

参数:
括号内输入参数是另一个张量(可以理解为复制操作,仿照括号内张量的size进行扩展,将输入tensor的维度扩展为与括号内指定tensor相同的size)

# 示例一
import torch
a = torch.tensor([1, 2, 3])  # a.size:一行三列
b = torch.tensor([[1, 1, 1], [2, 2, 2]])  # b.size:两行三列  
c = a.expand_as(b)  # 将a的size按照b的size扩展,赋值为c,c的size和b的size是相同的
print(a)
print(a.size)

print(b)
print(b.size())

print(c)
print(c.size())
 
# 输出:
tensor([1, 2, 3])
torch.Size([3])     # a的size此时是1行3列

tensor([[1, 1, 1],
        [2, 2, 2]])
torch.Size([2, 3])  # b的size此时是2行3列

tensor([[1, 2, 3],
        [1, 2, 3]])
torch.Size([2, 3])  # c,由a扩展而来,按照b的size扩展得到的c,c的size和b的相同,都是2行3列

你可能感兴趣的:(【PyTorch学习记录】,pytorch,深度学习,python)