【图像分割】用pytorch实现SegNet算法

论文地址:https://arxiv.org/pdf/1511.00561.pdf
官方代码:https://github.com/alexgkendall/caffe-segnet
这篇博客中的代码可在github上找到:https://github.com/chen-zhoujian/SegNet-pytorch

编码器

VGG-16的前13层
BN层的momentum=0.1
下采样时保留最大值的位置
pytorch代码:

class Encoder(nn.Module):
    def __init__(self, input_channels):
        super(Encoder, self).__init__()

        self.enco1 = nn.Sequential(
            nn.Conv2d(input_channels, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64, momentum=bn_momentum),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64, momentum=bn_momentum),
            nn.ReLU()
        )
        self.enco2 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128, momentum=bn_momentum),
            nn.ReLU(),
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128, momentum=bn_momentum),
            nn.ReLU()
        )
        self.enco3 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256, momentum=bn_momentum),
            nn.ReLU(),
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256, momentum=bn_momentum),
            nn.ReLU(),
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256, momentum=bn_momentum),
            nn.ReLU()
        )
        self.enco4 = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512, momentum=bn_momentum),
            nn.ReLU(),
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512, momentum=bn_momentum),
            nn.ReLU(),
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512, momentum=bn_momentum),
            nn.ReLU()
        )
        self.enco5 = nn.Sequential(
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512, momentum=bn_momentum),
            nn.ReLU(),
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512, momentum=bn_momentum),
            nn.ReLU(),
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512, momentum=bn_momentum),
            nn.ReLU()
        )

    def forward(self, x):
        id = []

        x = self.enco1(x)
        x, id1 = F.max_pool2d(x, kernel_size=2, stride=2, return_indices=True)  # 保留最大值的位置
        id.append(id1)
        x = self.enco2(x)
        x, id2 = F.max_pool2d(x, kernel_size=2, stride=2, return_indices=True)
        id.append(id2)
        x = self.enco3(x)
        x, id3 = F.max_pool2d(x, kernel_size=2, stride=2, return_indices=True)
        id.append(id3)
        x = self.enco4(x)
        x, id4 = F.max_pool2d(x, kernel_size=2, stride=2, return_indices=True)
        id.append(id4)
        x = self.enco5(x)
        x, id5 = F.max_pool2d(x, kernel_size=2, stride=2, return_indices=True)
        id.append(id5)

        return x, id

编码器+解码器(SegNet)

解码器也是13层,对应编码器的13层
上采样时输入编码器下采样时保留的位置
论文中SegNet网络的最后加了一个分类层,不过因为损失函数选用了nn.CrossEntropyLoss(),所以便不用了
(详情参考:https://blog.csdn.net/zziahgf/article/details/80196376)
pytorch代码:

class SegNet(nn.Module):
    def __init__(self, input_channels, output_channels):
        super(SegNet, self).__init__()

        self.weights_new = self.state_dict()
        self.encoder = Encoder(input_channels)

        self.deco1 = nn.Sequential(
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512, momentum=bn_momentum),
            nn.ReLU(),
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512, momentum=bn_momentum),
            nn.ReLU(),
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512, momentum=bn_momentum),
            nn.ReLU()
        )
        self.deco2 = nn.Sequential(
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512, momentum=bn_momentum),
            nn.ReLU(),
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512, momentum=bn_momentum),
            nn.ReLU(),
            nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256, momentum=bn_momentum),
            nn.ReLU()
        )
        self.deco3 = nn.Sequential(
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256, momentum=bn_momentum),
            nn.ReLU(),
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256, momentum=bn_momentum),
            nn.ReLU(),
            nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128, momentum=bn_momentum),
            nn.ReLU()
        )
        self.deco4 = nn.Sequential(
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128, momentum=bn_momentum),
            nn.ReLU(),
            nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64, momentum=bn_momentum),
            nn.ReLU()
        )
        self.deco5 = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64, momentum=bn_momentum),
            nn.ReLU(),
            nn.Conv2d(64, output_channels, kernel_size=3, stride=1, padding=1),
        )

    def forward(self, x):
        x, id = self.encoder(x)

        x = F.max_unpool2d(x, id[4], kernel_size=2, stride=2)
        x = self.deco1(x)
        x = F.max_unpool2d(x, id[3], kernel_size=2, stride=2)
        x = self.deco2(x)
        x = F.max_unpool2d(x, id[2], kernel_size=2, stride=2)
        x = self.deco3(x)
        x = F.max_unpool2d(x, id[1], kernel_size=2, stride=2)
        x = self.deco4(x)
        x = F.max_unpool2d(x, id[0], kernel_size=2, stride=2)
        x = self.deco5(x)

        return x

