【代码复现】CMU-Net进行语义分割

文章目录

  • 1. main.py
    • 1.1. 参数声明
    • 1.2. 配置参数以及模型相关
      • 1.2.1. 获取配置参数,打印输出,选择损失函数
      • 1.2.2. 模型声明,相关参数、优化器、学习策略的选择
    • 1.3. 数据集相关
      • 1.3.1. 读取数据集
      • 1.3.2. 数据增强方法
      • 1.3.2. 创建 “数据集对象”
      • 1.3.3. DataLoader()
    • 1.4. 训练epoch
      • 1.4.1. solver.py-模型训练脚本
  • 2. 模型CMUNet.py
    • 2.1. Multi-Scale Attention Gate模块
    • 2.2. ConvMixerBlock模块


摘要:CMU-Net属于多尺度语义分割,原论文用的数据集是超声图像。这里使用IDRiD数据集对硬渗出物进行分割。

1. main.py

1.1. 参数声明

这里没啥要说的,了解一下具体参数就行

def parse_args():
    parser = argparse.ArgumentParser()

    parser.add_argument('--name', default='CMUnet',
                        help='model name')
    parser.add_argument('--epochs', default=300, type=int, metavar='N',
                        help='number of total epochs to run')
    parser.add_argument('-b', '--batch_size', default=8, type=int,
                        metavar='N', help='mini-batch size (default: 8)')
    # model
    parser.add_argument('--deep_supervision', default=False, type=str2bool)  #########
    parser.add_argument('--input_channels', default=3, type=int,
                        help='input channels')
    parser.add_argument('--num_classes', default=1, type=int,
                        help='number of classes')
    parser.add_argument('--input_w', default=256, type=int,
                        help='image width')
    parser.add_argument('--input_h', default=256, type=int,
                        help='image height')
    # loss
    parser.add_argument('--loss', default='BCEDiceLoss',
                        choices=LOSS_NAMES)
    # dataset
    parser.add_argument('--dataset', default='IDRiD',
                        help='dataset name')
    parser.add_argument('--img_ext', default='.png',
                        help='image file extension')
    parser.add_argument('--mask_ext', default='.png',
                        help='mask file extension')
    # optimizer
    parser.add_argument('--optimizer', default='Adam',
                        choices=['Adam', 'SGD'],
                        help='loss: ' +
                             ' | '.join(['Adam', 'SGD']) +
                             ' (default: Adam)')
    parser.add_argument('--lr', '--learning_rate', default=0.0001, type=float,
                        metavar='LR', help='initial learning rate')
    parser.add_argument('--momentum', default=0.9, type=float,
                        help='momentum')
    parser.add_argument('--weight_decay', default=1e-4, type=float,
                        help='weight decay')
    parser.add_argument('--nesterov', default=False, type=str2bool,
                        help='nesterov')
    # scheduler
    parser.add_argument('--scheduler', default='CosineAnnealingLR',
                        choices=['CosineAnnealingLR', 'ReduceLROnPlateau', 'MultiStepLR', 'ConstantLR'])
    parser.add_argument('--min_lr', default=1e-5, type=float,
                        help='minimum learning rate')
    parser.add_argument('--factor', default=0.1, type=float)
    parser.add_argument('--patience', default=2, type=int)
    parser.add_argument('--milestones', default='1,2', type=str)
    parser.add_argument('--gamma', default=2 / 3, type=float)
    parser.add_argument('--early_stopping', default=-1, type=int,
                        metavar='N', help='early stopping (default: -1)')
    parser.add_argument('--num_workers', default=4, type=int)
    config = parser.parse_args()
    return config

1.2. 配置参数以及模型相关

这里开始进入主函数main()

1.2.1. 获取配置参数,打印输出,选择损失函数

def main():
    config = vars(parse_args())  # 获取参数
    os.makedirs('checkpoint/%s' % config['name'], exist_ok=True)  # 创建保存结果的根目录

    # 打印所有配置参数
    print('-' * 20)
    for key in config:
        print('%s: %s' % (key, config[key]))
    print('-' * 20)

    # 将配置参数保存在./config['name']/config.yml中
    with open('checkpoint/%s/config.yml' % config['name'], 'w') as f:
        yaml.dump(config, f)


    # 选择loss函数,用于后续模型损失计算
    if config['loss'] == 'BCEWithLogitsLoss':
        criterion = nn.BCEWithLogitsLoss().cuda()
    else:
        criterion = losses.__dict__[config['loss']]().cuda()

    cudnn.benchmark = True

