使用pytorch实现语义分割模型FCN

本文主要介绍使用pytorch实现最基本的语义分割模型FCN,主要介绍训练脚本和模型脚本文件。
我安装的pytorch环境是1.4版本,使用的数据集是cityscapes数据集,下面看一下几个重要的代码块:

if __name__ == '__main__':
    logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
    args = get_args()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logging.info(f'Using device {device}')
    torch.cuda.set_device(1)
    
    net = FCN(n_channels=3, n_classes=4, bilinear=True)
    logging.info(f'Network:\n'
                 f'\t{net.n_channels} input channels\n'
                 f'\t{net.n_classes} output channels (classes)\n'
                 f'\t{"Bilinear" if net.bilinear else "Transposed conv"} upscaling')

    if args.load:
        net.load_state_dict(
            torch.load(args.load, map_location=device)
        )
        logging.info(f'Model loaded from {args.load}')

    net.to(device=device)
    train_net(net=net,
                  epochs=args.epochs,
                  batch_size=args.batchsize,
                  lr=args.lr,
                  device=device,
                  img_scale=args.scale,
                  val_percent=args.val / 100)

使用logging打印出代码运行的日志;
使用argparse模块获取需要用户输入的参数;
设置运行的设备,我此处使用了GPU,可通过set_device()设置GPU号;
net = FCN(n_channels=3, n_classes=4, bilinear=True)表示使用了FCN这个网络模型,n_channels表示输入的是3通道的图像,n_classes表示输出的类别是多少,bilinear=True表示在模型的上采样插值过程中使用了双线性插值;
args.load 表示是否要导入预训练的模型;
net.to(device=device) 将模型搬移到cuda上执行运算;
train_net()表示超参数的设置;

