功能:
扩展张量中某维数据的尺寸
,返回输入张量在某维扩展为更大尺寸后的张量,且原始tensor和扩展后tensor不共享内存。
参数:
括号中输入参数为指定经过维度尺寸扩展后的张量的size。
注意事项:
expand()函数只能将size=1的维度扩展到更大的尺寸,如果扩展其他size()的维度会报错。(如下:示例三)
# 示例一
import torch
a = torch.tensor([1, 2, 3]) # a.size:一行三列
b = a.expand(2, 3) # 将 a 扩展为:两行三列
print(a)
print(a.size)
print(b)
print(b.size)
# 输出:
tensor([1, 2, 3])
torch.Size([3])
tensor([[1, 2, 3],
[1, 2, 3]])
torch.Size([2, 3])
# 示例二
import torch
a = torch.tensor([1, 2, 3]) # a.size:一行三列
b = a.expand(3, 3) # 将 a 扩展为:三行三列
print(a)
print(b)
# 输出:
tensor([1, 2, 3])
tensor([[1, 2, 3],
[1, 2, 3],
[1, 2, 3]])
# 示例三
import torch
a = torch.tensor([[1], [2], [3]]) # a.size:三行一列
b = a.expand(3, 3) # 将 a 扩展为:三行三列
print(a.size())
print(a)
print(b.size())
print(b)
# 输出:
torch.Size([3, 1])
tensor([[1],
[2],
[3]])
torch.Size([3, 3])
tensor([[1, 1, 1],
[2, 2, 2],
[3, 3, 3]])
功能:
与expand()类似,用来扩展张量中某维数据的尺寸。
参数:
括号内输入参数是另一个张量(可以理解为复制操作,仿照括号内张量的size进行扩展,将输入tensor的维度扩展为与括号内指定tensor相同的size)
# 示例一
import torch
a = torch.tensor([1, 2, 3]) # a.size:一行三列
b = torch.tensor([[1, 1, 1], [2, 2, 2]]) # b.size:两行三列
c = a.expand_as(b) # 将a的size按照b的size扩展,赋值为c,c的size和b的size是相同的
print(a)
print(a.size)
print(b)
print(b.size())
print(c)
print(c.size())
# 输出:
tensor([1, 2, 3])
torch.Size([3]) # a的size此时是1行3列
tensor([[1, 1, 1],
[2, 2, 2]])
torch.Size([2, 3]) # b的size此时是2行3列
tensor([[1, 2, 3],
[1, 2, 3]])
torch.Size([2, 3]) # c,由a扩展而来,按照b的size扩展得到的c,c的size和b的相同,都是2行3列