1.2.2. 模型声明,相关参数、优化器、学习策略的选择

    # 模型声明
    model = CMUNet(img_ch=3, output_ch=1, l=7, k=7)
    model = model.cuda()

    # 把需要计算梯度的参数params过滤出来
    params = filter(lambda p: p.requires_grad, model.parameters())

    # 选择优化器:Adam 或 SGD
    if config['optimizer'] == 'Adam':
        optimizer = optim.Adam(params, lr=config['lr'], weight_decay=config['weight_decay'])
        
    elif config['optimizer'] == 'SGD':
        optimizer = optim.SGD(params, lr=config['lr'], momentum=config['momentum'], nesterov=config['nesterov'], weight_decay=config['weight_decay'])
        
    else:
        raise NotImplementedError
        
    # 选择 scheduler
    if config['scheduler'] == 'CosineAnnealingLR':
        scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=config['epochs'], eta_min=config['min_lr'])
        
    elif config['scheduler'] == 'ReduceLROnPlateau':
        scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, factor=config['factor'], patience=config['patience'], verbose=1, min_lr=config['min_lr'])
        
    elif config['scheduler'] == 'MultiStepLR':
        scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[int(e) for e in config['milestones'].split(',')], gamma=config['gamma'])
        
    elif config['scheduler'] == 'ConstantLR':
        scheduler = None
        
    else:
        raise NotImplementedError

1.3. 数据集相关

1.3.1. 读取数据集

    # 遍历每个训练集和验证集的img路径,如:'inputs/IDRiD_512/train/images/IDRiD_54_0.png'
    train_img_ids = glob(os.path.join('inputs', config['dataset'], 'train','images', '*' + config['img_ext']))
    val_img_ids = glob(os.path.join('inputs', config['dataset'], 'test', 'images', '*' + config['img_ext']))
    
    # 获取每张img的名称,如:'IDRiD_54_0'
    train_img_ids = [os.path.splitext(os.path.basename(p))[0] for p in train_img_ids]
    val_img_ids = [os.path.splitext(os.path.basename(p))[0] for p in val_img_ids]

1.3.2. 数据增强方法

训练集和数据集的数据增强方法不同,一般训练集增强方法更多,而验证集主要是对图片进行规范化

    # 训练集的数据增强方法
    train_transform = Compose([
        RandomRotate90(),             # 旋转90°
        # transforms.Flip(),
        Resize(config['input_h'], config['input_w']),
        transforms.Normalize(),
    ])
    # 验证集的数据增强方法
    val_transform = Compose([
        Resize(config['input_h'], config['input_w']),
        transforms.Normalize(),
    ])

1.3.2. 创建 “数据集对象”

通过Dataset()类创建训练集对象验证集的对象

    # 创建训练集对象
    train_dataset = Dataset(
        img_ids=train_img_ids,        # img的名称,如:'IDRiD_54_0'
        img_dir=os.path.join('inputs', config['dataset'], 'train','images'),  # imgs所在的文件夹路径:'inputs/IDRiD_512/train/images'
        mask_dir=os.path.join('inputs', config['dataset'], 'train', 'masks'), # masks所在的文件夹路径:'inputs/IDRiD_512/train/masks'
        img_ext=config['img_ext'],    # img的后缀,如:.png
        mask_ext=config['mask_ext'],  # mask的后缀,如:.png
        num_classes=config['num_classes'],   # 类别个数:1
        transform=train_transform)    # 数据增强(上面已经定义)
        
    # 创建验证集对象,同上
    val_dataset = Dataset(
        img_ids=val_img_ids,
        img_dir=os.path.join('inputs', config['dataset'], 'test','images'),
        mask_dir=os.path.join('inputs', config['dataset'], 'test', 'masks'),
        img_ext=config['img_ext'],
        mask_ext=config['mask_ext'],
        num_classes=config['num_classes'],
        transform=val_transform)

下面,具体来看Dataset()类的实现:

类Dataset()中有一个魔法函数__getitem__,这个函数会在epoch阶段数据读取时调用,具体在代码中:“for input, target, _ in train_loader:”里调用。

import os
import cv2
import numpy as np
import torch
import torch.utils.data

