Pytorch 基于 Detectron2 从零实现 Unet

Detectron2解读全部文章链接:

  1. Facebook计算机视觉开源框架Detectron2学习笔记 — 从demo到训练自己的模型

  2. Detectron2 “快速开始” Detection Tutorial Colab Notebook 详细解读

  3. Detectron2 官方文档详细解读 (上)

  4. Detectron2 官方文档详细解读(下)

  5. Detectron2 代码解读(1)如何构建模型

  6. Pytorch 基于 Detectron2 从零实现 Unet

Detectron2 的基本流程已经搞得差不多了,正好最近在尝试用 Unet 做一些简单的语义分割,闲着没事就准备把 Unet 搬到 Detectron2 上面来。比起复杂的目标检测模型,语义分割模型实现起来还是比较简单的,因为它无需生成先验框以及实现复杂的 loss 计算,直接卷积出热力图再接 softmax 然后送进 CrossEntropy 就行了… (指 Unet)。总的来说实现 Unet 的代码量不会很大,因此正好拿来练手…

首先确定代码思路,有哪些参数是我们希望可以自由调节的。从我的任务来看,我的数据集比较简单,我不是很需要非常庞大的网络,同时原论文 16x 的下采样倍率也有点多了,我希望变成 8x。 因此我需要用户可以自由调节 1. 网络的卷积核数量,把原本的 64-128-256-512-1024 变成 16-32-48-64。2. 同时下采样变成 8 倍,也就是减少一层跨层连接。这两点通过一个参数就可以调节,只需要加入到 cfg 里面就可以了。

根据前面的文章对于 Detectron2 的了解,从零开始实现模型的流程大概是这样的:

  1. 添加我们自己的 cfg 配置文件,确定我们需要哪些条目以及添加哪些条目。

  2. 实现我们自己的数据集注册,我们需要把我们的数据集(可以是任何形式的)以一个函数读取(比如叫 dataset_function),然后把这个函数注册进 Detectron2。

  3. 实现我们自己的 Dataloader,实际上使用 detectron2 的默认 build_detection_train_loader 就可以了,只是在数据增强这方面我们需要自定义一下 DatasetMapper,这个 Mapper 把你的 dataset_function 读取上来的数据转换成模型可以接受的输入格式。(不明白请看我前面的文章)

  4. 之后我们发现 Detectron2 有一个内置的 meta_arch 叫 semantic_seg 我们可以直接把 Unet 套上去。这个 semantic_seg 在阅读代码后我们发现,它包含一个 backbone 和一个 classifier。这个 meta_arch 非常合适,我们决定直接用这个。

  5. 之后我们就需要自己实现一个 Backbone 和一个 classifier,对于 Unet 来说,Backbone 就是那个 U 型网络,classifier 就是最后一个卷积层(没错,就一层。。)和 loss,这个 Backbone 和 classifier 的返回格式必须严格按照 semantic_seg 中流程每一步定义的返回格式来,不然很麻烦。

  6. Loss 直接用 CrossEntropyLoss 就行了。

  7. 基本功能到这里已经实现了,后续看到什么坑填什么坑就好了。

接下来按照步骤来开始。

1. 设置 cfg 文件

我先解释一下,detecton2 的 cfg 机制是这样的:首先文件会读取所有默认的 cfg 条目,其次读取你的 config (yaml) 文件,从 yaml 文件中覆盖对应的条目。如果你需要加入自己的条目,你需要从代码里添加。

前面提到了,我们希望自定义卷积核数量和下采样倍率,这里我们使用一个 list 作为配置文件就可以定义了。比如原版结构我们编码成 [64, 128, 256, 512, 1024],那么我自己的新结构就可以编码成 [16, 32, 48, 64 ]。在 detectron2/projects 下创建 unet 文件夹,在里面新建 train_net.py:

from detectron2.engine import default_argument_parser, default_setup, launch
from detectron2.config import get_cfg

