pytorch学习经验(二) pytorch常用张量维度操作

  • 去掉大小为1的维度

x = torch.squeeze(x)

  • 添加大小为1的维度

x = torch.unsqueeze(x, 3) # 在第3个维度上扩展

  • 张量扩展,在指定维度上将原来的张量扩展到指定大小,比如原来x是31,输入size为[3, 4],可以将其扩大成34,4为原来1个元素的复制。

x = x.expand(*size)

  • 交换两个维度

x = torch.transpose(x, 1, 2) # 交换1和2维度

  • 交换多个维度,transpose只能对两个维度进行操作,permute没有限制

x = x.permute(1, 2, 3, 0) # 进行维度重组

  • 改变形状,下面这两条命令意思是一样的,但是view可能会出现一些诡异的报错,原因是当从多的维度变到少的维度时,如果张量不是在连续内存存放,则view无法变成合并维度,但reshape不受限制。

x = x.view(1, 2, -1)
x = x.reshape(1, 2, -1)

  • 张量拼接,第一个参数是一个tuple,每个tuple是一个张量,第二维度是dim,在指定dim上拼接

torch.cat(a_tuple, dim)

  • 张量拼接,与cat不同的在于,cat只能在原有的某一维度上进行连接,stack可以创建一个新的维度,将原有维度在这个维度上进行顺序排列。比如说,有2个44的张量,用cat就只能把它们变成一个84或48的张量,用stack可以变成24*4.

torch.stack(a_tuple, dim)

  • 张量拆分,在指定维度上将a变成chunk_num个大小相等的chunk,返回一个tuple。如果最后一个不够chunk_num,就返回剩下的。另一个相似的是split,除了第二个参数是chunk_size,其他都一样。

torch.chunk(a, chunk_num, dim)
torch.split(a, chunk_size, dim)

你可能感兴趣的:(pytorch学习经验(二) pytorch常用张量维度操作)