论文地址:https://arxiv.org/pdf/1511.00561.pdf
官方代码:https://github.com/alexgkendall/caffe-segnet
这篇博客中的代码可在github上找到:https://github.com/chen-zhoujian/SegNet-pytorch
VGG-16的前13层
BN层的momentum=0.1
下采样时保留最大值的位置
pytorch代码:
class Encoder(nn.Module):
def __init__(self, input_channels):
super(Encoder, self).__init__()
self.enco1 = nn.Sequential(
nn.Conv2d(input_channels, 64, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(64, momentum=bn_momentum),
nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(64, momentum=bn_momentum),
nn.ReLU()
)
self.enco2 = nn.Sequential(
nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(128, momentum=bn_momentum),
nn.ReLU(),
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(128, momentum=bn_momentum),
nn.ReLU()
)
self.enco3 = nn.Sequential(
nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(256, momentum=bn_momentum),
nn.ReLU(),
nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(256, momentum=bn_momentum),
nn.ReLU(),
nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(256, momentum=bn_momentum),
nn.ReLU()
)
self.enco4 = nn.Sequential(
nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(512, momentum=bn_momentum),
nn.ReLU(),
nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(512, momentum=bn_momentum),
nn.ReLU(),
nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(512, momentum=bn_momentum),
nn.ReLU()
)
self.enco5 = nn.Sequential(
nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(512, momentum=bn_momentum),
nn.ReLU(),
nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(512, momentum=bn_momentum),
nn.ReLU(),
nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(512, momentum=bn_momentum),
nn.ReLU()
)
def forward(self, x):
id = []
x = self.enco1(x)
x, id1 = F.max_pool2d(x, kernel_size=2, stride=2, return_indices=True) # 保留最大值的位置
id.append(id1)
x = self.enco2(x)
x, id2 = F.max_pool2d(x, kernel_size=2, stride=2, return_indices=True)
id.append(id2)
x = self.enco3(x)
x, id3 = F.max_pool2d(x, kernel_size=2, stride=2, return_indices=True)
id.append(id3)
x = self.enco4(x)
x, id4 = F.max_pool2d(x, kernel_size=2, stride=2, return_indices=True)
id.append(id4)
x = self.enco5(x)
x, id5 = F.max_pool2d(x, kernel_size=2, stride=2, return_indices=True)
id.append(id5)
return x, id
解码器也是13层,对应编码器的13层
上采样时输入编码器下采样时保留的位置
论文中SegNet网络的最后加了一个分类层,不过因为损失函数选用了nn.CrossEntropyLoss(),所以便不用了
(详情参考:https://blog.csdn.net/zziahgf/article/details/80196376)
pytorch代码:
class SegNet(nn.Module):
def __init__(self, input_channels, output_channels):
super(SegNet, self).__init__()
self.weights_new = self.state_dict()
self.encoder = Encoder(input_channels)
self.deco1 = nn.Sequential(
nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(512, momentum=bn_momentum),
nn.ReLU(),
nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(512, momentum=bn_momentum),
nn.ReLU(),
nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(512, momentum=bn_momentum),
nn.ReLU()
)
self.deco2 = nn.Sequential(
nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(512, momentum=bn_momentum),
nn.ReLU(),
nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(512, momentum=bn_momentum),
nn.ReLU(),
nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(256, momentum=bn_momentum),
nn.ReLU()
)
self.deco3 = nn.Sequential(
nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(256, momentum=bn_momentum),
nn.ReLU(),
nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(256, momentum=bn_momentum),
nn.ReLU(),
nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(128, momentum=bn_momentum),
nn.ReLU()
)
self.deco4 = nn.Sequential(
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(128, momentum=bn_momentum),
nn.ReLU(),
nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(64, momentum=bn_momentum),
nn.ReLU()
)
self.deco5 = nn.Sequential(
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(64, momentum=bn_momentum),
nn.ReLU(),
nn.Conv2d(64, output_channels, kernel_size=3, stride=1, padding=1),
)
def forward(self, x):
x, id = self.encoder(x)
x = F.max_unpool2d(x, id[4], kernel_size=2, stride=2)
x = self.deco1(x)
x = F.max_unpool2d(x, id[3], kernel_size=2, stride=2)
x = self.deco2(x)
x = F.max_unpool2d(x, id[2], kernel_size=2, stride=2)
x = self.deco3(x)
x = F.max_unpool2d(x, id[1], kernel_size=2, stride=2)
x = self.deco4(x)
x = F.max_unpool2d(x, id[0], kernel_size=2, stride=2)
x = self.deco5(x)
return x
编码器使用VGG-16的权重,解码器不用
因为编码器中不包括VGG-16后面3个全连接层,所以需要经过处理
附预训练权重下载地址:https://download.pytorch.org/models/vgg16_bn-6c64b313.pth
(这个load_weights函数是SegNet类的成员函数)
pytorch代码:
def load_weights(self, weights_path):
weights = torch.load(weights_path)
del weights["classifier.0.weight"]
del weights["classifier.0.bias"]
del weights["classifier.3.weight"]
del weights["classifier.3.bias"]
del weights["classifier.6.weight"]
del weights["classifier.6.bias"]
names = []
for key, value in self.encoder.state_dict().items():
if "num_batches_tracked" in key:
continue
names.append(key)
for name, dict in zip(names, weights.items()):
self.weights_new[name] = dict[1]
self.encoder.load_state_dict(self.weights_new)
以上是SegNet网络的搭建和编码器预训练权重的加载
打开写有训练图片和标签路径的txt文件
(txt文件的创建往后看)
对训练图片和标签进行预处理
参考博客:https://blog.csdn.net/sinat_42239797/article/details/90641659
pytorch代码:
class MyDataset(Data.Dataset):
def __init__(self, txt_path):
super(MyDataset, self).__init__()
paths = open(txt_path, "r")
image_label = []
for line in paths:
line.rstrip("\n")
line.lstrip("\n")
path = line.split()
image_label.append((path[0], path[1]))
self.image_label = image_label
def __getitem__(self, item):
image, label = self.image_label[item]
image = cv.imread(image)
image = cv.resize(image, (224, 224))
image = image/255.0 # 归一化输入
image = torch.Tensor(image)
image = image.permute(2, 0, 1) # 将图片的维度转换成网络输入的维度(channel, width, height)
label = cv.imread(label, 0)
label = cv.resize(label, (224, 224))
label = torch.Tensor(label)
return image, label
def __len__(self):
return len(self.image_label)
选用cityscapes数据集,数据集的下载和介绍参考这篇博客:https://blog.csdn.net/avideointerfaces/article/details/104139298
数据集的预处理脚本:https://github.com/mcordts/cityscapesScripts
关于cityscapes数据集标签的处理可参考我写的这篇博客:https://blog.csdn.net/chenzhoujian_/article/details/106874950
将训练的图片和标签写入txt文件
注意下数据集文件的路径和训练集、验证集、测试集的数量
代码如下:
import glob
i = 0
def make_train_txt(num):
global i
paths = glob.glob("cityscapes\\leftImg8bit\\train\\*\\*")
txt = open("train.txt", "w")
for path in paths:
data = path + " " + path.replace("leftImg8bit", "gtFine").replace("gtFine.png", "gtFine_labelTrainIds.png") + "\n"
txt.write(data)
i = i + 1
if i == num:
break
i = 0
txt.close()
def make_test_txt(num):
global i
paths = glob.glob("cityscapes\\leftImg8bit\\test\\*\\*")
txt = open("test.txt", "w")
for path in paths:
data = path + " " + path.replace("leftImg8bit", "gtFine").replace("gtFine.png", "gtFine_labelTrainIds.png") + "\n"
txt.write(data)
i = i + 1
if i == num:
break
i = 0
txt.close()
def make_val_txt(num):
global i
paths = glob.glob("cityscapes\\leftImg8bit\\val\\*\\*")
txt = open("val.txt", "w")
for path in paths:
data = path + " " + path.replace("leftImg8bit", "gtFine").replace("gtFine.png", "gtFine_labelTrainIds.png") + "\n"
txt.write(data)
i = i + 1
if i == num:
break
i = 0
txt.close()
train_num = 400
test_num = 100
val_num = 100
if True:
make_train_txt(train_num)
if False:
make_test_txt(test_num)
if False:
make_val_txt(val_num)
迭代次数、训练集数量、批训练大小、学习率等超参数可自行调节
(这个代码是需要调用GPU的,没有GPU的需要自行修改下)
pytorch代码:
from SegNet import *
def train(SegNet):
SegNet = SegNet.cuda()
SegNet.load_weights(PRE_TRAINING)
train_loader = Data.DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
optimizer = torch.optim.SGD(SegNet.parameters(), lr=LR, momentum=MOMENTUM)
loss_func = nn.CrossEntropyLoss(weight=torch.from_numpy(np.array(CATE_WEIGHT)).float()).cuda()
SegNet.train()
for epoch in range(EPOCH):
for step, (b_x, b_y) in enumerate(train_loader):
b_x = b_x.cuda()
b_y = b_y.cuda()
b_y = b_y.view(BATCH_SIZE, 224, 224)
output = SegNet(b_x)
loss = loss_func(output, b_y.long())
loss = loss.cuda()
optimizer.zero_grad()
loss.backward()
optimizer.step()
if step % 1 == 0:
print("Epoch:{0} || Step:{1} || Loss:{2}".format(epoch, step, format(loss, ".4f")))
torch.save(SegNet.state_dict(), WEIGHTS + "SegNet_weights" + str(time.time()) + ".pth")
parser = argparse.ArgumentParser()
parser.add_argument("--class_num", type=int, default=2, help="训练的类别的种类")
parser.add_argument("--epoch", type=int, default=4, help="训练迭代次数")
parser.add_argument("--batch_size", type=int, default=2, help="批训练大小")
parser.add_argument("--learning_rate", type=float, default=0.01, help="学习率大小")
parser.add_argument("--momentum", type=float, default=0.9)
parser.add_argument("--category_weight", type=float, default=[0.7502381287857225, 1.4990483912788268], help="损失函数中类别的权重")
parser.add_argument("--train_txt", type=str, default="train.txt", help="训练的图片和标签的路径")
parser.add_argument("--pre_training_weight", type=str, default="vgg16_bn-6c64b313.pth", help="编码器预训练权重路径")
parser.add_argument("--weights", type=str, default="./weights/", help="训练好的权重保存路径")
opt = parser.parse_args()
print(opt)
CLASS_NUM = opt.class_num
EPOCH = opt.epoch
BATCH_SIZE = opt.batch_size
LR = opt.learning_rate
MOMENTUM = opt.momentum
CATE_WEIGHT = opt.category_weight
TXT_PATH = opt.train_txt
PRE_TRAINING = opt.pre_training_weight
WEIGHTS = opt.weights
train_data = MyDataset(txt_path=TXT_PATH)
SegNet = SegNet(3, CLASS_NUM)
train(SegNet)
mIoU的介绍可参考这篇博客:
https://blog.csdn.net/lingzhou33/article/details/87901365
pytorch代码:
from SegNet import *
parser = argparse.ArgumentParser()
parser.add_argument("--class_num", type=int, default=2, help="预测的类别的种类")
parser.add_argument("--weights", type=str, default="weights/SegNet_weights1592624668.4279704.pth", help="训练好的权重路径")
parser.add_argument("--val_paths", type=str, default="val.txt", help="验证集的图片和标签的路径")
opt = parser.parse_args()
print(opt)
CLASS_NUM = opt.class_num
WEIGHTS = opt.weights
VAL_PATHS = opt.val_paths
SegNet = SegNet(3, CLASS_NUM)
SegNet.load_state_dict(torch.load(WEIGHTS))
SegNet.eval()
paths = open(VAL_PATHS, "r")
mIoU = []
for index, line in enumerate(paths):
line.rstrip("\n")
line.lstrip("\n")
path = line.split()
image = cv.imread(path[0])
image = cv.resize(image, (224, 224))
image = image / 255.0 # 归一化输入
image = torch.Tensor(image)
image = image.permute(2, 0, 1) # 将图片的维度转换成网络输入的维度(channel, width, height)
image = torch.unsqueeze(image, dim=0)
output = SegNet(image)
output = torch.squeeze(output)
output = output.argmax(dim=0)
predict = cv.resize(np.uint8(output), (2048, 1024))
label = cv.imread(path[1])
target = label[:, :, 0]
# 自己写的方法
intersection = []
union = []
iou = 0
for i in range(1, CLASS_NUM):
intersection.append(np.sum(predict[target == i] == i))
union.append(np.sum(predict == i) + np.sum(target == i) - intersection[i-1])
iou += intersection[i-1]/union[i-1]
# 用numpy库实现的方法
# intersection = np.logical_and(target, predict)
# union = np.logical_or(target, predict)
# iou = np.sum(intersection) / np.sum(union)
mIoU.append(iou/CLASS_NUM)
print("miou_{0}:{1}".format(index, format(mIoU[index], ".4f")))
paths.close()
file = open("result.txt", "a")
print("\n")
print("mIoU:{}".format(format(np.mean(mIoU), ".4f")))
file.write("评价日期:" + str(time.asctime(time.localtime(time.time()))) + "\n")
file.write("使用的权重:" + WEIGHTS + "\n")
file.write("mIoU: " + str(format(np.mean(mIoU), ".4f")) + "\n")
file.close()
pytorch代码:
from SegNet import *
def test(SegNet):
SegNet.load_state_dict(torch.load(WEIGHTS))
SegNet.eval()
paths = os.listdir(SAMPLES)
for path in paths:
image_src = cv.imread(SAMPLES + path)
image = cv.resize(image_src, (224, 224))
image = image / 255.0
image = torch.Tensor(image)
image = image.permute(2, 0, 1)
image = torch.unsqueeze(image, dim=0)
output = SegNet(image)
output = torch.squeeze(output)
output = output.argmax(dim=0)
output_np = cv.resize(np.uint8(output), (2048, 1024))
image_seg = np.zeros((1024, 2048, 3))
image_seg = np.uint8(image_seg)
colors = COLORS
for c in range(CLASS_NUM):
image_seg[:, :, 0] += np.uint8((output_np == c)) * np.uint8(colors[c][0])
image_seg[:, :, 1] += np.uint8((output_np == c)) * np.uint8(colors[c][1])
image_seg[:, :, 2] += np.uint8((output_np == c)) * np.uint8(colors[c][2])
image_seg = Image.fromarray(np.uint8(image_seg))
old_image = Image.fromarray(np.uint8(image_src))
image = Image.blend(old_image, image_seg, 0.3)
# 将背景或空类去掉
image_np = np.array(image)
image_np[output_np == 0] = image_src[output_np == 0]
image = Image.fromarray(image_np)
image.save(OUTPUTS + path)
print(path + " is done!")
parser = argparse.ArgumentParser()
parser.add_argument("--class_num", type=int, default=2, help="预测的类别的种类")
parser.add_argument("--weights", type=str, default="weights/SegNet_weights1592624668.4279704.pth", help="训练好的权重路径")
parser.add_argument("--colors", type=int, default=[[0, 0, 0], [0, 255, 0]], help="类别覆盖的颜色")
parser.add_argument("--samples", type=str, default="samples//", help="用于测试的图片文件夹的路径")
parser.add_argument("--outputs", type=str, default="outputs//", help="保存结果的文件夹的路径")
opt = parser.parse_args()
print(opt)
CLASS_NUM = opt.class_num
WEIGHTS = opt.weights
COLORS = opt.colors
SAMPLES = opt.samples
OUTPUTS = opt.outputs
SegNet = SegNet(3, CLASS_NUM)
test(SegNet)
测试两类,road和其它
注意下类别的数量
计算出的权重用于损失函数
详情参考这篇博客:https://blog.csdn.net/fanzonghao/article/details/85263553
代码:
import cv2 as cv
import numpy as np
paths = open("train.txt", "r")
CLASS_NUM = 2
SUM = [[] for i in range(CLASS_NUM)]
SUM_ = 0
for line in paths:
line.rstrip("\n")
line.lstrip("\n")
path = line.split()
img = cv.imread(path[1], 0)
img_np = np.array(img)
for i in range(CLASS_NUM):
SUM[i].append(np.sum((img_np == i)))
for index, iter in enumerate(SUM):
print("类别{}的数量:".format(index), sum(iter))
for iter in SUM:
SUM_ += sum(iter)
median = 1/CLASS_NUM
for index, iter in enumerate(SUM):
print("weight_{}:".format(index), median/(sum(iter)/SUM_))