torch中如何找出矩阵中元素之大于某个阈值的所有元素的下标?torch.find() ?

近期用到了torch中要查找矩阵中元素大于某个阈值的函数,torch中的函数一般都为 torch.函数名  比如 torch.max(), 于是乎,搜了torch.find() ,各种搜索都搜不到, 几经周折终于搞定,用法如下:a=torch.Tensor(2,3):range(1,6); 想找到元素值大于3的元素下标,则下标索引 b=a[a:gt(3)]  ,,,


错了错了,返回的是元素的值,,,待更正~~~~~~~~~~~~~~~~~~~~~~~~


torch中如何找出矩阵中元素之大于某个阈值的所有元素的下标?torch.find() ?_第1张图片


正确做法如下:

id=torch.range(1,a:nElement())[a:gt(3)]

如下图所示:

torch中如何找出矩阵中元素之大于某个阈值的所有元素的下标?torch.find() ?_第2张图片


6个元素中大于3的元素下标索引为 3,4,6  torch中的顺序是先行后列,,完结



另外发现一个讲caffe框架的课程


讲得非常好,推荐给大家~


你可能感兴趣的:(torch,torch.find)