pytorch读取coco数据集

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/a362682954/article/details/87915680
YOLOV3是工业上可以用的兼顾速度和准确率的一个深度学习目标检测模型,本系列文章将详细解释该模型的构成和实现,本文代码借鉴:https://github.com/eriklindernoren/PyTorch-YOLOv3

YOLOv3: An Incremental Improvement:https://pjreddie.com/media/files/papers/YOLOv3.pdf
原理在该篇博客就写的很详细了,这里就不赘述了:https://blog.csdn.net/leviopku/article/details/82660381

https://www.jianshu.com/p/d13ae1055302

github地址:https://github.com/18150167970/pytorch-yolov3-modifiy

1.文件组织架构
├── checkpoints/  #模型
├── data/  #数据
│   ├── get_coco_dataset.sh
│   ├── coco.names
├── utils/  #使用的函数
│   ├── __init__.py
│   ├── datasets.py
│   └── utils.py
├── config/  #配置文件
├── output/  #输出预测
├── weights/ #模型权重
├── README.md 
├── models.py #模型
├── train.py  #训练
├── test.py   #测试
├── detect.py #快速使用模型
└── requirements.txt  #环境
2.下载数据集
get_coco_dataset.sh 文件: 下载数据集并且制作训练集绝对路径文本

#!/bin/bash
# CREDIT: https://github.com/pjreddie/darknet/tree/master/scripts/get_coco_dataset.sh
 
# Clone COCO API
git clone https://github.com/pdollar/coco
cd coco
 
mkdir images
cd images
 
# Download Images
wget -c https://pjreddie.com/media/files/train2014.zip
wget -c https://pjreddie.com/media/files/val2014.zip
 
# Unzip
unzip -q train2014.zip
unzip -q val2014.zip
 
cd ..
 
# Download COCO Metadata
wget -c https://pjreddie.com/media/files/instances_train-val2014.zip
wget -c https://pjreddie.com/media/files/coco/5k.part
wget -c https://pjreddie.com/media/files/coco/trainvalno5k.part
wget -c https://pjreddie.com/media/files/coco/labels.tgz
tar xzf labels.tgz
unzip -q instances_train-val2014.zip
 
