Tensor——属性统计

statistics

    • norm范数
      • norm-p
    • max/min/mean/prod/sum
    • argmax/argmin
    • dim/keepdim
    • top-k/k-th
      • topk
      • kthvalue
    • compare
      • a>0
      • a不等0:
      • 相等问题

norm范数

Tensor——属性统计_第1张图片

norm-p

一范式:

>>> import torch
>>> a = torch.full([8],1)
>>> b = a.view(2,4)
>>> c = a.view(2,2,2)
>>> a,b,c,a.norm(1),b.norm(1),c.norm(1)
(tensor([1., 1., 1., 1., 1., 1., 1., 1.]), 
tensor([[1., 1., 1., 1.],
        [1., 1., 1., 1.]]), 
tensor([[[1., 1.],
         [1., 1.]],
        [[1., 1.],
         [1., 1.]]]), 
tensor(8.), tensor(8.), tensor(8.))

二范式:

>>> import torch
>>> a = torch.full([8],1)
>>> b = a.view(2,4)
>>> c = a.view(2,2,2)
>>> a.norm(2),b.norm(2),c.norm(2)
(tensor(2.8284), tensor(2.8284), tensor(2.8284))

多个参数:
(偷懒是真的爽……尽可能多的打印,能看懂就行吧 )

>>> import torch
>>> a = torch.full([8],1)
>>> b = a.view(2,4)
>>> c = a.view(2,2,2)
>>> b.norm(1,dim=1),b.norm(2,dim=1),c.norm(1,dim=1),c.norm(2,dim=1)
(tensor([4., 4.]),
 tensor([2., 2.]), 
 tensor([[2., 2.],
        [2., 2.]]), 
 tensor([[1.4142, 1.4142],
        [1.4142, 1.4142]]))

max/min/mean/prod/sum

先生成一个arange,依次求出其最大值、最小值、平均值、累乘、累加:

>>> import torch
>>> a = torch.arange(8).view(2,4).float()
>>> a.max(),a.min(),a.mean(),a.prod(),a.sum()
(tensor(7.), tensor(0.), tensor(3.5000), tensor(0.), tensor(28.))

argmax/argmin

这个函数生成的是最大值、最小值所在的索引:

>>> import torch
>>> a = torch.arange(8).view(2,4).float()
>>> a,a.argmax(),a.argmin()
(tensor([[0., 1., 2., 3.],
        [4., 5., 6., 7.]]), tensor(7), tensor(0))

但奇怪的是,结果返回的最大值索引是7,而不是[1,3];最小值的索引也是0,而不是[0,0]

这是由于min\max这些函数不带参数时,会将tensor打平,所以数据被理解为一维的。

如果不想被打平,就要指定维度:

>>> import torch
>>> a = torch.rand(4,10)
>>> a.argmax(),a.argmax(dim=1)
(tensor(36), tensor([6, 7, 6, 6]))

dim/keepdim

>>> import torch
>>> a = torch.rand(4,10)
>>> a,a.max(dim=1)
(tensor([[4.0516e-01, 3.5340e-01, 8.1751e-01, 6.8875e-02, 7.6161e-01, 6.1794e-01,
         4.1648e-01, 2.5585e-01, 4.3074e-02, 4.0688e-01],
        [1.2020e-01, 6.8003e-02, 7.6678e-01, 1.3043e-01, 5.9450e-01, 8.7791e-01,
         9.2720e-01, 7.0388e-01, 9.2272e-01, 6.2815e-01],
        [9.7613e-01, 6.7368e-01, 2.4311e-01, 7.1937e-01, 7.4137e-01, 6.8246e-01,
         7.7629e-04, 8.0122e-02, 3.5398e-01, 4.2544e-01],
        [8.3330e-01, 8.3967e-01, 9.8525e-01, 5.4272e-01, 9.4928e-01, 6.0541e-01,
         2.9491e-01, 5.4326e-01, 9.6286e-01, 4.5083e-02]]), torch.return_types.max(
values=tensor([0.8175, 0.9272, 0.9761, 0.9853]),
indices=tensor([2, 6, 0, 2])))
import torch

a = torch.rand(4,10)
print(a.max(dim=1))
print(a.argmax(dim=1))
print(a.max(dim=1,keepdim=True))
print(a.argmax(dim=1,keepdim=True))

输出结果:

torch.return_types.max(
values=tensor([0.7900, 0.9615, 0.9218, 0.9187]),
indices=tensor([3, 9, 9, 9]))
tensor([3, 9, 9, 9])
torch.return_types.max(
values=tensor([[0.7900],
        [0.9615],
        [0.9218],
        [0.9187]]),
indices=tensor([[3],
        [9],
        [9],
        [9]]))
tensor([[3],
        [9],
        [9],
        [9]])

top-k/k-th

topk

最大、最小:

>>> import torch
>>> a = torch.rand(4,10)
>>> a.topk(3,dim=1),a.topk(3,dim=1,largest=False)
(torch.return_types.topk(
values=tensor([[0.9983, 0.9524, 0.9346],
        [0.9966, 0.9656, 0.8841],
        [0.9230, 0.9063, 0.8953],
        [0.9644, 0.7742, 0.7558]]),
indices=tensor([[1, 9, 6],
        [5, 9, 6],
        [4, 5, 2],
        [2, 5, 4]])),
  torch.return_types.topk(
values=tensor([[0.1418, 0.3036, 0.3955],
        [0.0758, 0.2322, 0.3211],
        [0.1127, 0.1471, 0.1641],
        [0.0244, 0.2251, 0.3262]]),
indices=tensor([[0, 8, 7],
        [1, 7, 8],
        [8, 9, 1],
        [9, 1, 3]])))

kthvalue

kthvalue只能求最小值

import torch

a = torch.rand(4,10)
print(a.topk(3,dim=1))
print(a.topk(3,dim=1,largest=False))
print(a.kthvalue(8,dim=1))
print(a.kthvalue(3))
print(a.kthvalue(3,dim=1))

Tensor——属性统计_第2张图片

compare

a>0

找出a>0的数,两种方法:

import torch

a = torch.rand(4,10)
print(a>0)
print(torch.gt(a,0))

Tensor——属性统计_第3张图片

a不等0:

import torch

a = torch.rand(4,10)
print(a!=0)

在这里插入图片描述

相等问题

a和b比较:

import torch

a = torch.ones(2,3)
b = torch.randn(2,3)
print(torch.eq(a,b))

在这里插入图片描述
a和a比较:
注意eq和equal是不一样的

import torch

a = torch.ones(2,3)
print(torch.eq(a,a))
print(torch.equal(a,a))

在这里插入图片描述

你可能感兴趣的:(PyTorch)