a = torch.tensor([[1,2,3],[4,5,6],[7,8,9]])
print(a)
b = torch.tensor([[10,20,30],[40,50,60],[70,80,90]])
# 加法
print(a+b)
print(torch.add(a,b))
# 减法
print(torch.all(torch.eq(a-b,torch.sub(a,b))))
# 乘法
print(torch.all(torch.eq(a*b,torch.mul(a,b))))
# 除法
print(torch.all(torch.eq(a/b,torch.div(a,b))))
执行结果:
tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
tensor([[11, 22, 33],
[44, 55, 66],
[77, 88, 99]])
tensor([[11, 22, 33],
[44, 55, 66],
[77, 88, 99]])
tensor(True)
tensor(True)
tensor(True)
aa.rsqrt() 表示先对aa开平方,然后对开平方的结果求倒数
pow(aa,0.5) 表示开平方运算
exp(n) 表示:e的n次方
log(a) 表示:ln(a)
log2() 、 log10()
In[18]: a = torch.exp(torch.ones(2,2))
In[19]: a
Out[19]:
tensor([[2.7183, 2.7183],
[2.7183, 2.7183]])
In[20]: torch.log(a)
Out[20]:
tensor([[1., 1.],
[1., 1.]])
In[22]: torch.log2(a)
Out[22]:
tensor([[1.4427, 1.4427],
[1.4427, 1.4427]])
In[23]: torch.log10(a)
Out[23]:
tensor([[0.4343, 0.4343],
[0.4343, 0.4343]])
【1】取整、四舍五入、裁剪
floor、ceil 向下取整、向上取整
round 4舍5入
trunc、frac 裁剪
In[24]: a = torch.tensor(3.14)
In[25]: a.floor(),a.ceil(),a.trunc(),a.frac()
Out[25]: (tensor(3.), tensor(4.), tensor(3.), tensor(0.1400))
In[26]: a = torch.tensor(3.499)
In[27]: a.round()
Out[27]: tensor(3.)
In[28]: a = torch.tensor(3.5)
In[29]: a.round()
Out[29]: tensor(4.)
【2】clamp
torch.clamp(input, min, max, out=None) → Tensor
将输入input张量每个元素的夹紧到区间 [min,max][min,max],
并返回结果到一个新张量。
操作定义如下:
| min, if x_i < min
y_i = | x_i, if min <= x_i <= max
| max, if x_i > max
(1) gradient clipping 梯度裁剪
(2) (min) 小于min的都变为某某值
(3) (min, max) 不在这个区间的都变为某某值
(4) 梯度爆炸:一般来说,当梯度达到100左右的时候,就已经很大了,正常在10左右,通过打印梯度的模来查看 w.grad.norm(2)
(5) 对于w的限制叫做weight clipping,对于weight gradient clipping称为 gradient clipping。
In[30]: grad = torch.rand(2,3)*15
In[31]: grad.max()
Out[31]: tensor(10.6977)
In[32]: grad
Out[32]:
tensor([[ 6.7738, 10.6977, 4.4314],
[ 7.8088, 4.8236, 3.6213]])
In[33]: grad.clamp(10) # 小于10的都变为10
Out[33]:
tensor([[10.0000, 10.6977, 10.0000],
[10.0000, 10.0000, 10.0000]])
In[34]: grad.clamp(0,10) # 不在(0,10)区间的都变为10
Out[34]:
tensor([[ 6.7738, 10.0000, 4.4314],
[ 7.8088, 4.8236, 3.6213]])