在探究yolov5源代码中 process match部分,matches代码段的时候调试遇到了一下问题。
ValueError Traceback (most recent call last)
<ipython-input-40-9452df23b94c> in <module>
15 print(matches)
16 matchesn = matches.numpy()
---> 17 matches2 = matches[:, 2].argsort()[::-1]
18 matchesn2 = matchesn[:, 2].argsort()[::-1]
19 print(matches2)
ValueError: step must be greater than zero
经对比分析,这是由于pytorch和numpy中,argsort()函数均有定义,但是对tensor形式变量和numpy形式变量的具体使用方法不同。当变量a为一维numpy数组的时候,a.argsort()[::-1]能够正确用数组切片的形式([start : end : step]),实现一维数组a全部元素的倒序排列。当变量a为一维tensor的时候,上述方法失效并报错“step must be greater than zero”。
yolov5 在此处是将tensor转换为了numpy之后进行倒序实现,从tensor直接实现需要以下代码。
import torch
a = torch.tensor([1,5,2,6,7])
a,a_index = a.sort(descending = True)
b,b_index = a.sort()
a[a_index] = a[b_index]
print(a)
有关上述陈述的对比实验如下所示:
import torch
import numpy as np
iou = torch.tensor([[0.16512, 0.04280, 0.7912, 0.06599, 0.0755, 0.4665],
[0.014043, 0.3173, 0.4420, 1.2253, 0.206817, 0.5997],
[ 0.4398, 0.1185, 1.2385, 0.2133, 0.7412, 0.06974],
[ 0.7442, 0.9128, 1.0040, 2.0243, 1.0281, 1.3334],
[ 1.0045, 0.7125, 0.03617, 0.0962, 0.7367, 0.6041]])
iou_thres = 0.2
x = torch.where(iou > iou_thres )
x_stack = torch.stack(x,1)
matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1)
print(matches)
matchesn = matches.numpy()
matches2 = matches[:, 2].argsort()
matchesn2 = matchesn[:, 2].argsort()[::-1]
print(matches2)
print("***********")
print(matchesn2)
print(matchesn[:, 2].argsort())
tensor([[0.0000, 2.0000, 0.7912],
[0.0000, 5.0000, 0.4665],
[1.0000, 1.0000, 0.3173],
[1.0000, 2.0000, 0.4420],
[1.0000, 3.0000, 1.2253],
[1.0000, 4.0000, 0.2068],
[1.0000, 5.0000, 0.5997],
[2.0000, 0.0000, 0.4398],
[2.0000, 2.0000, 1.2385],
[2.0000, 3.0000, 0.2133],
[2.0000, 4.0000, 0.7412],
[3.0000, 0.0000, 0.7442],
[3.0000, 1.0000, 0.9128],
[3.0000, 2.0000, 1.0040],
[3.0000, 3.0000, 2.0243],
[3.0000, 4.0000, 1.0281],
[3.0000, 5.0000, 1.3334],
[4.0000, 0.0000, 1.0045],
[4.0000, 1.0000, 0.7125],
[4.0000, 4.0000, 0.7367],
[4.0000, 5.0000, 0.6041]])
tensor([ 5, 9, 2, 7, 3, 1, 6, 20, 18, 19, 10, 11, 0, 12, 13, 17, 15, 4,
8, 16, 14])
***********
[14 16 8 4 15 17 13 12 0 11 10 19 18 20 6 1 3 7 2 9 5]
[ 5 9 2 7 3 1 6 20 18 19 10 11 0 12 13 17 15 4 8 16 14]