加载预训练权重

编码器使用VGG-16的权重,解码器不用
因为编码器中不包括VGG-16后面3个全连接层,所以需要经过处理
附预训练权重下载地址:https://download.pytorch.org/models/vgg16_bn-6c64b313.pth
(这个load_weights函数是SegNet类的成员函数)
pytorch代码:

    def load_weights(self, weights_path):
        weights = torch.load(weights_path)
        del weights["classifier.0.weight"]
        del weights["classifier.0.bias"]
        del weights["classifier.3.weight"]
        del weights["classifier.3.bias"]
        del weights["classifier.6.weight"]
        del weights["classifier.6.bias"]

        names = []
        for key, value in self.encoder.state_dict().items():
            if "num_batches_tracked" in key:
                continue
            names.append(key)

        for name, dict in zip(names, weights.items()):
            self.weights_new[name] = dict[1]

        self.encoder.load_state_dict(self.weights_new)

以上是SegNet网络的搭建和编码器预训练权重的加载

构建MyDataset类

打开写有训练图片和标签路径的txt文件
(txt文件的创建往后看)
对训练图片和标签进行预处理
参考博客:https://blog.csdn.net/sinat_42239797/article/details/90641659
pytorch代码:

class MyDataset(Data.Dataset):
    def __init__(self, txt_path):
        super(MyDataset, self).__init__()

        paths = open(txt_path, "r")

        image_label = []
        for line in paths:
            line.rstrip("\n")
            line.lstrip("\n")
            path = line.split()
            image_label.append((path[0], path[1]))

        self.image_label = image_label

    def __getitem__(self, item):
        image, label = self.image_label[item]

        image = cv.imread(image)
        image = cv.resize(image, (224, 224))
        image = image/255.0  # 归一化输入
        image = torch.Tensor(image)
        image = image.permute(2, 0, 1)  # 将图片的维度转换成网络输入的维度(channel, width, height)

        label = cv.imread(label, 0)
        label = cv.resize(label, (224, 224))
        label = torch.Tensor(label)

        return image, label

    def __len__(self):
        return  len(self.image_label)

数据集的准备

选用cityscapes数据集,数据集的下载和介绍参考这篇博客:https://blog.csdn.net/avideointerfaces/article/details/104139298
数据集的预处理脚本:https://github.com/mcordts/cityscapesScripts
关于cityscapes数据集标签的处理可参考我写的这篇博客:https://blog.csdn.net/chenzhoujian_/article/details/106874950

txt文件的创建

将训练的图片和标签写入txt文件
注意下数据集文件的路径和训练集、验证集、测试集的数量
代码如下:

import glob

i = 0

def make_train_txt(num):
    global i
    paths = glob.glob("cityscapes\\leftImg8bit\\train\\*\\*")

    txt = open("train.txt", "w")

    for path in paths:
        data = path + " " + path.replace("leftImg8bit", "gtFine").replace("gtFine.png", "gtFine_labelTrainIds.png") + "\n"
        txt.write(data)
        i = i + 1
        if i == num:
            break

    i = 0
    txt.close()


def make_test_txt(num):
    global i
    paths = glob.glob("cityscapes\\leftImg8bit\\test\\*\\*")

    txt = open("test.txt", "w")

    for path in paths:
        data = path + " " + path.replace("leftImg8bit", "gtFine").replace("gtFine.png", "gtFine_labelTrainIds.png") + "\n"
        txt.write(data)
        i = i + 1
        if i == num:
            break

    i = 0
    txt.close()


def make_val_txt(num):
    global i
    paths = glob.glob("cityscapes\\leftImg8bit\\val\\*\\*")

    txt = open("val.txt", "w")

    for path in paths:
        data = path + " " + path.replace("leftImg8bit", "gtFine").replace("gtFine.png", "gtFine_labelTrainIds.png") + "\n"
        txt.write(data)
        i = i + 1
        if i == num:
            break

    i = 0
    txt.close()


train_num = 400
test_num = 100
val_num = 100

if True:
    make_train_txt(train_num)
if False:
    make_test_txt(test_num)
if False:
    make_val_txt(val_num)