class Dataset(torch.utils.data.Dataset):
    def __init__(self, img_ids, img_dir, mask_dir, img_ext, mask_ext, num_classes, transform=None):
        self.img_ids = img_ids     # img图片名称列表
        self.img_dir = img_dir     # img图片文件夹所在的路径
        self.mask_dir = mask_dir   # mask图片文件夹所在的路径
        self.img_ext = img_ext     # img图片的后缀
        self.mask_ext = mask_ext   # mask图片的后缀
        self.num_classes = num_classes  # 类别个数:1
        self.transform = transform      # 数据增强方法

    # __len__() 会在后面的torch.utils.data.DataLoader()中调用
    def __len__(self):
        return len(self.img_ids)
 
    # 会在epoch阶段读取数据时调用,如:for input, target, _ in train_loader:
    def __getitem__(self, idx):
        img_id = self.img_ids[idx]   # img的名称
        img = cv2.imread(os.path.join(self.img_dir, img_id + self.img_ext)) # img所在的路径
    
        # 读取masks
        mask = []
        for i in range(self.num_classes):
            mask.append(cv2.imread(os.path.join(self.mask_dir, str(i),
                        img_id + self.mask_ext), cv2.IMREAD_GRAYSCALE)[..., None])
                        # cv2.IMREAD_GRAYSCALE)[..., None]:读取为灰度图像并增加一个维度
        mask = np.dstack(mask)   # 对上面列表mask[]堆叠

    
        if self.transform is not None:
            augmented = self.transform(image=img, mask=mask)
            img = augmented['image']
            mask = augmented['mask']
        
        img = img.astype('float32') / 255
        img = img.transpose(2, 0, 1)
        mask = mask.astype('float32') / 255
        mask = mask.transpose(2, 0, 1)
        
        return img, mask, {'img_id': img_id}

1.3.3. DataLoader()

将上一步的数据集对象,输入到DataLoader()

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=config['batch_size'],
        shuffle=True,
        num_workers=config['num_workers'],
        drop_last=True)  # 如果数据集的大小不能被批次大小整除,决定是否丢弃最后一个不完整的批次。
        
    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=config['batch_size'],
        shuffle=False,
        num_workers=config['num_workers'],
        drop_last=False) # 不丢弃最后一个不完整的批次

1.4. 训练epoch

训练之前,对log进行了一些设置:


    # 创建了一个有序字典 log,用于记录模型的训练和验证过程中的一些关键指标
    log = OrderedDict([
        ('epoch', []),
        ('lr', []),
        ('loss', []),
        ('iou', []),
        ('val_loss', []),
        ('val_iou', []),
        ('val_dice', []),
    ])

epoch迭代过程:

    # 初始化参数
    best_iou = 0
    trigger = 0   # 含义??

    # 开始迭代epoch
    for epoch in range(config['epochs']):
        print('Epoch [%d/%d]' % (epoch, config['epochs']))  # 显示迭代到第几轮

        # 这里将train和val过程封装到脚本solver.py中(观赏性更好)
        train_log = solver.train(train_loader, model, criterion, optimizer)
        val_log = solver.validate(val_loader, model, criterion)
        
        # 没个epoch结束后scheduler更新一次
        if config['scheduler'] == 'CosineAnnealingLR':
            scheduler.step()
        elif config['scheduler'] == 'ReduceLROnPlateau':
            scheduler.step(val_log['loss'])

        # 打印每轮的指标结果
        print('loss %.4f - iou %.4f - val_loss %.4f - val_iou %.4f - val_dice %.4f - val_SE %.4f - val_PC %.4f - val_F1 %.4f - val_SP %.4f - val_ACC %.4f'
            % (train_log['loss'], train_log['iou'], val_log['loss'], val_log['iou'], val_log['dice'], val_log['SE'],
               val_log['PC'], val_log['F1'], val_log['SP'], val_log['ACC']))

        # 将本次的结果添加到log中
        log['epoch'].append(epoch)
        log['lr'].append(config['lr'])
        log['loss'].append(train_log['loss'])
        log['iou'].append(train_log['iou'])
        log['val_loss'].append(val_log['loss'])
        log['val_iou'].append(val_log['iou'])
        log['val_dice'].append(val_log['dice'])

        # 将每次的结果输出到csv文件中
        pd.DataFrame(log).to_csv('checkpoint/%s/log.csv' % config['name'], index=False)

        # trigger用来触发 early_stopping 机制
        trigger += 1

        if val_log['iou'] > best_iou:
            torch.save(model.state_dict(), 'checkpoint/%s/model.pth' % config['name'])
            best_iou = val_log['iou']
            print("=> saved best model")
            trigger = 0    # 如果模型训练出最好的结果就将trigger设置为0,就不会触发early_stopping机制。

        # early stopping
        # 如果设置的esrly_stopping>=0,且trigger经过不断+1操作后大于early_stopping就触发early_stopping机制。
        if config['early_stopping'] >= 0 and trigger >= config['early_stopping']:
            print("=> early stopping")
            break

        # 清空当前 CUDA 设备上的缓存空间
        torch.cuda.empty_cache()