# 我的代码中这个函数放在了其他文件中,我这个实际上只加了一条,如果需要加很多# 条,最好放在单独的文件,比较清晰。
def add_unet_config(cfg):

    cfg.MODEL.BACKBONE.UNET_CHANNELS = [64, 128, 256, 512, 1024]
    
   
def setup(args):  
    cfg = get_cfg()  # 加载默认的 cfg
    add_unet_config(cfg)  # 添加我们自己的条目
    cfg.merge_from_file(args.config_file)   # 从文件中覆盖默认 cfg
    cfg.merge_from_list(args.opts)   # 从命令行中覆盖默认 cfg
    cfg.freeze()
    default_setup(cfg, args)
    return cfg


def main(args):
    cfg = setup(args)
    
    
if __name__ == "__main__":
    
    # 使用内置的命令行
    args = default_argument_parser().parse_args()
    print("Command Line Args:", args)
    # 运行
    launch(
        main,
        args.num_gpus,
        num_machines=args.num_machines,
        machine_rank=args.machine_rank,
        dist_url=args.dist_url,
        args=(args,),
    )

之后我们创建 configs 文件夹,新建一个 Base-UNet-DS16-Semantic.yaml,这是我们自己网络的配置文件,里面我加了这些条目:

MODEL:
  META_ARCHITECTURE: "SemanticSegmentor"
  BACKBONE:
    NAME: "build_unet_backbone"
    FREEZE_AT: 0
    UNET_CHANNELS: [16, 32, 48, 64, 70]
  SEM_SEG_HEAD:
    NAME: "UnetSemSegHead"
    NUM_CLASSES: 3
    IGNORE_VALUE: 255
DATASETS:
  TRAIN: ("MyDataset_Train",)
DATALOADER:
  NUM_WORKERS: 8
INPUT:
  MIN_SIZE_TRAIN: (512, 768, 1024, 1280, 1536, 1792)
  MIN_SIZE_TRAIN_SAMPLING: "choice"
  MAX_SIZE_TRAIN: 4096
  CROP:
    ENABLED: True
    TYPE: "absolute"
    SIZE: (512, 512)
SOLVER:
  IMS_PER_BATCH: 16
  BASE_LR: 0.01
  MAX_ITER: 90000
OUTPUT_DIR: './'

这样的话 cfg 的设置就完成了,读取上来 cfg 之后,我们直接访问 cfg 就可以访问这些参数了。

2. 数据集注册

注册 Dataset

如果你看完了我前面的文章,那么注册数据集的过程你应该比较清楚。如果你的数据集不是标准格式的 (COCO),那么你需要自己写一个 function,这个 function 需要读取你的数据,并且返回 list[dict] ,每个dict包含一张图片的信息,具体dict的格式规范我前面的文章有说。对于语义分割模型来说,只需要有:

  • file_name (string):图片文件的绝对路径
  • height(int):图片的高
  • width(int):图片的宽
  • image_id(string或者int):该图片独特的一个id
  • sem_seg_file_name(string):对应的gt图的绝对路径,这个图片具有和原图相同的大小,并且从0开始,每种像素值对应一类。

因此我的 function 是这么写的,其中我把每个训练数据的名字放在了 train.txt 里面,这个没有固定写法,你自己数据集是怎么摆的你就怎么写,你就让这个函数返回这个 list[dict] 就可以了

def load_train_data():

    dataset = []   # initialize dataset

    with open('data/train.txt', 'r') as f:

        for file in f.readlines():

            # Detectron2 standard format
            image = {
                'file_name': img_path + file.replace("\n",""),
                'height': 2048,
                'width': 2448,
                'image_id': file.split('.')[0],
                'sem_seg_file_name': segmentation_path + file.split('.')[0] + '.png'
            }

            dataset.append(image)
    
    return dataset

随后我们使用如下代码把 dataset 注册进 Detectron2:

DatasetCatalog.register("MyDataset_Train", load_train_data)

之后测试:

if __name__ == '__main__':
    
	import random
    dataset_dict = DatasetCatalog.get("MyDataset_Train")
    for d in random.sample(dataset_dict, 3):
        print(d)

输出:

