pytorch基操04-比较运算符

目录

    • 1 torch中的比较运算符
      • 1.1 torch.equal
      • 1.2 torch.equal
      • 1.3 torch.gt
      • 1.4 torch.ge
      • 1.5 torch.lt
      • 1.6 torch.le
      • 1.7 torch.ne
      • 1.8 torch.sort
      • 1.9 torch.topk
      • 1.10 torch.kthvalue
      • 1.11 torch.isinf
      • 1.12 torch.isfinite
      • 1.13 torch.nan

1 torch中的比较运算符

为了演示不同比较运算符的作用,先初始化两个tensor a和b。

# 创建a,b tensor并用均值为20,方差为10的高斯分布采样赋值
a=torch.empty(size=(2,2)).normal_(20,10).floor_() # floor 向下取整
b=torch.empty(size=(2,2)).normal_(20,10).floor_()
a[0,0],b[0,0]=10,10 # 强制修改[0,0]位置元素相同
a:
 tensor([[10., 24.],
        [ 9., 22.]])
b:
 tensor([[10., 25.],
        [22., 26.]])

1.1 torch.equal

equal方法只有当a和b形状元素值都相等才返回True。

print('equal:',torch.equal(a,b))

这里除了[0,0]位置其他元素都不相同,因此返回False.

equal: False

1.2 torch.equal

# a==b
print('eq:',torch.eq(a,b))
eq: tensor([[ True, False],
        [False, False]])

1.3 torch.gt

# a > b , gt=Greater Than
print('a>b:',torch.gt(a,b))

每个a的元素都小于等于在b对应位置的元素,因此都返回False。

a>b: tensor([[False, False],
        [False, False]])

1.4 torch.ge

# a>=b , ge=Greater than or Equal to
print('a>=b:',torch.ge(a,b))
a>=b: tensor([[ True, False],
        [False, False]])

1.5 torch.lt

# a
print('a,torch.lt(a,b))
a<b: tensor([[False,  True],
        [ True,  True]])

1.6 torch.le

# a<=b , le=Less than or Equal to
print('a<=b:',torch.le(a,b))
a<=b: tensor([[True, True],
        [True, True]])

1.7 torch.ne

# a!=b , ne=Not Equal
print('ne:',torch.ne(a,b))
ne: tensor([[False,  True],
        [ True,  True]])

1.8 torch.sort

创建用于排序的张量。

print('\n'+'-'*8+'排序'+'-'*8+'\n')
a=torch.empty(size=(2,5)).normal_(20,10).floor_()
print('a:',a)
print('b:',b)
a: tensor([[28., 17., 18., 16., 19.],
        [44., 30., 10.,  9., 29.]])

排序操作,返回排序后的张量,以及排序后的原元素的索引。此处沿着维度1进行排序。

# torch.sort
print('-'*8+'sort'+'-'*8)
a_sort,a_idx=torch.sort(a,dim=1)
print('a_sort:\n',a_sort)
print('idx_sort:\n',a_idx)
--------sort--------
a_sort:
 tensor([[ 3.,  5.,  8., 12., 22.],
        [-3.,  7., 19., 33., 33.]])
idx_sort:
 tensor([[3, 2, 0, 4, 1],
        [4, 1, 2, 0, 3]])

1.9 torch.topk

topk可以获取对应维度上,topk大或者topk小的元素。还是使用上上面的a作为例子。

# torch.topk 前k大个元素
print('-'*8+'topk'+'-'*8)
topk_res=torch.topk(a,dim=1,k=2,largest=True)
print(topk_res)

可以看到,维度1上最大的两个元素为(22,12)和(33,33)。

a: tensor([[ 8., 22.,  5.,  3., 12.],
        [33.,  7., 19., 33., -3.]])
--------topk--------
torch.return_types.topk(
values=tensor([[22., 12.],
        [33., 33.]]),
indices=tensor([[1, 4],
        [0, 3]]))

1.10 torch.kthvalue

kthvalue可以获取在指定维度上第k小的元素,只返回一个元素,且只能是第k小的(大的也不行)。

# torch.kthvalue
# get the k-th smallest values 第k小的元素
print('-'*8+'kthvalue'+'-'*8)
print(torch.kthvalue(a,k=3,dim=1))
a: tensor([[ 8., 22.,  5.,  3., 12.],
        [33.,  7., 19., 33., -3.]])

torch.return_types.kthvalue(
values=tensor([ 8., 19.]),
indices=tensor([0, 2]))

1.11 torch.isinf

isinf用于判断元素是否是无界的,这里故意除以0来制造无界的元素。

# torch.isinf 是否无界
print(torch.isinf(a/0))
tensor([[True, True, True, True, True],
        [True, True, True, True, True]])

1.12 torch.isfinite

isfinite用于判断是否有界。

# 是否有界
print(torch.isfinite(a/0))
tensor([[False, False, False, False, False],
        [False, False, False, False, False]])

1.13 torch.nan

# 是否是nan
a[0,0]=np.NAN
print(torch.isnan(a))
tensor([[ True, False, False, False, False],
        [False, False, False, False, False]])

你可能感兴趣的:(机器学习,pytorch,python,深度学习)