https://github.com/amdegroot/ssd.pytorch
链接:https://pan.baidu.com/s/1EqrkQOR0Vx4JGaqJn0um8w
提取码:rd8c
我是用的yolov3-pytorch的环境,不过一定要注意:pytorch版本必须小于1.3,我用的是pytorch1.2.0
自己电脑配置:win10+2G 显卡
参考这位博主对应自己的数据集进行修改,我的数据集是一分类。
如果和我一样是一分类,要在VOC0712.py
中修改,不然会出现KeyError: 'person'
VOC_CLASSES = ( # always index 0
'aeroplane', 'bicycle', 'bird', 'boat',
'bottle', 'bus', 'car', 'cat', 'chair',
'cow', 'diningtable', 'dog', 'horse',
'motorbike', 'person', 'pottedplant',
'sheep', 'sofa', 'train', 'tvmonitor')
#把上边的修改成下面的代码,自己数据集的类别,同时要加上 []
VOC_CLASSES = [( 'person')] # always index 0 改为自己的数据集类别
eval.py我出现这个错误
im_detect: 40/44 0.243s im_detect: 41/44 0.235s im_detect: 42/44 0.202s im_detect: 43/44 0.225s im_detect: 44/44 0.222s Evaluating detections Writing person VOC results file Traceback (most recent call last): File "F:/mx_matting/ssd.pytorch-master/eval.py", line 441, in <module> thresh=args.confidence_threshold) File "F:/mx_matting/ssd.pytorch-master/eval.py", line 416, in test_net evaluate_detections(all_boxes, output_dir, dataset) File "F:/mx_matting/ssd.pytorch-master/eval.py", line 421, in evaluate_detections do_python_eval(output_dir) File "F:/mx_matting/ssd.pytorch-master/eval.py", line 178, in do_python_eval ovthresh=0.5, use_07_metric=use_07_metric) File "F:/mx_matting/ssd.pytorch-master/eval.py", line 265, in voc_eval with open(imagesetfile, 'r') as f: FileNotFoundError: [Errno 2] No such file or directory: 'test.txt' VOC07 metric? Yes
定位到我的最后一个报错,即eval.py的265行,把imagesetfile 改为test.txt文件的在电脑中的绝对路径,路径带单引号
将
with open(imagesetfile, 'r') as f:
改入如下形式
with open('F:/mx_matting/ssd.pytorch-master/data/VOCdevkit/VOC2007/ImageSets/Main/test.txt', 'r') as f:
在train.py
大概200行左右添加下面的语句,创建一个txt文件,将路径写入
#将loss值保存在loss.txt文件中
doc = open('F:/mx_matting/ssd.pytorch-master/loss.txt', 'a') # 打开创建好的1.txt文件,参数a表示:打开写入,如果存在,则附加到文件的末尾
print(iteration,loss.data, file=doc)
完整代码如下:
if iteration % 10 == 0:
print('timer: %.4f sec.' % (t1 - t0))
print('iter ' + repr(iteration) + ' || Loss: %.4f ||' % (loss.data), end=' ') #data[0]改为data
#print(loss.data)
#将loss值保存在loss.txt文件中
doc = open('F:/mx_matting/ssd.pytorch-master/loss.txt', 'a') # 打开创建好的1.txt文件,参数a表示:打开写入,如果存在,则附加到文件的末尾
print(iteration,loss.data, file=doc)
根据我的这个博客,利用loss.txt
将loss曲线画出来
demo.py
# 使用SSD(Pytorch)进行目标检测
import os
import sys
import torch
from torch.autograd import Variable
import numpy as np
import cv2
from ssd import build_ssd
from data import VOC_CLASSES as labels
from matplotlib import pyplot as plt
# 定位到ssd.pytorch这个路径
module_path = os.path.abspath(os.path.join('F:/mx_matting/ssd.pytorch-master/demo'))
if module_path not in sys.path:
sys.path.append(module_path)
if torch.cuda.is_available():
torch.set_default_tensor_type('torch.cuda.FloatTensor')
# 构建架构,指定输入图像的大小(300),和要评分的对象类别的数量(X+1类)
net = build_ssd('test', 300, 2) # 【改1】这里改一下,如果有5类,就改成6
# 将预训练的权重加载到数据集上
net.load_weights('../weights/ssd300_VOC_9000.pth') # 【改2】这里改成你自己的模型文件
# 加载多张图像
# 【改3】改成你自己的文件夹
imgs = '../my_img/'
img_list = os.listdir(imgs)
for img in img_list:
# 对输入图像进行预处理
current_img = imgs + img
image = cv2.imread(current_img)
rgb_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
x = cv2.resize(image, (300, 300)).astype(np.float32)
x -= (104.0, 117.0, 123.0)
x = x.astype(np.float32)
x = x[:, :, ::-1].copy()
x = torch.from_numpy(x).permute(2, 0, 1)
# 把图片设为变量
xx = Variable(x.unsqueeze(0))
if torch.cuda.is_available():
xx = xx.cuda()
y = net(xx)
# 解析 查看结果
top_k = 10
plt.figure(figsize=(6, 6))
colors = plt.cm.hsv(np.linspace(0, 1, 21)).tolist()
currentAxis = plt.gca()
detections = y.data
scale = torch.Tensor(rgb_image.shape[1::-1]).repeat(2)
for i in range(detections.size(1)):
j = 0
while detections[0, i, j, 0] >= 0.6:
score = detections[0, i, j, 0]
label_name = labels[i-1]
display_txt = '%s: %.2f'%(label_name, score)
print(display_txt)
pt = (detections[0,i,j,1:]*scale).cpu().numpy()
coords = (pt[0], pt[1]), pt[2]-pt[0]+1, pt[3]-pt[1]+1
color = colors[i]
currentAxis.add_patch(plt.Rectangle(*coords, fill=False, edgecolor=color, linewidth=2))
currentAxis.text(pt[0], pt[1], display_txt, bbox={
'facecolor':color, 'alpha':0.5})
j += 1
plt.imshow(rgb_image)
plt.show()