{'file_name': './data/dataset/imgs/000796.jpg', 'height': 2048, 'width': 2448, 'image_id': '000796', 'sem_seg_file_name': './data/dataset/segmentations/000796.png'}
{'file_name': './data/dataset/imgs/000111.jpg', 'height': 2048, 'width': 2448, 'image_id': '000111', 'sem_seg_file_name': './data/dataset/segmentations/000111.png'}
{'file_name': './data/dataset/imgs/000762.jpg', 'height': 2048, 'width': 2448, 'image_id': '000762', 'sem_seg_file_name': './data/dataset/segmentations/000762.png'}

成功!数据集已经被注册进去了,之后使用 DatasetCatalog.get(“MyDataset_Train”) 就可以访问这个数据集了。

现在这个 Dataset 仅仅返回了包含图片信息的一个 Dict,然而它并非 tensor,不能直接送入模型进行处理,因此我们需要一个 Dataloader。这个 Dataloader 需要有一个 Mapper,这个 Mapper 负责把这个 Dict 变成模型可以输入的 Tensor,以及进行数据增强操作。

创建 Dataloader

我们知道 Dataloader 默认情况下通过 build_detection_train_loader 函数就可以创建,其中需要传入一个 Mapper 对象。这个 Mapper 对象我们可以直接使用 Detectron2 的 DatasetMapper,但是数据增强这里我们要有自己的操作,因此:

def get_train_aug(cfg):
        augs = [
            T.RandomCrop(
                cfg.INPUT.CROP.TYPE,
                cfg.INPUT.CROP.SIZE,
            ),
            T.RandomFlip()
        ]
        return augs

def build_train_loader(cfg):
        
        mapper = DatasetMapper(cfg, is_train=True, augmentations=augs(cfg))
        return build_detection_train_loader(cfg, mapper=mapper)

这里我们使用了最基础的两种数据增强 - RandomCrop 随机裁剪和 RandomFlip 随机翻转。这两种数据增强的用法可以查看 Detectron2 的官方文档,这里不再赘述。注意,我们这里是根据 cfg 创建的数据增强。另外这两个方法实际写在 Trainer 类下。

这里我们创建了一个 DatasetMapper 类的对象,数据增强是我们自己定义的,随后把这个 DatasetMapper 传入 build_detection_train_loader(),这个函数返回一个 DataLoader 对象,包含了训练数据,我们只需要使用:

train_loader = build_train_loader(cfg)
for data in train_loader:
    ...

就可以访问里面的数据了。Dataloader 到这里就创建完毕了。测试:

train_loader = build_train_loader(cfg)
for i in train_loader:
    print(i)
    break

输出了一个 batch,这里取其中一个图片的信息:

{'file_name': './data/dataset/imgs/000327.jpg', 'height': 2048, 'width': 2448, 'image_id': '000327', 'image': tensor([[[250, 250, 250,  ..., 200, 202, 204],
         [252, 252, 250,  ..., 201, 204, 204],
         [245, 255, 255,  ..., 211, 196, 193],
         ...,
         [249, 252, 253,  ..., 240, 238, 235],
         [246, 253, 252,  ..., 242, 240, 242],
         [243, 248, 246,  ..., 242, 239, 242]],

        [[249, 249, 249,  ..., 168, 171, 176],
         [248, 248, 248,  ..., 176, 175, 174],
         [249, 251, 246,  ..., 175, 175, 184],
         ...,
         [239, 240, 244,  ..., 225, 224, 223],
         [238, 245, 243,  ..., 227, 224, 224],
         [236, 244, 238,  ..., 225, 220, 222]],

        [[251, 251, 253,  ..., 192, 200, 206],
         [253, 254, 254,  ..., 190, 201, 215],
         [250, 253, 250,  ..., 199, 208, 224],
         ...,
         [252, 252, 254,  ..., 246, 248, 241],
         [248, 252, 253,  ..., 248, 248, 247],
         [243, 250, 249,  ..., 246, 245, 245]]], dtype=torch.uint8), 'sem_seg': tensor([[0, 0, 0,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        ...,
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]])}