def get_args():
    parser = argparse.ArgumentParser(description='Train the FCN ',
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('-e', '--epochs', metavar='E', type=int, default=5,
                        help='Number of epochs', dest='epochs')
    parser.add_argument('-b', '--batch-size', metavar='B', type=int, nargs='?', default=1,
                        help='Batch size', dest='batchsize')
    parser.add_argument('-l', '--learning-rate', metavar='LR', type=float, nargs='?', default=0.1,
                        help='Learning rate', dest='lr')
    parser.add_argument('-f', '--load', dest='load', type=str, default=False,
                        help='Load model from a .pth file')
    parser.add_argument('-s', '--scale', dest='scale', type=float, default=1,
                        help='Downscaling factor of the images')
    parser.add_argument('-v', '--validation', dest='val', type=float, default=10.0,
                        help='Percent of the data that is used as validation (0-100)')

    return parser.parse_args()

表示设置参数的函数;

import setproctitle

setproctitle.setproctitle("xxx")

可通过安装setproctitle这个工具来设置代码运行时候的任务名称,以免误删任务。

dataset = BasicDataset(dir_img, dir_mask, img_scale)
n_val = int(len(dataset) * val_percent)
n_train = len(dataset) - n_val
train, val = random_split(dataset, [n_train, n_val])
train_loader = DataLoader(train, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True, drop_last=True)

对数据进行训练集和验证集的划分,以及利用dataloader导入数据;

optimizer = optim.RMSprop(net.parameters(), lr=lr, weight_decay=1e-8, momentum=0.9)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min' if net.n_classes > 1 else 'max', patience=2)

设置训练网络的优化器

criterion = nn.CrossEntropyLoss()

设置loss函数,在语义分割中,最常用的就是交叉熵损失函数;

optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_value_(net.parameters(), 0.1)
optimizer.step()

梯度初始化和loss反传过程;

if save:
       try:
            os.mkdir(dir_checkpoint)
            logging.info('Created checkpoint directory')
        except OSError:
            pass
        torch.save(net.state_dict(),
                   dir_checkpoint + f'epoch{epoch + 1}.pth')
        logging.info(f'Checkpoint {epoch + 1} saved !')

模型保存的过程。

class FCN(nn.Module):

    def __init__(self, n_channels, n_classes, bilinear=True):
        super(FCN, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear
       
        
        in_channels = 16
        in_channels_2 = 32
        in_channels_3 = 64
        in_channels_4 = 128 
       

        self.conv1 = nn.Sequential(nn.Conv2d(n_channels, in_channels, 3, padding=1, bias=False),
                                       nn.ReLU(),
                                       nn.Conv2d(in_channels, in_channels, 3, padding=1, bias=False),
                                       nn.ReLU(),
                                       nn.MaxPool2d(2, stride=2, padding=0, return_indices=False, ceil_mode=False))
        self.conv2 = nn.Sequential(nn.Conv2d(in_channels, in_channels_2, 3, padding=1, bias=False),
                                       nn.ReLU(),
                                       nn.Conv2d(in_channels_2, in_channels_2, 3, padding=1, bias=False),
                                       nn.ReLU(),
                                       nn.MaxPool2d(2, stride=2, padding=0, return_indices=False, ceil_mode=False))                               
        self.conv3 = nn.Sequential(nn.Conv2d(in_channels_2, in_channels_2, 3, padding=1, bias=False),
                                       nn.ReLU(),
                                       nn.Conv2d(in_channels_2, in_channels_2, 3, padding=1, bias=False),
                                       nn.ReLU(),
                                       nn.Conv2d(in_channels_2, in_channels_2, 3, padding=1, bias=False),
                                       nn.ReLU(),
                                       nn.MaxPool2d(2, stride=2, padding=0, return_indices=False, ceil_mode=False))
        self.conv4 = nn.Sequential(nn.Conv2d(in_channels_2, in_channels_3, 3, padding=1, bias=False),
                                       nn.ReLU(),
                                       nn.Conv2d(in_channels_3, in_channels_3, 3, padding=1, bias=False),
                                       nn.ReLU(),
                                       nn.Conv2d(in_channels_3, in_channels_3, 3, padding=1, bias=False),
                                       nn.ReLU(),
                                       nn.MaxPool2d(2, stride=2, padding=0, return_indices=False, ceil_mode=False)) 

        self.conv5 = nn.Sequential(nn.Conv2d(in_channels_3, in_channels_3, 3, padding=1, bias=False),
                                       nn.ReLU(),
                                       nn.Conv2d(in_channels_3, in_channels_3, 3, padding=1, bias=False),
                                       nn.ReLU(),
                                       nn.Conv2d(in_channels_3, in_channels_3, 3, padding=1, bias=False),
                                       nn.ReLU(),
                                       nn.MaxPool2d(2, stride=2, padding=0, return_indices=False, ceil_mode=False))                                                               
        self.conv6 = nn.Sequential(nn.Conv2d(in_channels_3, in_channels_4, 3, padding=1, bias=False),
                                       nn.ReLU(),
                                       nn.MaxPool2d(2, stride=2, padding=0, return_indices=False, ceil_mode=False),
                                       nn.Conv2d(in_channels_4, in_channels_4, 7, padding=3, bias=False),
                                       nn.ReLU(),
                                       nn.Dropout2d(0.5),
                                       nn.Conv2d(in_channels_4, in_channels_4, 1, padding=0, bias=False),
                                       nn.ReLU(),
                                       nn.Dropout2d(0.5),
                                       nn.Conv2d(in_channels_4, n_classes, 1, padding=0, bias=False))  
        self.conv7 = nn.Conv2d(in_channels_3, n_classes, 3, padding=1, bias=False)
        self.conv8 = nn.Conv2d(in_channels_3, n_classes, 3, padding=1, bias=False)
        self.conv9 = nn.Conv2d(in_channels_2, n_classes, 3, padding=1, bias=False)
        self.conv10 = nn.Conv2d(in_channels_2, n_classes, 3, padding=1, bias=False)
        self.conv11 = nn.Conv2d(in_channels, n_classes, 3, padding=1, bias=False)

                                                                                                      
        
    def forward(self, x):
        
        x1 = self.conv1(x)
        x2 = self.conv2(x1)
        x3 = self.conv3(x2)
        x4 = self.conv4(x3)
        x5 = self.conv5(x4)
        x6 = self.conv6(x5)
        x7 = interpolate(x6, size=None, scale_factor=2, mode='bilinear', align_corners=True)
       
        x5 = self.conv7(x5)
        x7 = x7 + x5
        x7 = interpolate(x7, size=None, scale_factor=2, mode='bilinear', align_corners=True)
        x4 = self.conv8(x4)
        
        x7 = x7 + x4
        x7 = interpolate(x7, size=None, scale_factor=2,  mode='bilinear', align_corners=True)
        x3 = self.conv9(x3)
        x7 = x7 + x3
        x7 = interpolate(x7, size=None, scale_factor=2, mode='bilinear', align_corners=True)
        x2 = self.conv10(x2)
        x7 = x7 + x2
        
        x7 = interpolate(x7, size=None, scale_factor=2, mode='bilinear',align_corners=True)
        x1 = self.conv11(x1)
        x7 = x7 + x1
        x7 = interpolate(x7, size=None, scale_factor=2, mode='bilinear',align_corners=True)
                
        return x7 
       

这里对原版的全卷积神经网络FCN做了简单的优化,增加了更多的特征融合过程,从而将低层特征和高层特征进行了更加有效的融合。

你可能感兴趣的:(pytorch实战)