补充一些常用的张量操作。
torch.cat([a,b],dim = c) 用于合并张量,但是要保证张量的数据维度可以合并不会出错(即在要合并的轴 数据维度可以不一样,但是其他的轴数据要保持维度相同)。
a,b :指要合并的数据,c 表示要合并的所在轴。
代码如下:
import torch
a = torch.rand(4,3,28,28)
b = torch.rand(6,3,28,28)
#除0轴数据维度可以不一样(4和6),1,2,3轴数据维度都相同故可以合并
c = torch.cat([a,b],dim = 0)
#除1轴数据维度可以不同也可以相同(3,3);1轴据维度不同(4,6);
# 2,3轴数据维度都相同故不可以合并,打印出错
# d = torch.cat([a,b],dim = 1)
print(c.shape)
# print(d.shape)
torch.stack([a,b],dim = c) 会在dim指定轴之前增加新的维度。但在指定轴的数据维度比==.cat()== 要求更严格同样必须一致。
代码如下:
import torch
a = torch.rand(4,3,28,28)
b = torch.rand(6,3,28,28)
#除0轴数据维度必须保持相同(4和6 不同 会出错),1,2,3轴数据维度都相同否则会无法合并报错
c = torch.stack([a,b],dim = 0)
print(c.shape)
输出结果:
import torch
a = torch.rand(6,3,28,28)
b = torch.rand(6,3,28,28)
#除0轴数据维度必须保持相同(6和6),1,2,3轴数据维度都相同否则会无法合并报错
c = torch.stack([a,b],dim = 0)
print(c.shape)
输出结果:
.split(a,dim = b) 将张量按照a的指定长度
在b轴上进行拆分。
.chunk(a,dim = b) 将张量拆分成a个数量
在b轴上进行拆分。
代码如下:
import torch
a = torch.rand(6,3,28,28)
print(a.shape)
#split以长度进行拆分,3是指长度,每n个进行一个拆分,要有接受数据要保持对应
b,c,d =a.split(2,dim=0)
print(b.shape,c.shape,d.shape)
b,c = a.split(3,dim=0)
print(b.shape,c.shape)
print("********************************************************************")
#chunk,是拆分成指定的n个
b,c = a.chunk(2,dim = 0)
print(b.shape,c.shape)
b,c,d =a.chunk(3,dim=0)
print(b.shape,c.shape,d.shape)
输出结果:
加法
:若相加数据的维度不同,符合广播机制的会广播后再相加。
+号 可以使用加号进行相加。
torch.add() 也可以调用add方式相加。
代码如下:
import torch
a = torch.rand(4,5)
b = torch.rand(5)
print(a)
print(b)
#因为b的维度不够所有且符合广播机制,torch会将b广播成与a相同然后相加
#加法具有两种实现形式,一种是+号,另一种是调用add方式
print(a+b)
print(torch.add(a,b))
减法
-号 可以使用重载运算符减号进行相减。
torch.sub() 也可以调用sub(减法:subtraction)方式相减。
代码如下:
import torch
a = torch.rand(4,5)
b = torch.rand(5)
print(a)
print(b)
print(a-b)
print(torch.sub(a,b))
乘法
:乘法分为元素相乘(即对应位置的元素想乘)和矩阵乘法。
元素相乘:
*号 :可以使用重载运算符星号进行对应元素相乘。
torch.mul() :也可以调用mul(乘法:multiply)方法相乘。
矩阵乘法:需满足矩阵的运算规则,如A的列数(4行5列),等于C的行数(5行8列)得到新的维度(4行8列)
.mm(a,c)号 :仅适用于2D张量矩阵(不推荐)。
@ :重载运算符符号号进行矩阵相乘。
torch.matmul() :也可以调用.matmul()方法进行矩阵相乘。(在3D、4D等多维张量矩阵乘法中,只计算最后两个轴。如(1,2,3,4)@(1,2,4,5)=(1,2,3,5))
代码如下:
import torch
a = torch.rand(4,5)
b = torch.rand(5)
c = torch.rand(5,8)
print(a)
print(b)
#各对应元素相乘
print(a*b)
print(torch.mul(a,b))
#矩阵乘法:torch.mm(仅适用于2D矩阵相乘,不推荐);@符号重载的矩阵乘号;.matmul()函数等三种方法
#矩阵乘法要满足矩阵的运算规则:即A的列数(4行5列),等于的行数C(5行8列)得到新的维度(4行8列)
d = a@c
e = torch.matmul(a,c)
print(d,d.shape)
print(e,e.shape)
除法
:
/号 :可以使用重载运算符 / 号进行对应元素相除。
torch.div() :也可以调用div(除法:divide)方法相除。
代码如下:
import torch
a = torch.rand(4,5)
b = torch.rand(5)
print(a)
print(b)
#各对应元素相除
print(a/b)
print(torch.div(a,b))
.pow(a) :计算x的a次方。也可以使用两个星号来代替 **()。
.sqrt() :开平方根。同样可以使用 **(0.5)
.rsqrt() :开平方根后的倒数。
代码如下:
import torch
a = torch.full((3,3),4)
#平方
b = a.pow(2)
#3次方
c = a**(3)
#开平方根
d = a.sqrt()
e = a.pow(0.5)
#平方根的倒数
f = a.rsqrt()
print(a)
print(b)
print(c)
print(d)
print(e)
print(f)
输出结果:
.exp(a) :计算以e的a次方。
.log(a) :计算以e为底log(a)。
.rsqrt() :计算以10为底log(a)。
代码如下:
import torch
a = torch.full((3,3),2)
#对a每个数均进行e的指定次方
b = torch.exp(a)
# log默认以2为底
c = torch.log(b)
#log 函数以10为底
d = torch.log10(a)
print(a)
print(b)
print(c)
print(d)
.floor() :向下取整。
.ceil(a) :向上取整。
.trunc() :截取整数部分。
.frac() :截取小数部分。
.round() :对小数部分进行四舍五入。
代码如下:
import torch
a = torch.tensor(5.64)
#向下取整
print(a.floor())
#向上取整
print(a.ceil())
#截取整数部分
print(a.trunc())
#截取小数部分
print(a.frac())
#对小鼠部分进行四舍五入
print(a.round())
.clamp(a,b) :将数据裁剪到 a 到 b 之间。(常用于对梯度裁剪,防止梯度爆炸的情况出现)
代码如下:
import torch
a = torch.rand(3,3)*20
print(a)
print(a.min())
print(a.median())
print(a.max())
#将数据裁剪到 5-15之间,小于5的以5代替,大于15的以15代替
b = a.clamp(5,15)
print(b)
1范数 :所有数据绝对值之和。
2范数 :所有数据的平方和开根号。
p范数 :所有数据的p次方和开p根号。
1范数和2范数,未指定轴分析 :即对所有的数的绝对值求和,以及开根号。
代码如下:
a = torch.full([8],1.)
b = torch.full((2,4),1.)
c = torch.full((2,3,4),1.)
print(a)
print(b)
print(c)
# 1范数:所有数据的绝对值之和 , 2范数平方和开根号
print(a.norm(1),b.norm(1),c.norm(1))
print(a.norm(2),b.norm(2),c.norm(2))
1范数和2范数,以c为例在指定轴分析 :C为3D张量分别再0,1,2三个指定轴求1范数,分析数据的计算。
0轴
:c的形状为(2,3,4),在0轴分析,如果未设置轴保留=truse,则1 范数形状应为(3,4)
1轴
:c的形状为(2,3,4),在1轴分析,如果未设置轴保留=truse,则1 范数形状应为(2,4),可以理解对在垂直方向相加。
2轴
:c的形状为(2,3,4),在2轴分析,如果未设置轴保留=truse,则1 范数形状应为(2,3),可以理解对在水平方向相加。
代码如下:
import torch
c = torch.full((2,3,4),1.)
print(c)
#在指定的轴求范数
#0轴
print(c.norm(1,dim=0))
print(c.norm(2,dim=0))
#1轴
print(c.norm(1,dim=1))
print(c.norm(2,dim=1))
#2轴
print(c.norm(1,dim=2))
print(c.norm(2,dim=2))
.min() :获取张量数据中的最小值。
.max() :获取张量数据中的最大值。
pytorch采用的是将整个张量打平和1D张量,根据最大值和最小值获取位置索引。
.argmin() :获取打平后张量数据中的最小值索引。
.argmax() :获取打平后张量数据中的最大值索引。
代码如下:
import torch
a = torch.arange(24).view(2,3,4).float()
print(a)
#打印张量的最大值和最小值
print(a.min(),a.max())
#打印张量最大值,最小值对应的索引,无参数指定时默认flatten
print(a.argmin(),a.argmax())
#若不想打平,则需要指定轴
#0轴可以理解未垂直方向取索引
print(a.argmin(dim=0))
print(a.argmax(dim=0))
#1轴可以理解为水平方向取索引
print(a.argmin(dim=1))
print(a.argmax(dim=1))
#1轴可以理解为水平方向取索引
print(a.argmin(dim=2))
print(a.argmax(dim=2))
0轴
:a的形状为(2,3,4),在0轴索引分析,是对应位置索引,索引值形状未(3,4)。
1轴
:a的形状为(2,3,4),在1轴索引分析,可以理解为竖向(垂直)取索引,索引值形状未(2,4)。
2轴
:a的形状为(2,3,4),在2轴索引分析,可以理解为横向(水平)取索引,索引值形状未(2,3)。
keepdim :对指定的轴取索引时,如果保持轴数不变需要使用 keepdim 保持。
代码如下:
import torch
a = torch.randn(4,10)
print(a)
#打印在1轴最大值及对应索引
print(a.max(dim=1))
#打印索引
print(a.argmax(dim =1))
print("***********************************")
#打印在1轴最大值及对应索引,保留轴
print(a.max(dim=1,keepdim = True))
#打印索引
print(a.argmax(dim =1,keepdim = True))
在分类问题中,由于各种原因,可能会出现,分类的某一问题概率值并不高,为了更准确的分类,我们会需要保留的大的前K个概率,进一步判断药分类的类别。
.topk(a) :保留前a个概率值。
.kthvalue(a) :需要注意的是保留第a个小的,并且只能设置为小。
代码如下:
import torch
a = torch.randn(2,8)
print(a)
#largest 默认为True取最大的前k个,False取最小的前k个
#取最大的前3个及对应的索引号
print(a.topk(3,dim=1))
#取最小的前3个及对应的索引号
print(a.topk(3,dim=1,largest = False))
#取第k个小的值,只能取小
print(a.kthvalue(1,dim=1))
大于:可以直接用重载运算符 > 或者==.gt()== 大于(great)比较。
小于:可以直接用重载运算符 < 或者==.lt()== 小于比较。
等于:可以直接用重载运算符 == 或者 .eq() 等于(equal)。
不等于:可以直接用重载运算符 != 或者 .not_equal() 。
代码如下:
import torch
a = torch.arange(9).view(3,3)
print(a)
#大于
print(a>5)
print(torch.gt(a,5))
#小于
print(a<5)
print(torch.lt(a,5))
#等于
print(a == 5)
print(torch.eq(a,5))
#不等于
print(a != 5)
print(torch.not_equal(a,5))
.where(condition , x, y ) :如果满足条件,会将x中对应元素赋值给输出,不满足则将y对应数值赋给输出。
代码如下:
import torch
condation = torch.randn(3,3)
print(condation)
x = torch.full((3,3),0.)
print(x)
y = torch.full((3,3),1.)
print(y)
#where 用法
print(torch.where(condation>0.5,x,y))
输出结果:
.gather(input , dim, index,out =None) :将数据索引映射到所需要的位置。
代码如下:
import torch
#数据
data = torch.randn(3,6)
print(input)
#索引 在输入的数据中在1轴上去前2个最大值及索引
indexz_data = data.topk(2,dim=1)
print(indexz_data)
idx = indexz_data.indices
#将数据索引映射到另一个位置 [50 - 56]
label = torch.arange(6) + 50
print(label)
#使用gathe进行对应查找
print(torch.gather(label.expand(3,6),dim=1,index = idx))
本节,承上对Pytorch中常用的一些方法进行补充和解释,敬请小伙伴们批评指正,学习讨论,觉得有价值,劳驾动动食指,点个赞哈。