可以看到,这个格式已经可以被模型接受了,其中 ‘image’ 和 ‘sem_seg’ 被读取成了数据增强之后的 tensor 格式。实际上送入模型的也是这个 dict。模型会从这个 dict 中提取数据,这也是为什么推荐你使用 Detectron2 的传统格式,如果你是完全从头自己写的模型,而没用 meta_arch 抽象结构的话,这里你模型获得数据之后还需要自己去解析这个 dict,相当麻烦。所以尽量还是 Detectron2 一用用全套… Dataloader 和 模型结构全用他的,省了很多事。

3. 开始写模型

上面说了,如果不用 Detectron2 的 meta_arch 抽象模型结构的话,会很麻烦,因为你的模型还需要自己解析这个 dict (你当然可以让 dataloader 直接输出你的数据 tensor 而不是这个 dict,但是这样的话你为什么要用 Detectron2… )。因为是语义分割模型,我们在 detectron2/detectron2/modeling/meta_arch/semantic_seg.py 里面发现了我们想用的这个抽象结构:

这个文件里注册了一个叫 SemanticSegmentor 的类,我们来看一下这个类干了什么,只看重点:

@META_ARCH_REGISTRY.register()
class SemanticSegmentor(nn.Module):
    
    
    @configurable
	def __init__(
            self,
            *,
            backbone: Backbone,
            sem_seg_head: nn.Module,
            pixel_mean: Tuple[float],
            pixel_std: Tuple[float]
        ):
        
            super().__init__()
            # SemanticSegmentor 主要包含一个 backbone (Backbone对象) 和一个 sem_seg_head(nn.Module对象)
            self.backbone = backbone
            self.sem_seg_head = sem_seg_head
            self.register_buffer("pixel_mean", torch.tensor(pixel_mean).view(-1, 1, 1), False)
            self.register_buffer("pixel_std", torch.tensor(pixel_std).view(-1, 1, 1), False)

	# 看一下它是怎么 forward的
    def forward(self, batched_inputs):
			
            # 解析传进来的 dict,把传进来的 list[dict] 里面的 "image" 变成 tensor
            images = [x["image"].to(self.device) for x in batched_inputs]
            images = [(x - self.pixel_mean) / self.pixel_std for x in images]
            images = ImageList.from_tensors(images, self.backbone.size_divisibility)
			# 送入 backbone 得到输出
            features = self.backbone(images.tensor)
			# 如果输入的 dict 里面有 "sem_seg" 这一项 (8-bit 的类别图),存入 targets
            if "sem_seg" in batched_inputs[0]:
                targets = [x["sem_seg"].to(self.device) for x in batched_inputs]
                targets = ImageList.from_tensors(
                    targets, self.backbone.size_divisibility, self.sem_seg_head.ignore_value
                ).tensor

            else:
                targets = None
            # 根据 backbone 的输出 features,和 ground_truth 图 targets,送入 sem_seg_head 模块算出 预测pred 和 loss
            results, losses = self.sem_seg_head(features, targets)
			
            # 下面略 ....

显然,这里的 backbone 和 sem_seg_head 需要我们自己填进去,同时:

backbone 需要返回 features,任意形式

sem_seg_head 接受 features 和 targets(N, H, W)算出最终输出和loss

那么对于 Unet 来说

backbone:是 Detectron2 的一个 backbone 对象,这个对象必须包含一个 out_shape 成员,用于创建 sem_seg_head。输出的格式不限制,只要你的 sem_seg_head 能接受就好。在 Unet 里就是那一系列 U 形结构,最终输出一个 (N, K, H, W)的 tensor,K 是最最后的通道数量。

