数据集:refuge数据集
训练轮数:10
Architecture |
dice coefficient | mean IOU |
unet | 0.946 | 52.6 |
sk-unet | 0.989 | 66.1 |
cbam-unet | 0.988 | 65.8 |
(1)在UNet最后的输出卷积前添加,SK模块
训练结果:
[epoch: 9]
train_loss: 0.0710
lr: 0.009592
dice coefficient: 0.989
global correct: 97.8
average row correct: ['47.3', '99.1']
IoU: ['34.4', '97.8']
mean IoU: 66.1
模型改动:
from typing import Dict
import torch
import torch.nn as nn
import torch.nn.functional as F
# sk-unet
'''------------- SK模块-----------------------------'''
class SKConv(nn.Module):
def __init__(self, features, WH, M, G, r, stride=1, L=32):
""" Constructor
Args:
features: 输入通道维度
WH: 输入特征图的空间维度
M: 分支的数量
G: 卷积组的数量
r: 计算d,向量s的压缩倍数,C/r
stride: 步长,默认为1
L: 矢量z的最小维度,默认为32
"""
super(SKConv, self).__init__()
d = max(int(features / r), L)
self.M = M
self.features = features
self.convs = nn.ModuleList([])
# 使用不同kernel size的卷积,增加不同的感受野
for i in range(M):
self.convs.append(nn.Sequential(
nn.Conv2d(features, features, kernel_size=3 + i * 2, stride=stride, padding=1 + i, groups=G),
nn.BatchNorm2d(features),
nn.ReLU(inplace=False)
))
# 全局平均池化
self.gap = nn.AvgPool2d(int(WH / stride))
self.fc = nn.Linear(features, d)
self.fcs = nn.ModuleList([])
# 全连接层
for i in range(M):
self.fcs.append(
nn.Linear(d, features)
)
self.softmax = nn.Softmax(dim=1)
def forward(self, x):
''' Split操作'''
for i, conv in enumerate(self.convs):
fea = conv(x).unsqueeze_(dim=1)
if i == 0:
feas = fea
else:
feas = torch.cat([feas, fea], dim=1)
''' Fuse操作'''
fea_U = torch.sum(feas, dim=1)
fea_s = self.gap(fea_U).squeeze_()
fea_z = self.fc(fea_s)
''' Select操作'''
for i, fc in enumerate(self.fcs):
# fc-->d*c维
vector = fc(fea_z).unsqueeze_(dim=1)
if i == 0:
attention_vectors = vector
else:
attention_vectors = torch.cat([attention_vectors, vector], dim=1)
# 计算attention权重
attention_vectors = self.softmax(attention_vectors)
attention_vectors = attention_vectors.unsqueeze(-1).unsqueeze(-1)
# 最后一步,各特征图与对应的注意力权重相乘,得到输出特征图V
fea_v = (feas * attention_vectors).sum(dim=1)
return fea_v
# 卷积,在uent中卷积一般成对使用
class DoubleConv(nn.Sequential):
# 输入通道数, 输出通道数, mid_channels为成对卷积中第一个卷积层的输出通道数
def __init__(self, in_channels, out_channels, mid_channels=None):
if mid_channels is None:
mid_channels = out_channels
super(DoubleConv, self).__init__(
# 3*3卷积,填充为1,卷积之后输入输出的特征图大小一致
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(mid_channels),
nn.ReLU(inplace=True),
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
# 下采样
class Down(nn.Sequential):
def __init__(self, in_channels, out_channels):
super(Down, self).__init__(
# 1.最大池化的窗口大小为2, 步长为2
nn.MaxPool2d(2, stride=2),
# 2.两个卷积
DoubleConv(in_channels, out_channels)
)
# 上采样
class Up(nn.Module):
# bilinear是否采用双线性插值
def __init__(self, in_channels, out_channels, bilinear=True):
super(Up, self).__init__()
if bilinear:
# 使用双线性插值上采样
# 上采样率为2,双线性插值模式
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
else:
# 使用转置卷积上采样
self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
self.conv = DoubleConv(in_channels, out_channels)
def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
x1 = self.up(x1)
# [N, C, H, W]
# 上采样之后的特征图与要拼接的特征图,高度方向的差值
diff_y = x2.size()[2] - x1.size()[2]
# 上采样之后的特征图与要拼接的特征图,宽度方向的差值
diff_x = x2.size()[3] - x1.size()[3]
# padding_left, padding_right, padding_top, padding_bottom
# 1.填充差值
x1 = F.pad(x1, [diff_x // 2, diff_x - diff_x // 2,
diff_y // 2, diff_y - diff_y // 2])
# 2.拼接
x = torch.cat([x2, x1], dim=1)
# 3.卷积,两次卷积
x = self.conv(x)
return x
# 最后的1*1输出卷积
class OutConv(nn.Sequential):
def __init__(self, in_channels, num_classes):
super(OutConv, self).__init__(
nn.Conv2d(in_channels, num_classes, kernel_size=1)
)
class UNet(nn.Module):
# 参数: 输入通道数, 分割任务个数, 是否使用双线插值, 网络中第一个卷积通道个数
def __init__(self,
in_channels: int = 1,
num_classes: int = 2,
bilinear: bool = True,
base_c: int = 64):
super(UNet, self).__init__()
self.in_channels = in_channels
self.num_classes = num_classes
self.bilinear = bilinear
self.in_conv = DoubleConv(in_channels, base_c)
# 下采样,参数:输入通道,输出通道
self.down1 = Down(base_c, base_c * 2)
self.down2 = Down(base_c * 2, base_c * 4)
self.down3 = Down(base_c * 4, base_c * 8)
# 如果采用双线插值上采样为 2,采用转置矩阵上采样为 1
factor = 2 if bilinear else 1
# 最后一个下采样,如果是双线插值则输出通道为512,否则为1024
self.down4 = Down(base_c * 8, base_c * 16 // factor)
# 上采样,参数:输入通道,输出通道
self.up1 = Up(base_c * 16, base_c * 8 // factor, bilinear)
self.up2 = Up(base_c * 8, base_c * 4 // factor, bilinear)
self.up3 = Up(base_c * 4, base_c * 2 // factor, bilinear)
self.up4 = Up(base_c * 2, base_c, bilinear)
# 最后的1*1输出卷积
self.out_conv = OutConv(base_c, num_classes)
# # sk模块
self.sk = SKConv(base_c, 480, 2, 1, 2)
# 正向传播过程
def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
# 1. 定义最开始的两个卷积层
x1 = self.in_conv(x)
# 2. contracting path(收缩路径)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
# 3. expanding path(扩展路径)
x = self.up1(x5, x4)
x = self.up2(x, x3)
x = self.up3(x, x2)
x = self.up4(x, x1)
# sk模块
x = self.sk(x)
# 4. 最后1*1输出卷积
logits = self.out_conv(x)
return {"out": logits}
(2)在UNet最后的输出卷积后添加,CBAM模块
训练结果:
[epoch: 9] train_loss: 0.2040 lr: 0.000000 dice coefficient: 0.988 global correct: 97.7 average row correct: ['48.6', '98.9'] IoU: ['33.8', '97.7'] mean IoU: 65.8
模型改动:
import os
import time
import datetime
import torch
from src import UNet
from train_utils import train_one_epoch, evaluate, create_lr_scheduler
from my_dataset import DriveDataset
import transforms as T
class SegmentationPresetTrain:
def __init__(self, base_size, crop_size, hflip_prob=0.5, vflip_prob=0.5,
mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
min_size = int(0.5 * base_size)
max_size = int(1.2 * base_size)
trans = [T.RandomResize(min_size, max_size)]
if hflip_prob > 0:
trans.append(T.RandomHorizontalFlip(hflip_prob))
if vflip_prob > 0:
trans.append(T.RandomVerticalFlip(vflip_prob))
trans.extend([
T.RandomCrop(crop_size),
T.ToTensor(),
T.Normalize(mean=mean, std=std),
])
self.transforms = T.Compose(trans)
def __call__(self, img, target):
return self.transforms(img, target)
class SegmentationPresetEval:
def __init__(self, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
self.transforms = T.Compose([
T.ToTensor(),
T.Normalize(mean=mean, std=std),
])
def __call__(self, img, target):
return self.transforms(img, target)
def get_transform(train, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
base_size = 565
crop_size = 480
if train:
return SegmentationPresetTrain(base_size, crop_size, mean=mean, std=std)
else:
return SegmentationPresetEval(mean=mean, std=std)
# 传入参数,创建模型
def create_model(num_classes):
model = UNet(in_channels=3, num_classes=num_classes, base_c=32)
return model
def main(args):
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
batch_size = args.batch_size
# segmentation nun_classes + background
num_classes = args.num_classes + 1
# using compute_mean_std.py
mean = (0.709, 0.381, 0.224)
std = (0.127, 0.079, 0.043)
# 用来保存训练以及验证过程中信息
results_file = "results{}.txt".format(datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
train_dataset = DriveDataset(args.data_path,
train=True,
transforms=get_transform(train=True, mean=mean, std=std))
val_dataset = DriveDataset(args.data_path,
train=False,
transforms=get_transform(train=False, mean=mean, std=std))
num_workers = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])
train_loader = torch.utils.data.DataLoader(train_dataset,
batch_size=batch_size,
num_workers=num_workers,
shuffle=True,
pin_memory=True,
collate_fn=train_dataset.collate_fn)
val_loader = torch.utils.data.DataLoader(val_dataset,
batch_size=1,
num_workers=num_workers,
pin_memory=True,
collate_fn=val_dataset.collate_fn)
model = create_model(num_classes=num_classes)
model.to(device)
params_to_optimize = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(
params_to_optimize,
lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay
)
scaler = torch.cuda.amp.GradScaler() if args.amp else None
# 创建学习率更新策略,这里是每个step更新一次(不是每个epoch)
lr_scheduler = create_lr_scheduler(optimizer, len(train_loader), args.epochs, warmup=True)
if args.resume:
checkpoint = torch.load(args.resume, map_location='cpu')
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
args.start_epoch = checkpoint['epoch'] + 1
if args.amp:
scaler.load_state_dict(checkpoint["scaler"])
best_dice = 0.
start_time = time.time()
for epoch in range(args.start_epoch, args.epochs):
mean_loss, lr = train_one_epoch(model, optimizer, train_loader, device, epoch, num_classes,
lr_scheduler=lr_scheduler, print_freq=args.print_freq, scaler=scaler)
confmat, dice = evaluate(model, val_loader, device=device, num_classes=num_classes)
val_info = str(confmat)
print(val_info)
print(f"dice coefficient: {dice:.3f}")
# write into txt
with open(results_file, "a") as f:
# 记录每个epoch对应的train_loss、lr以及验证集各指标
train_info = f"[epoch: {epoch}]\n" \
f"train_loss: {mean_loss:.4f}\n" \
f"lr: {lr:.6f}\n" \
f"dice coefficient: {dice:.3f}\n"
f.write(train_info + val_info + "\n\n")
if args.save_best is True:
if best_dice < dice:
best_dice = dice
else:
continue
save_file = {"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
"lr_scheduler": lr_scheduler.state_dict(),
"epoch": epoch,
"args": args}
if args.amp:
save_file["scaler"] = scaler.state_dict()
if args.save_best is True:
torch.save(save_file, "save_weights/best_model.pth")
else:
torch.save(save_file, "save_weights/model_{}.pth".format(epoch))
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print("training time {}".format(total_time_str))
def parse_args():
import argparse
parser = argparse.ArgumentParser(description="pytorch unet training")
parser.add_argument("--data-path", default="./", help="DRIVE2 root")
# exclude background
parser.add_argument("--num-classes", default=1, type=int)
parser.add_argument("--device", default="cuda", help="training device")
parser.add_argument("-b", "--batch-size", default=4, type=int)
parser.add_argument("--epochs", default=10, type=int, metavar="N",
help="number of total epochs to train")
parser.add_argument('--lr', default=0.01, type=float, help='initial learning rate')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
metavar='W', help='weight decay (default: 1e-4)',
dest='weight_decay')
parser.add_argument('--print-freq', default=1, type=int, help='print frequency')
parser.add_argument('--resume', default='', help='resume from checkpoint')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
help='start epoch')
parser.add_argument('--save-best', default=True, type=bool, help='only save best dice weights')
# Mixed precision training parameters
parser.add_argument("--amp", default=False, type=bool,
help="Use torch.cuda.amp for mixed precision training")
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
if not os.path.exists("./save_weights"):
os.mkdir("./save_weights")
main(args)