txt文件里的内容:
【图像分割】用pytorch实现SegNet算法_第1张图片

训练部分

迭代次数、训练集数量、批训练大小、学习率等超参数可自行调节
(这个代码是需要调用GPU的,没有GPU的需要自行修改下)
pytorch代码:

from SegNet import *


def train(SegNet):

    SegNet = SegNet.cuda()
    SegNet.load_weights(PRE_TRAINING)

    train_loader = Data.DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)

    optimizer = torch.optim.SGD(SegNet.parameters(), lr=LR, momentum=MOMENTUM)

    loss_func = nn.CrossEntropyLoss(weight=torch.from_numpy(np.array(CATE_WEIGHT)).float()).cuda()

    SegNet.train()
    for epoch in range(EPOCH):
        for step, (b_x, b_y) in enumerate(train_loader):
            b_x = b_x.cuda()
            b_y = b_y.cuda()
            b_y = b_y.view(BATCH_SIZE, 224, 224)
            output = SegNet(b_x)
            loss = loss_func(output, b_y.long())
            loss = loss.cuda()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if step % 1 == 0:
                print("Epoch:{0} || Step:{1} || Loss:{2}".format(epoch, step, format(loss, ".4f")))

    torch.save(SegNet.state_dict(), WEIGHTS + "SegNet_weights" + str(time.time()) + ".pth")


parser = argparse.ArgumentParser()
parser.add_argument("--class_num", type=int, default=2, help="训练的类别的种类")
parser.add_argument("--epoch", type=int, default=4, help="训练迭代次数")
parser.add_argument("--batch_size", type=int, default=2, help="批训练大小")
parser.add_argument("--learning_rate", type=float, default=0.01, help="学习率大小")
parser.add_argument("--momentum", type=float, default=0.9)
parser.add_argument("--category_weight", type=float, default=[0.7502381287857225, 1.4990483912788268], help="损失函数中类别的权重")
parser.add_argument("--train_txt", type=str, default="train.txt", help="训练的图片和标签的路径")
parser.add_argument("--pre_training_weight", type=str, default="vgg16_bn-6c64b313.pth", help="编码器预训练权重路径")
parser.add_argument("--weights", type=str, default="./weights/", help="训练好的权重保存路径")
opt = parser.parse_args()
print(opt)

CLASS_NUM = opt.class_num
EPOCH = opt.epoch
BATCH_SIZE = opt.batch_size
LR = opt.learning_rate
MOMENTUM = opt.momentum
CATE_WEIGHT = opt.category_weight
TXT_PATH = opt.train_txt
PRE_TRAINING = opt.pre_training_weight
WEIGHTS = opt.weights


train_data = MyDataset(txt_path=TXT_PATH)

SegNet = SegNet(3, CLASS_NUM)
train(SegNet)

计算mIoU

mIoU的介绍可参考这篇博客:
https://blog.csdn.net/lingzhou33/article/details/87901365
pytorch代码:

from SegNet import *


parser = argparse.ArgumentParser()
parser.add_argument("--class_num", type=int, default=2, help="预测的类别的种类")
parser.add_argument("--weights", type=str, default="weights/SegNet_weights1592624668.4279704.pth", help="训练好的权重路径")
parser.add_argument("--val_paths", type=str, default="val.txt", help="验证集的图片和标签的路径")
opt = parser.parse_args()
print(opt)

CLASS_NUM = opt.class_num
WEIGHTS = opt.weights
VAL_PATHS = opt.val_paths


SegNet = SegNet(3, CLASS_NUM)
SegNet.load_state_dict(torch.load(WEIGHTS))
SegNet.eval()

paths = open(VAL_PATHS, "r")
mIoU = []
for index, line in enumerate(paths):
    line.rstrip("\n")
    line.lstrip("\n")
    path = line.split()

    image = cv.imread(path[0])
    image = cv.resize(image, (224, 224))
    image = image / 255.0  # 归一化输入
    image = torch.Tensor(image)
    image = image.permute(2, 0, 1)  # 将图片的维度转换成网络输入的维度(channel, width, height)
    image = torch.unsqueeze(image, dim=0)

    output = SegNet(image)
    output = torch.squeeze(output)
    output = output.argmax(dim=0)
    predict = cv.resize(np.uint8(output), (2048, 1024))

    label = cv.imread(path[1])
    target = label[:, :, 0]

    # 自己写的方法
    intersection = []
    union = []
    iou = 0
    for i in range(1, CLASS_NUM):
        intersection.append(np.sum(predict[target == i] == i))
        union.append(np.sum(predict == i) + np.sum(target == i) - intersection[i-1])
        iou += intersection[i-1]/union[i-1]

    # 用numpy库实现的方法
    # intersection = np.logical_and(target, predict)
    # union = np.logical_or(target, predict)
    # iou = np.sum(intersection) / np.sum(union)

    mIoU.append(iou/CLASS_NUM)
    print("miou_{0}:{1}".format(index, format(mIoU[index], ".4f")))