这里是引用scheduler.step() 和 scheduler.step(val_log[‘loss’]) 在调整学习率时有一些区别。
scheduler.step(): 这个方法用于没有参数的学习率调度策略。它会根据预定义的规则(如余弦退火)来更新优化器的学习率。这种调度策略通常是按照固定的周期或时间表进行的,与具体训练数据的性能无关。
scheduler.step(val_log[‘loss’]): 这个方法用于在训练过程中基于验证集的损失值来调整学习率。它通常用于类似于 “ReduceLROnPlateau” 的学习率调度策略。在每个训练周期结束后,将验证集的损失值作为参数传递给 scheduler.step() 方法。然后,学习率调度器根据验证集的性能来决定是否要降低学习率,例如当验证集的损失值不再改善时降低学习率。
区别在于第二个方法根据具体的损失值动态地调整学习率,使其更加适应当前模型的收敛情况。而第一个方法通常是根据预定义的规则定期调整学习率。
需要根据具体的学习率调度策略和实际情况选择使用哪种方法。

1.4.1. solver.py-模型训练脚本

from collections import OrderedDict
import torch
from src.metrics import iou_score
from src.utils import AverageMeter


def train(train_loader, model, criterion, optimizer):
    avg_meters = {'loss': AverageMeter(),
                  'iou': AverageMeter()}
    model.train()

    for input, target, _ in train_loader:
        input = input.cuda()
        target = target.cuda()
        output = model(input)
        loss = criterion(output, target)
        iou, dice, _, _, _, _, _ = iou_score(output, target)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        avg_meters['loss'].update(loss.item(), input.size(0))
        avg_meters['iou'].update(iou, input.size(0))

    return OrderedDict([('loss', avg_meters['loss'].avg),
                        ('iou', avg_meters['iou'].avg)
                        ])


def validate(val_loader, model, criterion):
    avg_meters = {'loss': AverageMeter(),
                  'iou': AverageMeter(),
                   'dice': AverageMeter(),
                   'SE':AverageMeter(),
                   'PC':AverageMeter(),
                   'F1':AverageMeter(),
                   'SP':AverageMeter(),
                   'ACC':AverageMeter()
                   }

    # switch to evaluate mode
    model.eval()

    with torch.no_grad():
        for input, target, _ in val_loader:
            input = input.cuda()
            target = target.cuda()
            output = model(input)
            loss = criterion(output, target)
            iou, dice, SE, PC, F1, SP, ACC = iou_score(output, target)
            avg_meters['loss'].update(loss.item(), input.size(0))
            avg_meters['iou'].update(iou, input.size(0))
            avg_meters['dice'].update(dice, input.size(0))
            avg_meters['SE'].update(SE, input.size(0))
            avg_meters['PC'].update(PC, input.size(0))
            avg_meters['F1'].update(F1, input.size(0))
            avg_meters['SP'].update(SP, input.size(0))
            avg_meters['ACC'].update(ACC, input.size(0))

    return OrderedDict([('loss', avg_meters['loss'].avg),
                        ('iou', avg_meters['iou'].avg),
                        ('dice', avg_meters['dice'].avg),
                        ('SE', avg_meters['SE'].avg),
                        ('PC', avg_meters['PC'].avg),
                        ('F1', avg_meters['F1'].avg),
                        ('SP', avg_meters['SP'].avg),
                        ('ACC', avg_meters['ACC'].avg)
                        ])

2. 模型CMUNet.py