# Set Up Image Lists
paste <(awk "{print \"$PWD\"}" <5k.part) 5k.part | tr -d '\t' > 5k.txt
paste <(awk "{print \"$PWD\"}" trainvalno5k.txt
3.配置文件
config.py   可以先不看这个,这个是后面需要的路径名和一些超参数,这不是我们关注的重点,但是需要这个.

#!/usr/bin/env python
# -*- coding:utf-8 -*-
from pprint import pprint
 
 
class Config:
    epochs = 20
    batch_size = 1
    imge_folder = 'data/samples'
    classes = 80
 
    # 配置文件地址
    model_config_path = 'config/yolov3.cfg'
    data_config_path = 'config/coco.data'
    weight_path = 'weights/yolov3.weights'
    class_path = 'data/coco.names'
 
    # 超参数
    conf_threshold = 0.8
    nms_threshold = 0.4
    img_size = 416
    checkpoint_interval = 1
    use_cuda = True
    momentum = 0.9
    decay = 0.0005
    learning_rate = 0.001
    burn_in = 1000
 
    checkpoint_dir = 'checkpoints'
    train = 'data/coco/trainvalno5k.txt'
    valid = 'data/coco/5k.txt'
    names = 'data/coco.names'
    backup = 'backup/'
    eval = 'coco'
    # 判断终端输入是否正确
 
    def _parse(self, kwargs):
        state_dict = self._state_dict()
        for k, v in kwargs.items():
            if k not in state_dict:
                raise ValueError('UnKnown Option: "--%s"' % k)
            setattr(self, k, v)
 
        print('======user config========')
        pprint(self._state_dict())
        print('==========end============')
 
    # 终端输入替换默认配置
    def _state_dict(self):
        return {k: getattr(self, k) for k, _ in Config.__dict__.items()
                if not k.startswith('_')}
 
 
opt = Config()
 

4.读数据
在主函数中加pytorch数据加载函数

traindata = Datasets(train_path)
dataloader = torch.utils.data.DataLoader(
     traindata, batch_size=opt.batch_size, shuffle=False)
其中数据集Datasets函数为

#!/usr/bin/ebv pyhton
# -*- coding:utf-8 -*-
 
from __future__ import division
 
import os
import numpy as np
import torch
import sys
 
from torch.utils.data import Dataset
from skimage.transform import resize
import cv2
 
 
class Datasets(Dataset):
    def __init__(self, list_path, img_size=416):
        with open(list_path, 'r') as file:
            # readline() 读一行, readlines()读全部并返回list,这里返回的是图像绝对地址
            self.img_files = file.readlines()
        self.label_files = [path.replace('images', 'labels').replace(
            '.png', '.txt').replace('.jpg', '.txt') for path in self.img_files]
        self.img_shape = (img_size, img_size)
        self.max_objects = 50  # 最大物体数量
 
    def __getitem__(self, index):
        img_path = self.img_files[index % len(self.img_files)].rstrip()
        # Python中有三个去除头尾字符、空白符的函数,它们依次为:
        # strip: 用来去除头尾字符、空白符(包括\n、\r、\t、' ',即:换行、回车、制表符、空格)
        # lstrip:用来去除开头字符、空白符(包括\n、\r、\t、' ',即:换行、回车、制表符、空格)
        # rstrip:用来去除结尾字符、空白符(包括\n、\r、\t、' ',即:换行、回车、制表符、空格)
        img = np.array(cv2.imread(img_path))
 
        # 把不是彩色图像的用下一张图像替换
        while len(img.shape) != 3:
            index += 1
            img_path = self.imge_files[index % len(self.img_files)].rstrip()
            img = np.array(cv2.imread(img_path))
 
        h, w, _ = img.shape
 
        # 填充图片至正方形
        dim_diff = np.abs(h - w)
        pad1, pad2 = dim_diff // 2, dim_diff - dim_diff // 2
        pad = ((pad1, pad2), (0, 0), (0, 0)) if h <= w else (
            (0, 0), (pad1, pad2), (0, 0))
        # np.pad函数见 https://blog.csdn.net/qq_36332685/article/details/78803622
        # 这里pad两位指的是,第几轴,头尾增加pad1,pad2位数值
        input_img = np.pad(img, pad, 'constant', constant_values=128) / 255.
 
        # 这里注意的是,图片填充和resize(),标签也需要做相应操作,不然对不上
        padded_h, padded_w, _ = input_img.shape
        # cv2.resize()输出默认是3通道
        input_img = cv2.resize(input_img, (416, 416))
 
        input_img = np.transpose(input_img, (2, 0, 1))
        input_img = torch.from_numpy(input_img).float()
 
        # 制作标签
        label_path = self.label_files[index % len(self.img_files)].rstrip()
        lables = None
        if os.path.exists(label_path):
            # 五位标签,(类别,x,y,w,h) x,y为矩阵中心点
            labels = np.loadtxt(label_path).reshape(-1, 5)
            x1 = w * (labels[:, 1] - labels[:, 3] / 2)
            y1 = h * (labels[:, 2] - labels[:, 4] / 2)
            x2 = w * (labels[:, 1] + labels[:, 3] / 2)
            y2 = h * (labels[:, 2] + labels[:, 4] / 2)
            # 边界填充
            x1 += pad[1][0]
            y1 += pad[0][0]
            x2 += pad[1][0]
            y2 += pad[0][0]
            # resize
            labels[:, 1] = ((x1 + x2) / 2) / padded_w
            labels[:, 2] = ((y1 + y2) / 2) / padded_h
            labels[:, 3] *= w / padded_w
            labels[:, 4] *= h / padded_h
 
        # 初始化标签结果
        filled_labels = np.zeros((self.max_objects, 5))
        # 存储标签,如果没有就为零,超过50就舍弃
        if labels is not None:
            filled_labels[range(len(labels))[:self.max_objects]
                          ] = labels[:self.max_objects]
        filled_labels = torch.from_numpy(filled_labels)
        return img_path, input_img, filled_labels
 
    def __len__(self):
        return len(self.img_files)
 
————————————————
版权声明:本文为CSDN博主「冰菓(笑)」的原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/a362682954/article/details/87915680

你可能感兴趣的:(pytorch,数据集)