paths.close()

file = open("result.txt", "a")

print("\n")
print("mIoU:{}".format(format(np.mean(mIoU), ".4f")))

file.write("评价日期:" + str(time.asctime(time.localtime(time.time()))) + "\n")
file.write("使用的权重:" + WEIGHTS + "\n")
file.write("mIoU: " + str(format(np.mean(mIoU), ".4f")) + "\n")

file.close()

测试图片

pytorch代码:

from SegNet import *


def test(SegNet):

    SegNet.load_state_dict(torch.load(WEIGHTS))
    SegNet.eval()

    paths = os.listdir(SAMPLES)

    for path in paths:

        image_src = cv.imread(SAMPLES + path)
        image = cv.resize(image_src, (224, 224))

        image = image / 255.0
        image = torch.Tensor(image)
        image = image.permute(2, 0, 1)
        image = torch.unsqueeze(image, dim=0)

        output = SegNet(image)
        output = torch.squeeze(output)
        output = output.argmax(dim=0)
        output_np = cv.resize(np.uint8(output), (2048, 1024))

        image_seg = np.zeros((1024, 2048, 3))
        image_seg = np.uint8(image_seg)

        colors = COLORS

        for c in range(CLASS_NUM):
            image_seg[:, :, 0] += np.uint8((output_np == c)) * np.uint8(colors[c][0])
            image_seg[:, :, 1] += np.uint8((output_np == c)) * np.uint8(colors[c][1])
            image_seg[:, :, 2] += np.uint8((output_np == c)) * np.uint8(colors[c][2])

        image_seg = Image.fromarray(np.uint8(image_seg))
        old_image = Image.fromarray(np.uint8(image_src))

        image = Image.blend(old_image, image_seg, 0.3)

        # 将背景或空类去掉
        image_np = np.array(image)
        image_np[output_np == 0] = image_src[output_np == 0]
        image = Image.fromarray(image_np)
        image.save(OUTPUTS + path)

        print(path + " is done!")


parser = argparse.ArgumentParser()
parser.add_argument("--class_num", type=int, default=2, help="预测的类别的种类")
parser.add_argument("--weights", type=str, default="weights/SegNet_weights1592624668.4279704.pth", help="训练好的权重路径")
parser.add_argument("--colors", type=int, default=[[0, 0, 0], [0, 255, 0]], help="类别覆盖的颜色")
parser.add_argument("--samples", type=str, default="samples//", help="用于测试的图片文件夹的路径")
parser.add_argument("--outputs", type=str, default="outputs//", help="保存结果的文件夹的路径")
opt = parser.parse_args()
print(opt)

CLASS_NUM = opt.class_num
WEIGHTS = opt.weights
COLORS = opt.colors
SAMPLES = opt.samples
OUTPUTS = opt.outputs


SegNet = SegNet(3, CLASS_NUM)
test(SegNet)

测试效果

测试两类,road和其它

计算类不均衡权重

注意下类别的数量
计算出的权重用于损失函数
详情参考这篇博客:https://blog.csdn.net/fanzonghao/article/details/85263553

代码:

import cv2 as cv
import numpy as np

paths = open("train.txt", "r")

CLASS_NUM = 2
SUM = [[] for i in range(CLASS_NUM)]
SUM_ = 0

for line in paths:
    line.rstrip("\n")
    line.lstrip("\n")
    path = line.split()
    img = cv.imread(path[1], 0)
    img_np = np.array(img)
    for i in range(CLASS_NUM):
        SUM[i].append(np.sum((img_np == i)))


for index, iter in enumerate(SUM):
    print("类别{}的数量:".format(index), sum(iter))


for iter in SUM:
    SUM_ += sum(iter)

median = 1/CLASS_NUM

for index, iter in enumerate(SUM):
    print("weight_{}:".format(index), median/(sum(iter)/SUM_))

你可能感兴趣的:(项目实践)