为了演示不同比较运算符的作用,先初始化两个tensor a和b。
# 创建a,b tensor并用均值为20,方差为10的高斯分布采样赋值
a=torch.empty(size=(2,2)).normal_(20,10).floor_() # floor 向下取整
b=torch.empty(size=(2,2)).normal_(20,10).floor_()
a[0,0],b[0,0]=10,10 # 强制修改[0,0]位置元素相同
a:
tensor([[10., 24.],
[ 9., 22.]])
b:
tensor([[10., 25.],
[22., 26.]])
equal方法只有当a和b形状元素值都相等才返回True。
print('equal:',torch.equal(a,b))
这里除了[0,0]位置其他元素都不相同,因此返回False.
equal: False
# a==b
print('eq:',torch.eq(a,b))
eq: tensor([[ True, False],
[False, False]])
# a > b , gt=Greater Than
print('a>b:',torch.gt(a,b))
每个a的元素都小于等于在b对应位置的元素,因此都返回False。
a>b: tensor([[False, False],
[False, False]])
# a>=b , ge=Greater than or Equal to
print('a>=b:',torch.ge(a,b))
a>=b: tensor([[ True, False],
[False, False]])
# a
print('a,torch.lt(a,b))
a<b: tensor([[False, True],
[ True, True]])
# a<=b , le=Less than or Equal to
print('a<=b:',torch.le(a,b))
a<=b: tensor([[True, True],
[True, True]])
# a!=b , ne=Not Equal
print('ne:',torch.ne(a,b))
ne: tensor([[False, True],
[ True, True]])
创建用于排序的张量。
print('\n'+'-'*8+'排序'+'-'*8+'\n')
a=torch.empty(size=(2,5)).normal_(20,10).floor_()
print('a:',a)
print('b:',b)
a: tensor([[28., 17., 18., 16., 19.],
[44., 30., 10., 9., 29.]])
排序操作,返回排序后的张量,以及排序后的原元素的索引。此处沿着维度1进行排序。
# torch.sort
print('-'*8+'sort'+'-'*8)
a_sort,a_idx=torch.sort(a,dim=1)
print('a_sort:\n',a_sort)
print('idx_sort:\n',a_idx)
--------sort--------
a_sort:
tensor([[ 3., 5., 8., 12., 22.],
[-3., 7., 19., 33., 33.]])
idx_sort:
tensor([[3, 2, 0, 4, 1],
[4, 1, 2, 0, 3]])
topk可以获取对应维度上,topk大或者topk小的元素。还是使用上上面的a作为例子。
# torch.topk 前k大个元素
print('-'*8+'topk'+'-'*8)
topk_res=torch.topk(a,dim=1,k=2,largest=True)
print(topk_res)
可以看到,维度1上最大的两个元素为(22,12)和(33,33)。
a: tensor([[ 8., 22., 5., 3., 12.],
[33., 7., 19., 33., -3.]])
--------topk--------
torch.return_types.topk(
values=tensor([[22., 12.],
[33., 33.]]),
indices=tensor([[1, 4],
[0, 3]]))
kthvalue可以获取在指定维度上第k小的元素,只返回一个元素,且只能是第k小的(大的也不行)。
# torch.kthvalue
# get the k-th smallest values 第k小的元素
print('-'*8+'kthvalue'+'-'*8)
print(torch.kthvalue(a,k=3,dim=1))
a: tensor([[ 8., 22., 5., 3., 12.],
[33., 7., 19., 33., -3.]])
torch.return_types.kthvalue(
values=tensor([ 8., 19.]),
indices=tensor([0, 2]))
isinf用于判断元素是否是无界的,这里故意除以0来制造无界的元素。
# torch.isinf 是否无界
print(torch.isinf(a/0))
tensor([[True, True, True, True, True],
[True, True, True, True, True]])
isfinite用于判断是否有界。
# 是否有界
print(torch.isfinite(a/0))
tensor([[False, False, False, False, False],
[False, False, False, False, False]])
# 是否是nan
a[0,0]=np.NAN
print(torch.isnan(a))
tensor([[ True, False, False, False, False],
[False, False, False, False, False]])