pytorch “step must be greater than zero“问题解决

在探究yolov5源代码中 process match部分,matches代码段的时候调试遇到了一下问题。

ValueError                                Traceback (most recent call last)
<ipython-input-40-9452df23b94c> in <module>
     15 print(matches)
     16 matchesn = matches.numpy()
---> 17 matches2 = matches[:, 2].argsort()[::-1]
     18 matchesn2 = matchesn[:, 2].argsort()[::-1]
     19 print(matches2)

ValueError: step must be greater than zero

经对比分析,这是由于pytorch和numpy中,argsort()函数均有定义,但是对tensor形式变量和numpy形式变量的具体使用方法不同。当变量a为一维numpy数组的时候,a.argsort()[::-1]能够正确用数组切片的形式([start : end : step]),实现一维数组a全部元素的倒序排列。当变量a为一维tensor的时候,上述方法失效并报错“step must be greater than zero”。

yolov5 在此处是将tensor转换为了numpy之后进行倒序实现,从tensor直接实现需要以下代码。

import torch
a = torch.tensor([1,5,2,6,7])
a,a_index = a.sort(descending = True)
b,b_index = a.sort()
a[a_index] = a[b_index]
print(a)

有关上述陈述的对比实验如下所示:

import torch
import numpy as np

iou = torch.tensor([[0.16512, 0.04280,  0.7912, 0.06599,  0.0755,  0.4665],
        [0.014043,  0.3173,  0.4420,  1.2253, 0.206817,  0.5997],
        [ 0.4398,  0.1185,  1.2385,  0.2133,  0.7412, 0.06974],
        [ 0.7442,  0.9128,  1.0040,  2.0243,  1.0281,  1.3334],
        [ 1.0045,  0.7125, 0.03617,  0.0962,  0.7367,  0.6041]])

iou_thres = 0.2
x = torch.where(iou > iou_thres )
x_stack = torch.stack(x,1)
matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1)

print(matches)
matchesn = matches.numpy()
matches2 = matches[:, 2].argsort()
matchesn2 = matchesn[:, 2].argsort()[::-1]
print(matches2)
print("***********")
print(matchesn2)
print(matchesn[:, 2].argsort())

tensor([[0.0000, 2.0000, 0.7912],
        [0.0000, 5.0000, 0.4665],
        [1.0000, 1.0000, 0.3173],
        [1.0000, 2.0000, 0.4420],
        [1.0000, 3.0000, 1.2253],
        [1.0000, 4.0000, 0.2068],
        [1.0000, 5.0000, 0.5997],
        [2.0000, 0.0000, 0.4398],
        [2.0000, 2.0000, 1.2385],
        [2.0000, 3.0000, 0.2133],
        [2.0000, 4.0000, 0.7412],
        [3.0000, 0.0000, 0.7442],
        [3.0000, 1.0000, 0.9128],
        [3.0000, 2.0000, 1.0040],
        [3.0000, 3.0000, 2.0243],
        [3.0000, 4.0000, 1.0281],
        [3.0000, 5.0000, 1.3334],
        [4.0000, 0.0000, 1.0045],
        [4.0000, 1.0000, 0.7125],
        [4.0000, 4.0000, 0.7367],
        [4.0000, 5.0000, 0.6041]])
tensor([ 5,  9,  2,  7,  3,  1,  6, 20, 18, 19, 10, 11,  0, 12, 13, 17, 15,  4,
         8, 16, 14])
***********
[14 16  8  4 15 17 13 12  0 11 10 19 18 20  6  1  3  7  2  9  5]
[ 5  9  2  7  3  1  6 20 18 19 10 11  0 12 13 17 15  4  8 16 14]

你可能感兴趣的:(YOLOv5源代码导读,debug,numpy,python,pytorch,debug,深度学习)