一范式:
>>> import torch
>>> a = torch.full([8],1)
>>> b = a.view(2,4)
>>> c = a.view(2,2,2)
>>> a,b,c,a.norm(1),b.norm(1),c.norm(1)
(tensor([1., 1., 1., 1., 1., 1., 1., 1.]),
tensor([[1., 1., 1., 1.],
[1., 1., 1., 1.]]),
tensor([[[1., 1.],
[1., 1.]],
[[1., 1.],
[1., 1.]]]),
tensor(8.), tensor(8.), tensor(8.))
二范式:
>>> import torch
>>> a = torch.full([8],1)
>>> b = a.view(2,4)
>>> c = a.view(2,2,2)
>>> a.norm(2),b.norm(2),c.norm(2)
(tensor(2.8284), tensor(2.8284), tensor(2.8284))
多个参数:
(偷懒是真的爽……尽可能多的打印,能看懂就行吧 )
>>> import torch
>>> a = torch.full([8],1)
>>> b = a.view(2,4)
>>> c = a.view(2,2,2)
>>> b.norm(1,dim=1),b.norm(2,dim=1),c.norm(1,dim=1),c.norm(2,dim=1)
(tensor([4., 4.]),
tensor([2., 2.]),
tensor([[2., 2.],
[2., 2.]]),
tensor([[1.4142, 1.4142],
[1.4142, 1.4142]]))
先生成一个arange,依次求出其最大值、最小值、平均值、累乘、累加:
>>> import torch
>>> a = torch.arange(8).view(2,4).float()
>>> a.max(),a.min(),a.mean(),a.prod(),a.sum()
(tensor(7.), tensor(0.), tensor(3.5000), tensor(0.), tensor(28.))
这个函数生成的是最大值、最小值所在的索引:
>>> import torch
>>> a = torch.arange(8).view(2,4).float()
>>> a,a.argmax(),a.argmin()
(tensor([[0., 1., 2., 3.],
[4., 5., 6., 7.]]), tensor(7), tensor(0))
但奇怪的是,结果返回的最大值索引是7,而不是[1,3];最小值的索引也是0,而不是[0,0]
这是由于min\max这些函数不带参数时,会将tensor打平,所以数据被理解为一维的。
如果不想被打平,就要指定维度:
>>> import torch
>>> a = torch.rand(4,10)
>>> a.argmax(),a.argmax(dim=1)
(tensor(36), tensor([6, 7, 6, 6]))
>>> import torch
>>> a = torch.rand(4,10)
>>> a,a.max(dim=1)
(tensor([[4.0516e-01, 3.5340e-01, 8.1751e-01, 6.8875e-02, 7.6161e-01, 6.1794e-01,
4.1648e-01, 2.5585e-01, 4.3074e-02, 4.0688e-01],
[1.2020e-01, 6.8003e-02, 7.6678e-01, 1.3043e-01, 5.9450e-01, 8.7791e-01,
9.2720e-01, 7.0388e-01, 9.2272e-01, 6.2815e-01],
[9.7613e-01, 6.7368e-01, 2.4311e-01, 7.1937e-01, 7.4137e-01, 6.8246e-01,
7.7629e-04, 8.0122e-02, 3.5398e-01, 4.2544e-01],
[8.3330e-01, 8.3967e-01, 9.8525e-01, 5.4272e-01, 9.4928e-01, 6.0541e-01,
2.9491e-01, 5.4326e-01, 9.6286e-01, 4.5083e-02]]), torch.return_types.max(
values=tensor([0.8175, 0.9272, 0.9761, 0.9853]),
indices=tensor([2, 6, 0, 2])))
import torch
a = torch.rand(4,10)
print(a.max(dim=1))
print(a.argmax(dim=1))
print(a.max(dim=1,keepdim=True))
print(a.argmax(dim=1,keepdim=True))
输出结果:
torch.return_types.max(
values=tensor([0.7900, 0.9615, 0.9218, 0.9187]),
indices=tensor([3, 9, 9, 9]))
tensor([3, 9, 9, 9])
torch.return_types.max(
values=tensor([[0.7900],
[0.9615],
[0.9218],
[0.9187]]),
indices=tensor([[3],
[9],
[9],
[9]]))
tensor([[3],
[9],
[9],
[9]])
最大、最小:
>>> import torch
>>> a = torch.rand(4,10)
>>> a.topk(3,dim=1),a.topk(3,dim=1,largest=False)
(torch.return_types.topk(
values=tensor([[0.9983, 0.9524, 0.9346],
[0.9966, 0.9656, 0.8841],
[0.9230, 0.9063, 0.8953],
[0.9644, 0.7742, 0.7558]]),
indices=tensor([[1, 9, 6],
[5, 9, 6],
[4, 5, 2],
[2, 5, 4]])),
torch.return_types.topk(
values=tensor([[0.1418, 0.3036, 0.3955],
[0.0758, 0.2322, 0.3211],
[0.1127, 0.1471, 0.1641],
[0.0244, 0.2251, 0.3262]]),
indices=tensor([[0, 8, 7],
[1, 7, 8],
[8, 9, 1],
[9, 1, 3]])))
kthvalue只能求最小值
import torch
a = torch.rand(4,10)
print(a.topk(3,dim=1))
print(a.topk(3,dim=1,largest=False))
print(a.kthvalue(8,dim=1))
print(a.kthvalue(3))
print(a.kthvalue(3,dim=1))
找出a>0的数,两种方法:
import torch
a = torch.rand(4,10)
print(a>0)
print(torch.gt(a,0))
import torch
a = torch.rand(4,10)
print(a!=0)
a和b比较:
import torch
a = torch.ones(2,3)
b = torch.randn(2,3)
print(torch.eq(a,b))
import torch
a = torch.ones(2,3)
print(torch.eq(a,a))
print(torch.equal(a,a))