sem_seg_head:是一个 nn.Module 对象,必须包含 size_divisibility 和 ignore_value 成员,用不用得上也必须得有,你可以随便设一个,只要不用就好了。这个在 Unet 里就是单独一个 1x1 的卷积,输入(N, K, H, W),输出(N, C, H, W)。这个 sem_seg_head 在训练时需要返回: None, loss(dict) 这个 loss 是一个 dict,形式是 {“sem_seg_loss": loss}。推理时需要返回 softmax(output), {} 因为推理时我们不需要 loss。

下面就开工了!

Backbone 部分

首先定义 Unet 的下采样模块和上采样模块,具体查看论文吧,本文还是以 Detectron2 的教学为主:

下采样模块是 maxpool - conv3x3 - bn - relu - conv3x3 - bn - relu

上采样模块是 upconv2x2 - conv3x3 - bn - relu - conv3x3 - bn - relu

这样很快就能写好:

class UnetBlockDown(nn.Module):

    def __init__(self, in_channels, out_channels):

        super().__init__()
        # 原论文3x3卷积没有加 Padding,这样会导致图片尺寸不断缩小... 之后 concat 的时候各种 padding,非常麻烦
        # 实测这里 padding=1,保持图片尺寸的影响不大。
        self.pool = nn.MaxPool2d(kernel_size=2)
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
    
    def forward(self, x):
        x = self.pool(x)
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.relu(self.bn2(self.conv2(x)))
        return x
    
class UnetBlockUp(nn.Module):

    def __init__(self, in_channels, out_channels, concat_channel):

        super().__init__()
        self.up_conv = nn.ConvTranspose2d(in_channels, in_channels // 2, 
                                          kernel_size=2, stride=2)
        self.conv1 = nn.Conv2d(in_channels // 2 + concat_channel, 
                               out_channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
	
    # 这里是因为 UnetBlockUp 需要 concatenate 之后再卷积,先对上一层的输出图进行上采样
    # 之后 concat 上保存的特征图(尺寸不对的话需要 Pad ),再 conv3x3-bn-relu-conv3x3-bn-relu
    def forward(self, x_in, x_saved):
        
        x_in = self.up_conv(x_in)

        # 这里的 Padding 参考的 https://github.com/milesial/Pytorch-UNet
        diffY = x_saved.size()[2] - x_in.size()[2]
        diffX = x_saved.size()[3] - x_in.size()[3]
        x_in = F.pad(x_in, [diffX // 2, diffX - diffX // 2,
                          diffY // 2, diffY - diffY // 2])
        x = torch.cat([x_saved, x_in], dim=1)
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.relu(self.bn2(self.conv2(x)))

        return x 

之后把它们组合成 unet 即可:

首先包含一个 stem 是去掉了 maxpool 的下采样模块,conv3x3 - bn - relu - conv 3x3 - bn - relu

之后就是下采样,保存特征图,之后上采样的时候把对应的特征图 concatenate 上去(合并)即可,非常简单。

import torch
import torch.nn as nn
import torch.nn.functional as F

from detectron2.modeling import Backbone, BACKBONE_REGISTRY

class UNet(Backbone):
    # 注意,继承了 Backbone 类而不是 nn.Module。
    
    def __init__(self, channels=[64, 128, 256, 512, 1024]):

        super().__init__()
        
        self.stem = nn.Sequential(
            nn.Conv2d(3, channels[0], kernel_size=3, padding=1),
            nn.BatchNorm2d(channels[0]),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels[0], channels[0], kernel_size=3, padding=1),
            nn.BatchNorm2d(channels[0]),
            nn.ReLU(inplace=True)
        )
        down_stage = []
        up_stage = []
        in_channel = channels[0]
        # 一个 for 循环根据 channels 创建下采样模块
        for out_channel in channels[1:]:
            down_stage.append(UnetBlockDown(in_channel, out_channel))
            in_channel = out_channel

        in_channel = channels[-1]
        # 再一个 for 循环根据 channels 和下采样模块的输出通道数创建上采样模块
        for channel in channels[::-1][1:]:
            out_channel = (channel + in_channel // 2) // 2
            up_stage.append(UnetBlockUp(in_channel, out_channel, channel))
            in_channel = out_channel

        self.out_shape = out_channel
        self.down_stage = nn.Sequential(*down_stage)
        self.up_stage = nn.Sequential(*up_stage)


    def forward(self, x):

        x = self.stem(x)
        saved_features = [x]
        for ind in range(len(self.down_stage) - 1):
            x = self.down_stage[ind](x)
            saved_features.append(x)

        x = self.down_stage[-1](x)
        for ind in range(len(self.up_stage)):
            x = self.up_stage[ind](x, saved_features[-ind-1])

        return x

    def output_shape(self):
        return self.out_shape

这里整个 Backbone 就创建结束了,我们需要提供一个接口给 Detectron2,让 Detectron2 根据 cfg 可以直接创建我们的模型:

@BACKBONE_REGISTRY.register()
def build_unet_backbone(cfg, input_shape):
    return UNet(
                channels=cfg.MODEL.BACKBONE.UNET_CHANNELS,
                activation=cfg.MODEL.BACKBONE.ACTIVATION,
                )

sem_seg_head 部分

这部分就更简单了,只包含一个 1x1 的卷积和一个算 loss 的 criterion 模块,我们这里用的交叉熵 nn.CrossEntropyLoss() 模块,非常方便。

import torch
import torch.nn as nn
import torch.nn.functional as F

from detectron2.modeling import SEM_SEG_HEADS_REGISTRY

# 用 Registry 机制把这个 sem_seg_head 注册进去
@SEM_SEG_HEADS_REGISTRY.register()
class UnetSemSegHead(nn.Module):

    def __init__(self, cfg, in_channel):

        super().__init__()
        self.conv = nn.Conv2d(in_channel, cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES, 
                            kernel_size=3, padding=1)
        self.criterion = nn.CrossEntropyLoss()
        
        # 为了 Detectron2 的兼容性
        self.size_divisibility = 0
        self.ignore_value = cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE

    def forward(self, x, targets=None):
        x = self.conv(x)
        
        if self.training:
            loss = self.criterion(x, targets)
            return None, {"loss_sem_seg": loss}
        else:
            return F.softmax(x, dim=1), {} 

测试

这样的话整个网络就完成了,是不是非常简单…,我们在 cfg 里面设置

MODEL.BACKBONE.NAME 项为 “build_unet_backbone”

MODEL.META_ARCHITECTURE 项为 “SemanticSegmentor”

MODEL.SEM_SEG_HEAD.NAME 项为 “UnetSemSegHead”

之后训练时使用 DefaultTrainer 的类方法 build_model(cfg) 即可创建模型:

测试(在上面 train_net.py 的 main() 函数中):

def main(args):
    cfg = setup(args)
    model= DefaultTrainer.build_model(cfg)
    print(model)

不出意外,模型创建成功,到这里就结束了。

4. 训练

我这里提供一个简单的训练脚本,Detectron2 训练完之后会自动 Evaluate,这里因为没设置 Evaluator 所以会报错… 不过没关系,这样模型训练完已经保存下来了,只是跳过了 evaluate 环节。如果你需要 Evaluate 可以参考 Detectron2 自带的 SemSegEvaluator 类,同时,你如果需要单张推理,你只需要用 DefaultPredictor 类就行,具体的写法和下面我的 Trainer 类没什么本质区别。

我的 Trainer 类继承了 DefaultTrainer 类,所有 lr_scheduler,max_iter,save_iter,base_lr 什么的都在 config 下就可以设置了,完全无需自己写,下面是我 train_net.py 的完整代码,可以参考:

import torch
import os
import numpy as np

import detectron2.data.transforms as T
from detectron2.engine import DefaultTrainer, default_argument_parser, default_setup, launch
from detectron2.config import get_cfg
from detectron2.data import DatasetMapper, MetadataCatalog, build_detection_train_loader
# 这里你还需要把 unet 文件夹添加到系统路径下去... 导入什么的这里就不讲了,我的所有模型文件都定义在 projects/UNet/unet/ 下面
from detectron2.projects.unet import *
# 我的注册数据集的文件放在 projects/UNet/data/register_data.py 里面
from data import register_dataset

    
class Trainer(DefaultTrainer):
	
    # 继承 DefaultTrainer 需要继承这个类方法。
    @classmethod
    def build_train_loader(cls, cfg):
      
        mapper = DatasetMapper(cfg, is_train=True, augmentations=Trainer.get_train_aug(cfg))
        return build_detection_train_loader(cfg, mapper=mapper)

	# 我用的 data_augmentation 定义在这里了,我后来又加入了一些 random brightness/contrast 来适应我的数据集。
    # 根据不同的数据集自己设置就好了。
    @classmethod
    def get_train_aug(cls, cfg):
        augs = [
            T.Resize(
                cfg.INPUT.MIN_SIZE_TRAIN
            ),
            T.RandomCrop(
                cfg.INPUT.CROP.TYPE,
                cfg.INPUT.CROP.SIZE,
            ),
            T.RandomFlip()
        ]
        if cfg.INPUT.ENABLE_RANDOM_BRIGHTNESS is not None:
            (min_scale, max_scale) = cfg.INPUT.ENABLE_RANDOM_BRIGHTNESS
            augs.append(
                T.RandomBrightness(min_scale, max_scale)
            )
        if cfg.INPUT.ENABLE_RANDOM_CONTRAST is not None:
            (min_scale, max_scale) = cfg.INPUT.ENABLE_RANDOM_CONTRAST
            augs.append(
                T.RandomContrast(min_scale, max_scale)
            )
        return augs


def setup(args):
    cfg = get_cfg()
    add_unet_config(cfg)
    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    cfg.freeze()
    default_setup(cfg, args)
    return cfg


def main(args):
    cfg = setup(args)
    
    if args.eval_only:
        model = Trainer.build_model(cfg)
        DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
            cfg.MODEL.WEIGHTS, resume=args.resume
        )
        res = Trainer.test(cfg, model)
        return res
    
    trainer = Trainer(cfg)
    trainer.resume_or_load(resume=args.resume)
    return trainer.train()


if __name__ == "__main__":

    args = default_argument_parser().parse_args()
    print("Command Line Args:", args)
    launch(
        main,
        args.num_gpus,
        num_machines=args.num_machines,
        machine_rank=args.machine_rank,
        dist_url=args.dist_url,
        args=(args,),
    )

我的 config 文件放在 configs/Base-UNet-DS16-Semantic.yaml 里面,这是完整文件:

MODEL:
  META_ARCHITECTURE: "SemanticSegmentor"
  BACKBONE:
    NAME: "build_unet_backbone"
    FREEZE_AT: 0
    UNET_CHANNELS: [16, 32, 32, 64]
  SEM_SEG_HEAD:
    NAME: "UnetSemSegHead"
    NUM_CLASSES: 3
    IGNORE_VALUE: 255
DATASETS:
  TRAIN: ("MyDataset_Train",)
DATALOADER:
  NUM_WORKERS: 8
INPUT:
  CROP:
    ENABLED: True
    TYPE: "absolute"
    SIZE: (384, 384)
  ENABLE_RANDOM_BRIGHTNESS: (0.6, 1.4)
  ENABLE_RANDOM_CONTRAST: (0.7, 1.3)
SOLVER:
  IMS_PER_BATCH: 16
  BASE_LR: 0.01
  MAX_ITER: 90000
OUTPUT_DIR: './output/'

最后的最后,在命令行输入,不出意外模型就跑起来了,每5000个iter自动保存一次(也可以在cfg设置),保存在 output/ 文件夹下。

python3 train_net.py --config-file configs/Base-UNet-DS16-Semantic.yaml 

随后如果需要继续训练,这样就 OK!

python3 train_net.py --config-file configs/Base-UNet-DS16-Semantic.yaml --resume MODEL.WEIGHTS 你的模型路径

你如果需要单张推理,自己写一个 DefaultPredictor 类,把 train_net.py 中的 Trainer 类换掉就好了,你如果理解的话改几行代码就可以了,这里不讲了~

你可能感兴趣的:(CV,计算机视觉,python,深度学习,神经网络)