本文主要介绍使用pytorch实现最基本的语义分割模型FCN,主要介绍训练脚本和模型脚本文件。
我安装的pytorch环境是1.4版本,使用的数据集是cityscapes数据集,下面看一下几个重要的代码块:
if __name__ == '__main__':
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
args = get_args()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logging.info(f'Using device {device}')
torch.cuda.set_device(1)
net = FCN(n_channels=3, n_classes=4, bilinear=True)
logging.info(f'Network:\n'
f'\t{net.n_channels} input channels\n'
f'\t{net.n_classes} output channels (classes)\n'
f'\t{"Bilinear" if net.bilinear else "Transposed conv"} upscaling')
if args.load:
net.load_state_dict(
torch.load(args.load, map_location=device)
)
logging.info(f'Model loaded from {args.load}')
net.to(device=device)
train_net(net=net,
epochs=args.epochs,
batch_size=args.batchsize,
lr=args.lr,
device=device,
img_scale=args.scale,
val_percent=args.val / 100)
使用logging打印出代码运行的日志;
使用argparse模块获取需要用户输入的参数;
设置运行的设备,我此处使用了GPU,可通过set_device()设置GPU号;
net = FCN(n_channels=3, n_classes=4, bilinear=True)表示使用了FCN这个网络模型,n_channels表示输入的是3通道的图像,n_classes表示输出的类别是多少,bilinear=True表示在模型的上采样插值过程中使用了双线性插值;
args.load 表示是否要导入预训练的模型;
net.to(device=device) 将模型搬移到cuda上执行运算;
train_net()表示超参数的设置;
def get_args():
parser = argparse.ArgumentParser(description='Train the FCN ',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('-e', '--epochs', metavar='E', type=int, default=5,
help='Number of epochs', dest='epochs')
parser.add_argument('-b', '--batch-size', metavar='B', type=int, nargs='?', default=1,
help='Batch size', dest='batchsize')
parser.add_argument('-l', '--learning-rate', metavar='LR', type=float, nargs='?', default=0.1,
help='Learning rate', dest='lr')
parser.add_argument('-f', '--load', dest='load', type=str, default=False,
help='Load model from a .pth file')
parser.add_argument('-s', '--scale', dest='scale', type=float, default=1,
help='Downscaling factor of the images')
parser.add_argument('-v', '--validation', dest='val', type=float, default=10.0,
help='Percent of the data that is used as validation (0-100)')
return parser.parse_args()
表示设置参数的函数;
import setproctitle
setproctitle.setproctitle("xxx")
可通过安装setproctitle这个工具来设置代码运行时候的任务名称,以免误删任务。
dataset = BasicDataset(dir_img, dir_mask, img_scale)
n_val = int(len(dataset) * val_percent)
n_train = len(dataset) - n_val
train, val = random_split(dataset, [n_train, n_val])
train_loader = DataLoader(train, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True, drop_last=True)
对数据进行训练集和验证集的划分,以及利用dataloader导入数据;
optimizer = optim.RMSprop(net.parameters(), lr=lr, weight_decay=1e-8, momentum=0.9)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min' if net.n_classes > 1 else 'max', patience=2)
设置训练网络的优化器
criterion = nn.CrossEntropyLoss()
设置loss函数,在语义分割中,最常用的就是交叉熵损失函数;
optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_value_(net.parameters(), 0.1)
optimizer.step()
梯度初始化和loss反传过程;
if save:
try:
os.mkdir(dir_checkpoint)
logging.info('Created checkpoint directory')
except OSError:
pass
torch.save(net.state_dict(),
dir_checkpoint + f'epoch{epoch + 1}.pth')
logging.info(f'Checkpoint {epoch + 1} saved !')
模型保存的过程。
class FCN(nn.Module):
def __init__(self, n_channels, n_classes, bilinear=True):
super(FCN, self).__init__()
self.n_channels = n_channels
self.n_classes = n_classes
self.bilinear = bilinear
in_channels = 16
in_channels_2 = 32
in_channels_3 = 64
in_channels_4 = 128
self.conv1 = nn.Sequential(nn.Conv2d(n_channels, in_channels, 3, padding=1, bias=False),
nn.ReLU(),
nn.Conv2d(in_channels, in_channels, 3, padding=1, bias=False),
nn.ReLU(),
nn.MaxPool2d(2, stride=2, padding=0, return_indices=False, ceil_mode=False))
self.conv2 = nn.Sequential(nn.Conv2d(in_channels, in_channels_2, 3, padding=1, bias=False),
nn.ReLU(),
nn.Conv2d(in_channels_2, in_channels_2, 3, padding=1, bias=False),
nn.ReLU(),
nn.MaxPool2d(2, stride=2, padding=0, return_indices=False, ceil_mode=False))
self.conv3 = nn.Sequential(nn.Conv2d(in_channels_2, in_channels_2, 3, padding=1, bias=False),
nn.ReLU(),
nn.Conv2d(in_channels_2, in_channels_2, 3, padding=1, bias=False),
nn.ReLU(),
nn.Conv2d(in_channels_2, in_channels_2, 3, padding=1, bias=False),
nn.ReLU(),
nn.MaxPool2d(2, stride=2, padding=0, return_indices=False, ceil_mode=False))
self.conv4 = nn.Sequential(nn.Conv2d(in_channels_2, in_channels_3, 3, padding=1, bias=False),
nn.ReLU(),
nn.Conv2d(in_channels_3, in_channels_3, 3, padding=1, bias=False),
nn.ReLU(),
nn.Conv2d(in_channels_3, in_channels_3, 3, padding=1, bias=False),
nn.ReLU(),
nn.MaxPool2d(2, stride=2, padding=0, return_indices=False, ceil_mode=False))
self.conv5 = nn.Sequential(nn.Conv2d(in_channels_3, in_channels_3, 3, padding=1, bias=False),
nn.ReLU(),
nn.Conv2d(in_channels_3, in_channels_3, 3, padding=1, bias=False),
nn.ReLU(),
nn.Conv2d(in_channels_3, in_channels_3, 3, padding=1, bias=False),
nn.ReLU(),
nn.MaxPool2d(2, stride=2, padding=0, return_indices=False, ceil_mode=False))
self.conv6 = nn.Sequential(nn.Conv2d(in_channels_3, in_channels_4, 3, padding=1, bias=False),
nn.ReLU(),
nn.MaxPool2d(2, stride=2, padding=0, return_indices=False, ceil_mode=False),
nn.Conv2d(in_channels_4, in_channels_4, 7, padding=3, bias=False),
nn.ReLU(),
nn.Dropout2d(0.5),
nn.Conv2d(in_channels_4, in_channels_4, 1, padding=0, bias=False),
nn.ReLU(),
nn.Dropout2d(0.5),
nn.Conv2d(in_channels_4, n_classes, 1, padding=0, bias=False))
self.conv7 = nn.Conv2d(in_channels_3, n_classes, 3, padding=1, bias=False)
self.conv8 = nn.Conv2d(in_channels_3, n_classes, 3, padding=1, bias=False)
self.conv9 = nn.Conv2d(in_channels_2, n_classes, 3, padding=1, bias=False)
self.conv10 = nn.Conv2d(in_channels_2, n_classes, 3, padding=1, bias=False)
self.conv11 = nn.Conv2d(in_channels, n_classes, 3, padding=1, bias=False)
def forward(self, x):
x1 = self.conv1(x)
x2 = self.conv2(x1)
x3 = self.conv3(x2)
x4 = self.conv4(x3)
x5 = self.conv5(x4)
x6 = self.conv6(x5)
x7 = interpolate(x6, size=None, scale_factor=2, mode='bilinear', align_corners=True)
x5 = self.conv7(x5)
x7 = x7 + x5
x7 = interpolate(x7, size=None, scale_factor=2, mode='bilinear', align_corners=True)
x4 = self.conv8(x4)
x7 = x7 + x4
x7 = interpolate(x7, size=None, scale_factor=2, mode='bilinear', align_corners=True)
x3 = self.conv9(x3)
x7 = x7 + x3
x7 = interpolate(x7, size=None, scale_factor=2, mode='bilinear', align_corners=True)
x2 = self.conv10(x2)
x7 = x7 + x2
x7 = interpolate(x7, size=None, scale_factor=2, mode='bilinear',align_corners=True)
x1 = self.conv11(x1)
x7 = x7 + x1
x7 = interpolate(x7, size=None, scale_factor=2, mode='bilinear',align_corners=True)
return x7
这里对原版的全卷积神经网络FCN做了简单的优化,增加了更多的特征融合过程,从而将低层特征和高层特征进行了更加有效的融合。