Transformer实战-系列教程14:DETR 源码解读1

Transformer实战-系列教程总目录

有任何问题欢迎在下面留言
本篇文章的代码运行界面均在Pycharm中进行
本篇文章配套的代码资源已经上传
点我下载源码

DETR 算法解读

1、项目配置

主要环境

install PyTorch 1.5+
pip install pycocotools 
pip install cython
pip install scipy

需要下载coco数据集,这个数据集比较大,训练集8w图像,验证集4w图像。数据包括coco-data,打开后三个文件夹,annotations、train2014、val2014。

作者表示,一个epoch需要28分钟,300个epoch花费了6天,在一台有8个V100的机器上。但是我们学习这个算法只需要跑一些小的demo,不需要这么大的配置资源

项目主要是两大模块,一个是数据的处理,一个是网络的构建

运行项目:main.py
配置参数:

--coco_path B:\CV\Transformer\DETR\coco-2014

只需要知道coco文件夹的位置

2、CocoDetection类

2.1 CocoDetection类

class CocoDetection(torchvision.datasets.CocoDetection):
    def __init__(self, img_folder, ann_file, transforms, return_masks):
        super(CocoDetection, self).__init__(img_folder, ann_file)
        self._transforms = transforms
        self.prepare = ConvertCocoPolysToMask(return_masks)

    def __getitem__(self, idx):
        img, target = super(CocoDetection, self).__getitem__(idx)
        image_id = self.ids[idx]
        target = {'image_id': image_id, 'annotations': target}
        img, target = self.prepare(img, target)
        if self._transforms is not None:
            img, target = self._transforms(img, target)
        return img, target
  1. CocoDetection类,继承torchvision.datasets.CocoDetection
  2. 构造函数,图像数据路径、标注数据路径、数据增强参数、是否要mask
  3. 初始化类
  4. 数据增强参数
  5. ConvertCocoPolysToMask是一个自定义的类,实例化这个类的对象,并且传入return_masks参数
  6. 定义一个迭代序列的函数,传入索引
  7. 获取索引为idx的图像(img)和标签(target)
  8. 获取当前图像的ID
  9. 把标签制作成包含图像id和原始标注数据的字典
  10. 调用 prepare 对象将标注数据转化为掩码
  11. 是否需要进行数据增强
  12. 调用make_coco_transforms函数进行数据增强,make_coco_transforms函数使用_transforms作为变量传入当前类、
  13. 返回处理后的图像和标签

这是一个coco数据集的读取类,而且还是继承torchvision.datasets.CocoDetection

2.1 make_coco_transforms()函数

make_coco_transforms()函数是进行数据增强的函数

def make_coco_transforms(image_set):
    normalize = T.Compose([
        T.ToTensor(),
        T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    scales = [480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800]
    if image_set == 'train':
        return T.Compose([
            T.RandomHorizontalFlip(),
            T.RandomSelect(
                T.RandomResize(scales, max_size=1333),
                T.Compose([ T.RandomResize([400, 500, 600]),
                    T.RandomSizeCrop(384, 600), T.RandomResize(scales, max_size=1333), ])
            ),
            normalize,
        ])
    if image_set == 'val':
        return T.Compose([ T.RandomResize([800], max_size=1333), normalize,
        ])
    raise ValueError(f'unknown {image_set}')

数据增强比较简单,把数据转换为Tensor格式,进行归一化操作,随机水平翻转、从两个不同的缩放和裁剪策略中随机选择一个等

你可能感兴趣的:(Transformer实战,transformer,pytorch,深度学习,计算机视觉,DETR,物体检测)