作用:将张量按维度dim进行拼接
代码示例:
import torch
import numpy as np
# ************example1***********
t = torch.ones((2,3))
t_0 = torch.cat([t,t],dim=0)
t_1 = torch.cat([t,t],dim=1)
print("t_0:{} shape:{}\nt_1:{} shape:{}".format(t_0, t_0.shape,t_1,t_1.shape))
运行:
可以看到在第一个维度拼接的时候,生成一个4*3的全1张量,第二个维度的时候生成一个2*6的全1张量
作用:在新创建的维度dim上进行拼接
和torch.cat()不同,用torch.stack()会创建一个新维度
代码示例:
t = torch.ones((2,3))
t_stack = torch.stack([t,t],dim=2)
print("t_stack:{}\nshape:{}".format(t_stack,t_stack.shape))
结果发现新创建了一个第三维度
若在第一个维度创建,则原来的维度往后移动一个位子,比如:
t = torch.ones((2,3))
t_stack = torch.stack([t,t],dim=0)
print("t_stack:{}\nshape:{}".format(t_stack,t_stack.shape))
运行结果:生成2*2*3的张量
作用:将张量按维度dim进行平均切分
返回值:张量列表
注意:如果不能整除,最后一份张量小于其他张量
代码示例:
a = torch.ones((2,3))
list_of_tensors = torch.chunk(a, dim=1, chunks=3)
for idx,t in enumerate(list_of_tensors):
print("第{}个张量:{},shape is {}".format(idx+1,t, t.shape))
运行:
作用:将张量按维度dim进行切分
返回值:张量列表
代码:split_size_or_sections为int时
t = torch.ones((2,5))
list_of_tensors = torch.split(t, 2,dim=1)
for idx,t in enumerate(list_of_tensors):
print("第{}个张量:{},shape is {}".format(idx+1,t, t.shape))
结果:
当split_size_or_sections为list时,list的数字加起来要等于指定维度的长度,否则会报错
改成
list_of_tensors = torch.split(t, [1,2,1,1],dim=1)
结果:
作用:在维度dim上,按index索引数据
返回值:依据index索引数据拼接的张量
代码:注意idx的数据类型是torch.long(如果是其他类型会报错)
t = torch.randint(0,9,size=(3,3))
idx = torch.tensor([0,2],dtype=torch.long)
t_select = torch.index_select(t,dim=0,index=idx)
print("t:\n{}\nt_select:\n{}".format(t,t_select))
运行
作用:按mask中的True进行索引
返回值:一维张量
代码实现
t = torch.randint(0,9,size=(3,3))
mask = t.ge(5)
t_select = torch.masked_select(t,mask) #ge表示大于等于
print("t:\n{}\nmask:\n{}\nt_select:\n{}".format(t,mask,t_select))
运行结果
作用:变换张量形状
注意:当张量在内存中是连续的时候,新张量与input共享数据内存
代码:
t = torch.randperm(8)
t_reshape = torch.reshape(t,(2,4))
print("t:\n{}\nt_reshape:\n{}".format(t,t_reshape))
运行
注意reshape前后的张量是共享内存的,这里可以用代码示例:
t[1] = 1119
print("t:\n{}\nt_reshape:\n{}".format(t,t_reshape))
print("t.data 内存地址:{}".format(id(t.data)))
print("t_reshape.data 内存地址:{}".format(id(t_reshape.data)))
运行可以看到t和t_reshape的第二个元素都变成了1119,而且数据的内存地址也是相同的
作用:交换张量的两个维度
代码示例:
t = torch.rand((2,2,4))
t_transpose = torch.transpose(t,dim0=1,dim1=2)
print("t:{} \nt_transpose:{}".format(t,t_transpose))
运行
作用:2维张量转置,对矩阵而言,等价torch.transpose(input, 0, 1)
代码示例:
t = torch.rand((2,5))
t_t = torch.t(t)
print("t:{} \nt_transpose:{}".format(t,t_t))
运行:
作用:压缩长度为1的维度
用代码理解一下:
t = torch.rand((1,1,3,4))
t_sq = torch.squeeze(t)
t_0 = torch.squeeze(t,dim=0)
t_1 = torch.squeeze(t,dim=1)
t_2 = torch.squeeze(t,dim=2)
t_3 = torch.squeeze(t,dim=3)
print(t.shape,
t_sq.shape,
t_0.shape,
t_1.shape,
t_2.shape,
t_3.shape)
结果如下:
作用:依据dim扩展维度
在t_sq = torch.squeeze(t)后面加上t_unsq = torch.unsqueeze(t_sq,dim=1)
t_sq = torch.squeeze(t)
t_unsq = torch.unsqueeze(t_sq,dim=1)
结果