前文简单的介绍了tensor的索引、切片等操作。本文主要介绍数据的拼接与分割以及数学运算。
[In] a = torch.rand(4,32,8)
[In] b = torch.rand(5,32,8)
torch.cat() #需要合并的维度值可以不同,其他维度必须完全相同
[In] torch.cat([a,b],dim=0).shape
[Out] torch.Size([9,32,8])
torch.stack() #创建一个新维度,两组数据的其他维度需要相同
[In] a = torch.rand(4,3,16,32)
[In] b = torch.rand(4,3,16,32)
[In] torch.stack([a,b],dim=2).shape
[Out] torch.Size([4,3,2,16,32])
torch.split() #与cat相反,用于分割维度
[In] a = torch.rand(2,32,8)
[In] c,d = a.split([1,1],sim=0)
[In] c.shape , d.shape
[Out] torch.Size([1,32,8]),torch.Size([1,32,8])
#加法
torch.add() or +
#数据维度相同,或者某一维度为1
[In] a = torch.rand(3,4)
[In] b = torch.rand(4)
[In] a+b
[In] torch.add(a,b)
#注意:torch.sum()与torch.add()的区别,前者用于压缩数据的某一维度
[In] torch.sum(a,dim=0) #dim不赋值则对数据中所有元素求和
#减法
torch.sub() or -
#乘法
点乘:* or torch.mul()
矩阵乘法: @ = torch.matmul() or (torch.mm()只能二维矩阵)
#除法
除法:/ or torch.div() #注意:torch.div的输入应为浮点型
#乘方
torch.pow() or **2
#开方
torch.sqrt()
#开方取倒数
torch.rsqrt()
#指数e
torch.exp()
#对数 ln
torch.log()
torch.floor() #向下取整
torch.ceil() #向上取整
torch.trunc() #取整数部分
torch.frac() #取小数部分
torch.any() #数据中任一元素为True,则返回True
torch.all() #数据中所有元素为True,则返回True
torch.max() #取某一维度最大值或者整个数据中的最大值以及其索引
torch.maximum() #比较两个相同形状张量的元素大小,取较大值,输出与输入相同形状
torch.min()
torch.minimum()
整个数据的运算有很多函数,有时候用到时忘记咋用了,或者长啥样,为了防止到处搜,这次直接放到我的博客中记录下来。