class CMUNet(nn.Module):
    def __init__(self, img_ch=3, output_ch=1, l=7, k=7):
        """
        Args:
            img_ch : input channel.
            output_ch: output channel.
            l: number of convMixer layers
            k: kernal size of convMixer
        """
        super(CMUNet, self).__init__()

        # Encoder
        self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Conv1 = conv_block(ch_in=img_ch, ch_out=64)
        self.Conv2 = conv_block(ch_in=64, ch_out=128)
        self.Conv3 = conv_block(ch_in=128, ch_out=256)
        self.Conv4 = conv_block(ch_in=256, ch_out=512)
        self.Conv5 = conv_block(ch_in=512, ch_out=1024)
        self.ConvMixer = ConvMixerBlock(dim=1024, depth=l, k=k)
        # Decoder
        self.Up5 = up_conv(ch_in=1024, ch_out=512)
        self.Up_conv5 = conv_block(ch_in=512 * 2, ch_out=512)
        self.Up4 = up_conv(ch_in=512, ch_out=256)
        self.Up_conv4 = conv_block(ch_in=256 * 2, ch_out=256)
        self.Up3 = up_conv(ch_in=256, ch_out=128)
        self.Up_conv3 = conv_block(ch_in=128 * 2, ch_out=128)
        self.Up2 = up_conv(ch_in=128, ch_out=64)
        self.Up_conv2 = conv_block(ch_in=64 * 2, ch_out=64)
        self.Conv_1x1 = nn.Conv2d(64, output_ch, kernel_size=1, stride=1, padding=0)
        # Skip-connection
        self.msag4 = MSAG(512)
        self.msag3 = MSAG(256)
        self.msag2 = MSAG(128)
        self.msag1 = MSAG(64)

    def forward(self, x):
        x1 = self.Conv1(x)

        x2 = self.Maxpool(x1)
        x2 = self.Conv2(x2)

        x3 = self.Maxpool(x2)
        x3 = self.Conv3(x3)

        x4 = self.Maxpool(x3)
        x4 = self.Conv4(x4)

        x5 = self.Maxpool(x4)
        x5 = self.Conv5(x5)
        x5 = self.ConvMixer(x5)

        x4 = self.msag4(x4)
        x3 = self.msag3(x3)
        x2 = self.msag2(x2)
        x1 = self.msag1(x1)

        d5 = self.Up5(x5)
        d5 = torch.cat((x4, d5), dim=1)
        d5 = self.Up_conv5(d5)

        d4 = self.Up4(d5)
        d4 = torch.cat((x3, d4), dim=1)
        d4 = self.Up_conv4(d4)

        d3 = self.Up3(d4)
        d3 = torch.cat((x2, d3), dim=1)
        d3 = self.Up_conv3(d3)

        d2 = self.Up2(d3)
        d2 = torch.cat((x1, d2), dim=1)
        d2 = self.Up_conv2(d2)
        d1 = self.Conv_1x1(d2)
        return d1

CMUNet模型可视化图:

【代码复现】CMU-Net进行语义分割_第1张图片

2.1. Multi-Scale Attention Gate模块

import torch.nn as nn
import torch

class MSAG(nn.Module):
    """
    Multi-scale attention gate
    """
    def __init__(self, channel):
        super(MSAG, self).__init__()
        self.channel = channel
        self.pointwiseConv = nn.Sequential(
            nn.Conv2d(self.channel, self.channel, kernel_size=1, padding=0, bias=True),
            nn.BatchNorm2d(self.channel),
        )
        self.ordinaryConv = nn.Sequential(
            nn.Conv2d(self.channel, self.channel, kernel_size=3, padding=1, stride=1, bias=True),
            nn.BatchNorm2d(self.channel),
        )
        self.dilationConv = nn.Sequential(
            nn.Conv2d(self.channel, self.channel, kernel_size=3, padding=2, stride=1, dilation=2, bias=True),
            nn.BatchNorm2d(self.channel),
        )
        self.voteConv = nn.Sequential(
            nn.Conv2d(self.channel * 3, self.channel, kernel_size=(1, 1)),
            nn.BatchNorm2d(self.channel),
            nn.Sigmoid()
        )
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x1 = self.pointwiseConv(x)
        x2 = self.ordinaryConv(x)
        x3 = self.dilationConv(x)
        _x = self.relu(torch.cat((x1, x2, x3), dim=1))
        _x = self.voteConv(_x)
        x = x + x * _x
        return x

Multi-Scale Attention Gate模块可视化图:

【代码复现】CMU-Net进行语义分割_第2张图片

2.2. ConvMixerBlock模块

class ConvMixerBlock(nn.Module):
    def __init__(self, dim=1024, depth=7, k=7):
        super(ConvMixerBlock, self).__init__()
        self.block = nn.Sequential(
            *[nn.Sequential(
                Residual(nn.Sequential(
                    # deep wise
                    nn.Conv2d(dim, dim, kernel_size=(k, k), groups=dim, padding=(k // 2, k // 2)),
                    nn.GELU(),
                    nn.BatchNorm2d(dim)
                )),
                nn.Conv2d(dim, dim, kernel_size=(1, 1)),
                nn.GELU(),
                nn.BatchNorm2d(dim)
            ) for i in range(depth)]
        )

    def forward(self, x):
        x = self.block(x)
        return x

ConvMixerBlock模块的可视化图:

【代码复现】CMU-Net进行语义分割_第3张图片

你可能感兴趣的:(模型代码解读,计算机视觉,深度学习,人工智能)