这个项目的数据集主要来自百度的道路分割,官网可以下到。具体代码内容如下
import torch.nn as nn
class ConvBNReLU(nn.Sequential):
def __init__(self, inplanes, planes, kernel_size, padding, dilation):
super(ConvBNReLU, self).__init__(
nn.Conv2d(inplanes, planes, kernel_size, stride=1,
padding=padding, dilation=dilation, bias=False),
nn.BatchNorm2d(planes),
nn.ReLU(inplace=True),
)
class ASPP(nn.Module):
"""
Atrous Spatial Pyramid Pooling.
Ref:
Rethinking Atrous Convolution for Semantic Image Segmentation
"""
def __init__(self, inplanes=2048, planes=256, stride=16):
super(ASPP, self).__init__()
if stride == 8:
dilation = [12, 24, 36]
elif stride == 16:
dilation = [6, 12, 18]
else:
raise NotImplementedError
self.block1 = ConvBNReLU(inplanes, planes, 1, 0, 1) # inchannel,outchannel,kernel,padding,dilation
self.block2 = ConvBNReLU(inplanes, planes, 3, dilation[0], dilation[0])
self.block3 = ConvBNReLU(inplanes, planes, 3, dilation[1], dilation[1])
self.block4 = ConvBNReLU(inplanes, planes, 3, dilation[2], dilation[2])
self.block5 = nn.Sequential(
nn.AdaptiveAvgPool2d(4),
ConvBNReLU(inplanes, planes, 1, 0, 1),
)
self.conv = ConvBNReLU(planes * 5, planes, 1, 0, 1)
self.dropout = nn.Dropout(0.5)
self._init_weight()
def _init_weight(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def forward(self, x):
h1 = self.block1(x)
h2 = self.block2(x)
h3 = self.block3(x)
h4 = self.block4(x)
h5 = self.block5(x)
h5 = F.interpolate(h5, size=x.size()[2:], mode='bilinear', align_corners=True) # 这个可以替代crf
x = torch.cat((h1, h2, h3, h4, h5), dim=1)
x = self.conv(x)
x = self.dropout(x) # 这里并没有bn和dropout连用,所以没有问题
return x
import torch.nn as nn
import torch.nn.functional as F
class Decoder(nn.Module):
"""
Ref:
Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation.
"""
def __init__(self, planes=128, num_classes=3):
super(Decoder, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(planes, planes, 1, bias=False),
nn.BatchNorm2d(planes),
nn.ReLU(inplace=True),
)
self.conv2 = nn.Sequential(
nn.Conv2d(planes + 256, 256, 3, padding=1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.Dropout(0.5),
nn.Conv2d(256, 256, 3, padding=1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.Conv2d(256, num_classes, 1), # kernel为1
)
self._init_weights()
def _init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def forward(self, x1, x2, output_size):
"""
:param x1:
:param x2: low level feature
:return:
"""
out1 = self.conv1(x2)
out0 = F.interpolate(x1, size=x2.size()[2:], mode='bilinear', align_corners=True)
out = torch.cat((out0, out1), dim=1)
out = self.conv2(out)
out = F.interpolate(out, size=output_size, mode='bilinear', align_corners=True)
return out
import torch
import math
import torch.nn as nn
class SeparableConv2d(nn.Module):
"""
Depth Separable Convolution.
"""
def __init__(self, inplanes, planes, kernel_size=3, stride=1, dilation=1, bias=False):
super(SeparableConv2d, self).__init__()
padding = (kernel_size - 1) * dilation // 2
self.depth_wise = nn.Conv2d(inplanes, inplanes, kernel_size, stride, padding,
dilation, groups=inplanes, bias=bias)
# self.bn = nn.BatchNorm2d(inplanes)
# inchannel outchannel kernel stride padding dilation groups bias
self.point_wise = nn.Conv2d(inplanes, planes, 1, 1, 0, 1, 1, bias)
def forward(self, x):
x = self.depth_wise(x)
# x = self.bn(x)
x = self.point_wise(x)
return x
class BasicConv2d(nn.Module):
def __init__(self, inplanes, planes, stride=1, dilation=1):
super(BasicConv2d, self).__init__()
self.features = nn.Sequential(
SeparableConv2d(inplanes, planes, 3, stride=stride, dilation=dilation),
nn.ReLU(inplace=True),
SeparableConv2d(planes, planes, 3, stride=1, dilation=dilation),
nn.ReLU(inplace=True),
SeparableConv2d(planes, planes, 3, stride=1, dilation=dilation)
)
self.downsample = None
if inplanes != planes or stride != 1:
self.downsample = nn.Sequential(
nn.Conv2d(inplanes, planes, 1, stride=stride, bias=False),
nn.BatchNorm2d(planes),
)
def forward(self, x):
identity = x
x = self.features(x)
if self.downsample is not None:
identity = self.downsample(identity)
x = x + identity
return x
class AlignedXception(nn.Module):
"""
Ref:
Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation.
"""
def __init__(self, stride=16):
"""
:param stride: Multiples of image down-sampling. The default value is 16(DeepLab v3+) or
it can be set to 8(DeepLab v3).
"""
super(AlignedXception, self).__init__()
if stride == 8:
self.stride = [1, 1]
self.dilation = [4, 4]
elif stride == 16:
self.stride = [2, 1]
self.dilation = [2, 2]
elif stride == 32:
self.stride = [2, 2]
self.dilation = [1, 1]
else:
raise NotImplementedError
# Entry flow
self.stem = nn.Sequential(
nn.Conv2d(3, 32, 3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, 64, 3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
)
self.stage1 = nn.Sequential(
BasicConv2d(64, 128, 2),
nn.ReLU(inplace=True),
)
self.stage2 = BasicConv2d(128, 256, 2)
self.stage3 = BasicConv2d(256, 728, self.stride[0])
# Middle flow
layers = []
for _ in range(16):
layers.append(BasicConv2d(728, 728, stride=1, dilation=self.dilation[0]))
self.stage4 = nn.Sequential(*layers)
# Exit flow
self.stage5 = nn.Sequential(
BasicConv2d(728, 1024, stride=self.stride[1], dilation=self.dilation[1]),
nn.ReLU(inplace=True),
SeparableConv2d(1024, 1536, dilation=self.dilation[1]),
nn.BatchNorm2d(1536),
nn.ReLU(inplace=True),
SeparableConv2d(1536, 1536, dilation=self.dilation[1]),
nn.BatchNorm2d(1536),
nn.ReLU(inplace=True),
SeparableConv2d(1536, 2048, dilation=self.dilation[1]),
nn.BatchNorm2d(2048),
nn.ReLU(inplace=True),
)
self._init_weight()
def _init_weight(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def forward(self, x):
"""
:param x:
:return:
result: Output two feature map to skip connect.
"""
x = self.stem(x)
x = self.stage1(x)
low_level_features = x
x = self.stage2(x)
x = self.stage3(x)
x = self.stage4(x)
x = self.stage5(x)
return x, low_level_features
import argparse
def get_parser():
parser = argparse.ArgumentParser(description="segmentation test")
parser.add_argument('--project_name', type=str, default="图像分割")
parser.add_argument('--use_cuda', type=bool, default=True)
parser.add_argument('--seed', type=int, default=123)
parser.add_argument('--data_base', type=str, default='E:/Datasets2/lane_seg')
parser.add_argument('--resume', type=bool, default=False)
parser.add_argument('--pretrained_model', type=str, default='./weights/')
parser.add_argument('--lr',type=float,default=0.001)
parser.add_argument('--milestones', type=list, default=[50, 80])
parser.add_argument('--epoches', type=int, default=200)
parser.add_argument('--save_path',type=str,default="./weights/")
args = parser.parse_args()
return args
import os
from torch.utils.data import Dataset
from torchvision import transforms
from utils import *
class BaiDuLaneDataset(Dataset):
labels = {
'void': {
'id': 0, 'trainId': 0, 'category': 'void', 'catId': 0, 'ignoreInEval': False,
'color': [0, 0, 0]},
's_w_d': {
'id': 200, 'trainId': 1, 'category': 'dividing', 'catId': 1, 'ignoreInEval': False,
'color': [70, 130, 180]},
's_y_d': {
'id': 204, 'trainId': 1, 'category': 'dividing', 'catId': 1, 'ignoreInEval': False,
'color': [220, 20, 60]},
'ds_w_dn': {
'id': 213, 'trainId': 1, 'category': 'dividing', 'catId': 1, 'ignoreInEval': True,
'color': [128, 20, 128]},
'ds_y_dn': {
'id': 209, 'trainId': 1, 'category': 'dividing', 'catId': 1, 'ignoreInEval': False,
'color': [255, 0, 0]},
'sb_y_do': {
'id': 206, 'trainId': 1, 'category': 'dividing', 'catId': 1, 'ignoreInEval': True,
'color': [0, 0, 60]},
'sb_w_do': {
'id': 207, 'trainId': 1, 'category': 'dividing', 'catId': 1, 'ignoreInEval': True,
'color': [0, 60, 100]},
'b_w_g': {
'id': 201, 'trainId': 2, 'category': 'guiding', 'catId': 2, 'ignoreInEval': False,
'color': [0, 0, 142]},
'b_y_g': {
'id': 203, 'trainId': 2, 'category': 'guiding', 'catId': 2, 'ignoreInEval': False,
'color': [119, 11, 32]},
'db_w_g': {
'id': 211, 'trainId': 2, 'category': 'guiding', 'catId': 2, 'ignoreInEval': True,
'color': [244, 35, 232]},
'db_y_g': {
'id': 208, 'trainId': 2, 'category': 'guiding', 'catId': 2, 'ignoreInEval': True,
'color': [0, 0, 160]},
'db_w_s': {
'id': 216, 'trainId': 3, 'category': 'stopping', 'catId': 3, 'ignoreInEval': True,
'color': [153, 153, 153]},
's_w_s': {
'id': 217, 'trainId': 3, 'category': 'stopping', 'catId': 3, 'ignoreInEval': False,
'color': [220, 220, 0]},
'ds_w_s': {
'id': 215, 'trainId': 3, 'category': 'stopping', 'catId': 3, 'ignoreInEval': True,
'color': [250, 170, 30]},
's_w_c': {
'id': 218, 'trainId': 4, 'category': 'chevron', 'catId': 4, 'ignoreInEval': True,
'color': [102, 102, 156]},
's_y_c': {
'id': 219, 'trainId': 4, 'category': 'chevron', 'catId': 4, 'ignoreInEval': True,
'color': [128, 0, 0]},
's_w_p': {
'id': 210, 'trainId': 5, 'category': 'parking', 'catId': 5, 'ignoreInEval': False,
'color': [128, 64, 128]},
's_n_p': {
'id': 232, 'trainId': 5, 'category': 'parking', 'catId': 5, 'ignoreInEval': True,
'color': [238, 232, 170]},
'c_wy_z': {
'id': 214, 'trainId': 6, 'category': 'zebra', 'catId': 6, 'ignoreInEval': False,
'color': [190, 153, 153]},
'a_w_u': {
'id': 202, 'trainId': 7, 'category': 'thru/turn', 'catId': 7, 'ignoreInEval': True,
'color': [0, 0, 230]},
'a_w_t': {
'id': 220, 'trainId': 7, 'category': 'thru/turn', 'catId': 7, 'ignoreInEval': False,
'color': [128, 128, 0]},
'a_w_tl': {
'id': 221, 'trainId': 7, 'category': 'thru/turn', 'catId': 7, 'ignoreInEval': False,
'color': [128, 78, 160]},
'a_w_tr': {
'id': 222, 'trainId': 7, 'category': 'thru/turn', 'catId': 7, 'ignoreInEval': False,
'color': [150, 100, 100]},
'a_w_tlr': {
'id': 231, 'trainId': 7, 'category': 'thru/turn', 'catId': 7, 'ignoreInEval': True,
'color': [255, 165, 0]},
'a_w_l': {
'id': 224, 'trainId': 7, 'category': 'thru/turn', 'catId': 7, 'ignoreInEval': False,
'color': [180, 165, 180]},
'a_w_r': {
'id': 225, 'trainId': 7, 'category': 'thru/turn', 'catId': 7, 'ignoreInEval': False,
'color': [107, 142, 35]},
'a_w_lr': {
'id': 226, 'trainId': 7, 'category': 'thru/turn', 'catId': 7, 'ignoreInEval': False,
'color': [201, 255, 229]},
'a_n_lu': {
'id': 230, 'trainId': 7, 'category': 'thru/turn', 'catId': 7, 'ignoreInEval': True,
'color': [0, 191, 255]},
'a_w_tu': {
'id': 228, 'trainId': 7, 'category': 'thru/turn', 'catId': 7, 'ignoreInEval': True,
'color': [51, 255, 51]},
'a_w_m': {
'id': 229, 'trainId': 7, 'category': 'thru/turn', 'catId': 7, 'ignoreInEval': True,
'color': [250, 128, 224]},
'a_y_t': {
'id': 233, 'trainId': 7, 'category': 'thru/turn', 'catId': 7, 'ignoreInEval': True,
'color': [127, 255, 0]},
'b_n_sr': {
'id': 205, 'trainId': 8, 'category': 'reduction', 'catId': 8, 'ignoreInEval': False,
'color': [255, 128, 0]},
'd_wy_za': {
'id': 212, 'trainId': 8, 'category': 'attention', 'catId': 8, 'ignoreInEval': True,
'color': [0, 255, 255]},
'r_wp_np': {
'id': 227, 'trainId': 8, 'category': 'no parking', 'catId': 8, 'ignoreInEval': False,
'color': [178, 132, 190]},
'vom_wy_n': {
'id': 223, 'trainId': 8, 'category': 'others', 'catId': 8, 'ignoreInEval': True,
'color': [128, 128, 64]},
'cm_n_n': {
'id': 250, 'trainId': 8, 'category': 'others', 'catId': 8, 'ignoreInEval': False,
'color': [102, 0, 204]},
'noise': {
'id': 249, 'trainId': 0, 'category': 'ignored', 'catId': 0, 'ignoreInEval': True,
'color': [0, 153, 153]},
'ignored': {
'id': 255, 'trainId': 0, 'category': 'ignored', 'catId': 0, 'ignoreInEval': True,
'color': [255, 255, 255]},
}
@staticmethod
def get_file_list(file_path, ext):
file_list = []
if ext == '':
dirs = ['ColorImage_road02/ColorImage', 'ColorImage_road03/ColorImage', 'ColorImage_road04/ColorImage']
elif ext == 'Label':
dirs = ['Gray_Label/Label_road02', 'Gray_Label/Label_road03', 'Gray_Label/Label_road04']
else:
raise NotImplementedError
for d in dirs:
f_path = os.path.join(file_path, d, ext)
dir_path = os.listdir(f_path)
dir_path = sorted(dir_path)
for dir in dir_path:
dir = os.path.join(d, ext, dir)
camera_file = os.listdir(file_path + '/' + dir)
camera_file = sorted(camera_file)
for file in camera_file:
path = os.path.join(dir, file)
for x in sorted(os.listdir(file_path + '/' + path)):
file_list.append(path + '/' + x)
return file_list
def __init__(self, root_file, phase='train', output_size=(846, 255), num_classes=8, adjust_factor=(0.3, 2.),
radius=(0., 1.)):
super(BaiDuLaneDataset, self).__init__()
assert phase in ['train', 'val', 'test']
self.root_file = root_file
img_ext = ''
label_ext = 'Label'
self.img_list = self.get_file_list(self.root_file, img_ext)
self.label_list = self.get_file_list(self.root_file, label_ext)
self.output_size = output_size
self.factor = adjust_factor
self.radius = radius
self.transform = self.preprocess(phase)
self.num_classes = num_classes
self.phase = phase
num_data = len(self.img_list)
assert num_data == len(self.label_list)
np.random.seed(2020)
data_list = np.random.permutation(num_data)
self.img_list = np.array(self.img_list)[data_list].tolist()
self.label_list = np.array(self.label_list)[data_list].tolist()
if phase == 'train':
self.img_list = self.img_list[0:int(0.7 * num_data)]
self.label_list = self.label_list[0:int(0.7 * num_data)]
elif phase == 'val':
self.img_list = self.img_list[int(0.7 * num_data):int(0.9 * num_data)]
self.label_list = self.label_list[int(0.7 * num_data):int(0.9 * num_data)]
elif phase == 'test':
self.img_list = self.img_list[int(0.9 * num_data):]
self.label_list = self.label_list[int(0.9 * num_data):]
else:
raise NotImplementedError
def __getitem__(self, item):
img = cv2.imread(self.root_file + '/' + self.img_list[item], cv2.IMREAD_UNCHANGED)
target = cv2.imread(self.root_file + '/' + self.label_list[item], cv2.IMREAD_UNCHANGED)
assert os.path.basename(self.img_list[item]).replace('.jpg', '') == \
os.path.basename(self.label_list[item]).replace('_bin.png', '') # 这个是保证label与data之间是对应的
offset = 690 # 过滤上面的无用空间
img = img[offset:, :]
# 图片与标签之间的对齐操作
if self.phase != 'test':
target = target[offset:, :]
# print(self.img_list[item])
# print(self.label_list[item])
target = self.encode_label_map(target) # 读的灰度图,根据灰度图的值做标签
# 从cv转到了PIL
img = Image.fromarray(img)
target = Image.fromarray(target)
sample = {
'image': img, 'label': target}
if self.transform is not None:
sample = self.transform(sample)
return sample
def __len__(self):
if self.phase=='train':
return len(self.img_list)
else:
return 10
'''
这个相当于是自己写的getitem
'''
def data_generator(self, batch_size):
index = np.arange(0, len(self.img_list))
while len(index):
select = np.random.choice(index, batch_size)
images = []
targets = []
for item in select:
img = cv2.imread(self.root_file + '/' + self.img_list[item], cv2.IMREAD_UNCHANGED)
target = cv2.imread(self.root_file + '/' + self.label_list[item], cv2.IMREAD_UNCHANGED)
print(self.root_file + '/' + self.label_list[item])
index = np.delete(index, select)
sample = {
'image': img, 'label': target}
if self.transforms is not None:
sample = self.transform(sample)
images.append(sample['image'])
targets.append(sample['label'])
yield {
'image': images, 'label': targets}
# 按字典里的内容循环,每次将符合一个类的内容的值在mask上做一次标记,最后作为mask返回
def encode_label_map(self, mask):
for value in self.labels.values():
pixel = value['id']
if value['ignoreInEval']:
# 0: category as background
mask[mask == pixel] = 0
else:
trainId = value['trainId']
if trainId > 4: # 这里的操作可以看出,在数据处理的时候他把4,5合并在一起了,这种要从业务层面去理解,可能4和5是一样的
trainId -= 1
mask[mask == pixel] = trainId
return mask
# 根据mask对应的id,映射成灰度图
def decode_label_map(self, mask):
mask[mask == 1] = 200
mask[mask == 2] = 201
mask[mask == 3] = 216
mask[mask == 4] = 210
mask[mask == 5] = 214
mask[mask == 6] = 202
mask[mask == 7] = 205
return mask
# 根据mask对应的id,映射成彩图
def decode_color_map(self, mask):
new_mask = np.zeros((*mask.shape, 3), dtype=np.uint8)
new_mask[mask == 0] = [0, 0, 0]
new_mask[mask == 1] = [70, 130, 180]
new_mask[mask == 2] = [0, 0, 142]
new_mask[mask == 3] = [153, 153, 153]
new_mask[mask == 4] = [128, 64, 128]
new_mask[mask == 5] = [190, 153, 153]
new_mask[mask == 6] = [0, 0, 230]
new_mask[mask == 7] = [255, 128, 0]
return new_mask
def preprocess(self, phase):
if phase == 'train':
preprocess = transforms.Compose([
FixedResize(self.output_size),
Translate(50, 255),
# RandomScale(),
CutOut(64),
RandomHorizontalFlip(),
AdjustColor(self.factor),
RandomGaussianBlur(self.radius),
Normalize(mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225)),
ToTensor(),
])
elif phase == 'val':
preprocess = transforms.Compose([
FixedResize(self.output_size),
Normalize(mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225)),
ToTensor(),
])
elif phase == 'test':
preprocess = transforms.Compose([
FixedResize(self.output_size, is_resize=False),
Normalize(mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225)),
ToTensor(),
])
else:
raise NotImplementedError
return preprocess
import numpy as np
import torch.nn as nn
from common import Decoder, ASPP, AlignedXception
def conv3x3(in_channels, out_channels, stride=1, dilation=1):
kernel_size = np.asarray((3, 3))
upsampled_kernel_size = (kernel_size - 1) * (dilation - 1) + kernel_size
full_padding = (upsampled_kernel_size - 1) // 2
full_padding, kernel_size = tuple(full_padding), tuple(kernel_size)
return nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride,
padding=full_padding, dilation=dilation, bias=False)
# 这里的配置使用的是deeplabv3++
class DeepLab(nn.Module):
def __init__(self, backbone="aligned_inception", stride=16, num_classes=8, pretrained=False):
super(DeepLab, self).__init__()
self.backbone = AlignedXception(stride)
planes = 128
self.aspp = ASPP(2048, 256, 16)
self.decoder = Decoder(planes=planes, num_classes=num_classes)
def forward(self, x):
x1, x2 = self.backbone(x)
x1 = self.aspp(x1)
x = self.decoder(x1, x2, x.size()[2:])
return x
#
# Ref:https://github.com/jfzhang95/pytorch-deeplab-xception/blob/master/dataloaders/custom_transforms.py
#
import random
import cv2
import torchvision.transforms.functional as FF
from PIL import Image, ImageOps, ImageFilter
class Normalize(object):
"""Normalize a tensor image with mean and standard deviation.
Args:
mean (tuple): means for each channel.
std (tuple): standard deviations for each channel.
"""
def __init__(self, mean=(0., 0., 0.), std=(1., 1., 1.)):
self.mean = mean
self.std = std
def __call__(self, sample):
img = sample['image']
mask = sample['label']
img = np.array(img).astype(np.float32)
mask = np.array(mask).astype(np.float32)
img /= 255.0
img -= self.mean
img /= self.std
return {
'image': img, 'label': mask}
class ToTensor(object):
"""Convert ndarrays in sample to Tensors."""
def __call__(self, sample):
# swap color axis because
# numpy image: H x W x C
# torch image: C X H X W
img = sample['image']
mask = sample['label']
img = np.array(img).astype(np.float32).transpose((2, 0, 1))
mask = np.array(mask).astype(np.float32)
img = torch.from_numpy(img).float()
mask = torch.from_numpy(mask).float()
return {
'image': img, 'label': mask}
class RandomHorizontalFlip(object):
def __call__(self, sample):
img = sample['image']
mask = sample['label']
if np.random.random() < 0.5:
img = img.transpose(Image.FLIP_LEFT_RIGHT)
mask = mask.transpose(Image.FLIP_LEFT_RIGHT)
return {
'image': img, 'label': mask}
class RandomRotate(object):
def __init__(self, degree):
self.degree = degree
def __call__(self, sample):
img = sample['image']
mask = sample['label']
rotate_degree = random.uniform(-1 * self.degree, self.degree)
img = img.rotate(rotate_degree, Image.BILINEAR)
mask = mask.rotate(rotate_degree, Image.NEAREST)
return {
'image': img, 'label': mask}
class RandomGaussianBlur(object):
def __init__(self, radius=(0., 1.)):
self.radius = radius
def __call__(self, sample):
img = sample['image']
mask = sample['label']
if np.random.random() < 0.5:
img = img.filter(ImageFilter.GaussianBlur(
radius=random.uniform(*self.radius)))
return {
'image': img, 'label': mask}
class RandomScaleCrop(object):
def __init__(self, base_size, crop_size, fill=0):
self.base_size = base_size
self.crop_size = crop_size
self.fill = fill
def __call__(self, sample):
img = sample['image']
mask = sample['label']
# random scale (short edge)
short_size = random.randint(int(self.base_size * 0.5), int(self.base_size * 2.0))
w, h = img.size
if h > w:
ow = short_size
oh = int(1.0 * h * ow / w)
else:
oh = short_size
ow = int(1.0 * w * oh / h)
img = img.resize((ow, oh), Image.BILINEAR)
mask = mask.resize((ow, oh), Image.NEAREST)
# pad crop
if short_size < self.crop_size:
padh = self.crop_size - oh if oh < self.crop_size else 0
padw = self.crop_size - ow if ow < self.crop_size else 0
img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0)
mask = ImageOps.expand(mask, border=(0, 0, padw, padh), fill=self.fill)
# random crop crop_size
w, h = img.size
x1 = random.randint(0, w - self.crop_size)
y1 = random.randint(0, h - self.crop_size)
img = img.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size))
mask = mask.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size))
return {
'image': img, 'label': mask}
class FixScaleCrop(object):
def __init__(self, crop_size):
self.crop_size = crop_size
def __call__(self, sample):
img = sample['image']
mask = sample['label']
w, h = img.size
if w > h:
oh = self.crop_size
ow = int(1.0 * w * oh / h)
else:
ow = self.crop_size
oh = int(1.0 * h * ow / w)
img = img.resize((ow, oh), Image.BILINEAR)
mask = mask.resize((ow, oh), Image.NEAREST)
# center crop
w, h = img.size
x1 = int(round((w - self.crop_size) / 2.))
y1 = int(round((h - self.crop_size) / 2.))
img = img.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size))
mask = mask.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size))
return {
'image': img, 'label': mask}
class FixedResize(object):
def __init__(self, size, is_resize=True):
if isinstance(size, tuple):
self.size = size
else:
self.size = (size, size) # size: (h, w)
self.is_resize = is_resize
def __call__(self, sample):
img = sample['image']
mask = sample['label']
# assert img.size == mask.size
img = img.resize(self.size, Image.BILINEAR)
if self.is_resize:
mask = mask.resize(self.size, Image.NEAREST)
return {
'image': img, 'label': mask}
class AdjustColor(object):
def __init__(self, factor=(0.3, 2.)):
self.factor = factor
def __call__(self, sample):
img = sample['image']
mask = sample['label']
assert img.size == mask.size
brightness_factor = np.random.uniform(*self.factor)
contrast_factor = np.random.uniform(*self.factor)
saturation_factor = np.random.uniform(*self.factor)
img = FF.adjust_brightness(img, brightness_factor)
img = FF.adjust_contrast(img, contrast_factor)
img = FF.adjust_saturation(img, saturation_factor)
return {
'image': img, 'label': mask}
class CutOut(object):
def __init__(self, mask_size):
self.mask_size = mask_size
def __call__(self, sample):
img = sample['image']
mask = sample['label']
image = np.array(img)
mask = np.array(mask)
mask_size_half = self.mask_size // 2
offset = 1 if self.mask_size % 2 == 0 else 0
h, w = image.shape[:2]
# find mask center coordinate
cxmin, cxmax = mask_size_half, w + offset - mask_size_half
cymin, cymax = mask_size_half, h + offset - mask_size_half
cx = np.random.randint(cxmin, cxmax)
cy = np.random.randint(cymin, cymax)
# left-top point
xmin, ymin = cx - mask_size_half, cy - mask_size_half
# right-bottom point
xmax, ymax = xmin + self.mask_size, ymin + self.mask_size
xmin, ymin, xmax, ymax = max(0, xmin), max(0, ymin), min(w, xmax), min(h, ymax)
if random.uniform(0, 1) < 0.5:
image[ymin:ymax, xmin:xmax] = (0, 0, 0)
return {
'image': Image.fromarray(image), 'label': Image.fromarray(mask)}
class RandomScale(object):
def __call__(self, sample):
image = sample['image']
mask = sample['label']
image = np.array(image)
mask = np.array(mask)
scale = np.random.uniform(0.7, 1.5)
h, w = image.shape[:2]
aug_image = image.copy()
aug_mask = mask.copy()
aug_image = cv2.resize(aug_image, (int(scale * w), int(scale * h)))
aug_mask = cv2.resize(aug_mask, (int(scale * w), int(scale * h)))
if scale < 1.:
new_h, new_w, _ = aug_image.shape
pre_h_pad = int((h - new_h) / 2)
pre_w_pad = int((w - new_w) / 2)
pad_list = [[pre_h_pad, h - new_h - pre_h_pad], [pre_w_pad, w - new_w - pre_w_pad], [0, 0]]
aug_image = np.pad(aug_image, pad_list, mode="constant", constant_values=0)
aug_mask = np.pad(aug_mask, pad_list[:2], mode="constant", constant_values=255)
if scale >= 1.:
new_h, new_w = aug_image.shape[:2]
pre_h_crop = int((new_h - h) / 2)
pre_w_crop = int((new_w - w) / 2)
post_h_crop = h + pre_h_crop
post_w_crop = w + pre_w_crop
aug_image = aug_image[pre_h_crop:post_h_crop, pre_w_crop:post_w_crop]
aug_mask = aug_mask[pre_h_crop:post_h_crop, pre_w_crop:post_w_crop]
return {
'image': Image.fromarray(aug_image), 'label': Image.fromarray(aug_mask)}
class Translate(object):
def __init__(self, t=50, ingore_index=255):
self.t = t
self.ingore_index = ingore_index
def __call__(self, sample):
image = sample['image']
target = sample['label']
image = np.array(image)
target = np.array(target)
if np.random.random() > 0.5:
x = random.uniform(-self.t, self.t)
y = random.uniform(-self.t, self.t)
M = np.float32([[1, 0, x],
[0, 1, y]])
h, w = image.shape[:2]
image = cv2.warpAffine(image, M, (w, h), flags=cv2.INTER_LINEAR, borderValue=(0, 0, 0))
target = cv2.warpAffine(target, M, (w, h), flags=cv2.INTER_NEAREST, borderValue=(self.ingore_index,))
return {
'image': Image.fromarray(image), 'label': Image.fromarray(target)}
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
class SegmentationLosses(nn.Module):
def __init__(self, num_classes=8, mode='CE', weights=None,
ignore_index=255, gamma=2, alpha=0.5, reduction='mean'):
super().__init__()
self.gamma = gamma
self.alpha = alpha
self.mode = mode
self.weights = weights
self.ignore_index = ignore_index
self.reduction = reduction
self.num_classes = num_classes
def forward(self, preds, target):
""""""
H1, W1 = preds.size()[2:]
H2, W2 = target.size()[1:]
assert H1 == H2 and W1 == W2
if self.mode == 'CE':
return self.CrossEntropyLoss(preds, target)
elif self.mode == 'FL':
return self.FocalLoss(preds, target)
elif self.mode == 'Dice':
return self.GeneralizedSoftDiceLoss(preds, target)
elif self.mode == 'Dice2':
return self.BatchSoftDeviceLoss(preds, target)
elif self.mode == 'CE || Dice':
loss = self.CrossEntropyLoss(preds, target) + \
self.GeneralizedSoftDiceLoss(preds, target)
return loss
else:
raise NotImplementedError
def CrossEntropyLoss(self, preds, target):
"""
:param preds: Tensor of shape [N, C, H, W]
:param target: Tensor of shape [N, H, W]
:return:
"""
device = target.device
# if self.weights is not None:
# weight = self.weights.to(device)
# else:
# arr = target.data.cpu().numpy().reshape(-1)
# weight = np.bincount(arr)
# weight = weight.astype(np.float)
# # weight = weight.sum() / weight
# weight = weight / weight.sum()
# median = np.median(weight)
# for i in range(weight.shape[0]):
# if int(weight[i]) == 0:
# continue
# weight[i] = median / weight[i]
# weight = torch.from_numpy(weight).to(device).float()
return F.cross_entropy(preds, target, weight=self.weights.to(device), ignore_index=self.ignore_index)
def FocalLoss(self, preds, target):
"""
FL = alpha * (1 - pt) ** beta * log(pt)
:param preds: Tensor of shape [N, C, H, W]
:param target: Tensor of shape [N, H, W]
:return:
"""
logits = -F.cross_entropy(preds, target.long(),
ignore_index=self.ignore_index)
pt = torch.exp(logits)
if self.alpha is not None:
logits *= self.alpha
loss = -((1 - pt) ** self.gamma) * logits
return loss
def GeneralizedSoftDiceLoss(self, preds, target):
"""
Paper:
https://arxiv.org/pdf/1606.04797.pdf
:param preds: Tensor of shape [N, C, H, W]
:param target: Tensor of shape [N, H, W]
:return:
"""
# overcome ignored label
ignore = target.data.cpu() == self.ignore_index
label = target.clone()
label[ignore] = 0
lb_one_hot = torch.zeros_like(preds).scatter_(1, label.unsqueeze(1), 1)
ignore = ignore.nonzero()
_, M = ignore.size()
a, *b = ignore.chunk(M, dim=1)
lb_one_hot[[a, torch.arange(lb_one_hot.size(1)).long(), *b]] = 0
lb_one_hot = lb_one_hot.detach()
# compute loss
probs = torch.sigmoid(preds)
numer = torch.sum((probs * lb_one_hot), dim=(2, 3))
denom = torch.sum(probs.pow(1) + lb_one_hot.pow(1), dim=(2, 3))
if not self.weights is None:
numer = numer * self.weight.view(1, -1)
denom = denom * self.weight.view(1, -1)
numer = torch.sum(numer, dim=1)
denom = torch.sum(denom, dim=1)
smooth = 1
loss = 1 - (2 * numer + smooth) / (denom + smooth)
if self.reduction == 'mean':
loss = loss.mean()
return loss
def BatchSoftDeviceLoss(self, preds, target):
"""
:param preds:
:param target:
:return:
"""
# overcome ignored label
ignore = target.data.cpu() == self.ignore_index
target = target.clone()
target[ignore] = 0
lb_one_hot = torch.zeros_like(preds).scatter_(1, target.unsqueeze(1), 1)
ignore = ignore.nonzero()
_, M = ignore.size()
a, *b = ignore.chunk(M, dim=1)
lb_one_hot[[a, torch.arange(lb_one_hot.size(1)).long(), *b]] = 0
lb_one_hot = lb_one_hot.detach()
# compute loss
probs = torch.sigmoid(preds)
numer = torch.sum((probs * lb_one_hot), dim=(2, 3))
denom = torch.sum(probs.pow(1) + lb_one_hot.pow(1), dim=(2, 3))
if not self.weights is None:
numer = numer * self.weight.view(1, -1)
denom = denom * self.weight.view(1, -1)
numer = torch.sum(numer)
denom = torch.sum(denom)
smooth = 1
loss = 1 - (2 * numer + smooth) / (denom + smooth)
return loss
if __name__ == '__main__':
criteria = SegmentationLosses(mode='CE')
# logits = torch.randn(16, 19, 14, 14)
im = torch.randn(16, 3, 14, 14)
label = torch.randint(0, 19, (16, 14, 14)).long()
net = torch.nn.Conv2d(3, 19, 3, 1, 1)
print(label.dtype)
label[2, 3, 3] = 255
print(label.dtype)
logits = net(im)
loss = criteria(logits, label)
loss.backward()
print(loss)
import cv2
import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim
from torch.backends import cudnn
from torch.utils.data import DataLoader
from tqdm import tqdm
from config import get_parser
from datalist import BaiDuLaneDataset
from model import DeepLab
from utils import SegmentationLosses
metric_loss=0
class train():
def __init__(self):
self.args = get_parser()
print(f"-----------{self.args.project_name}-------------")
use_cuda = self.args.use_cuda and torch.cuda.is_available()
if use_cuda:
torch.cuda.manual_seed(self.args.seed)
torch.cuda.manual_seed_all(self.args.seed)
else:
torch.manual_seed(self.args.seed)
self.device = torch.device('cuda' if use_cuda else 'cpu')
train_kwargs = {
'num_workers': 0, 'pin_memory': True} if use_cuda else {
}
test_kwargs = {
'num_workers': 0, 'pin_memory': False} if use_cuda else {
}
'''
构造DataLoader
'''
self.train_dataset = BaiDuLaneDataset(root_file=self.args.data_base, phase='train')
self.test_dataset = BaiDuLaneDataset(root_file=self.args.data_base, phase='test')
self.train_dataloader = DataLoader(self.train_dataset, batch_size=10, **train_kwargs)
self.test_dataloader = DataLoader(self.test_dataset, batch_size=10, **test_kwargs)
'''
定义模型
'''
self.model = DeepLab().to(self.device)
'''
CUDA加速
'''
if use_cuda:
self.model = torch.nn.DataParallel(self.model, device_ids=range(torch.cuda.device_count()))
cudnn.benchmark = True
'''
根据需要加载与训练模型权重参数
'''
if self.args.resume and self.args.pretrained_model:
data_dict = torch.load(self.args.pretrained_model)
new_data_dict = {
}
for k, v in data_dict.items():
new_data_dict[k] = v
self.model.load_state_dict(new_data_dict, strict=False)
print("load pretrained model successful!")
else:
print("initial net weights from stratch!")
'''
构造loss目标函数
选择优化器
学习率变化选择
'''
weights = torch.FloatTensor([0.00289, 0.2411, 1.068, 2.547, 7.544, 0.2689, 0.9043, 1.572])
self.criterion = SegmentationLosses(mode='CE', weights=weights).to(self.device) # 这里使用了weighted crossentropy
self.optimizer = optim.Adam(self.model.parameters(), lr=self.args.lr)
self.scheduler = torch.optim.lr_scheduler.MultiStepLR(self.optimizer, milestones=self.args.milestones,
gamma=0.1)
for epoch in range(1, self.args.epoches):
self.train(epoch)
if epoch % 1 == 0:
self.test(epoch)
torch.cuda.empty_cache()
print("model finish training")
def train(self, epoch):
global metric_loss
self.model.train()
average_loss = []
pbar = tqdm(self.train_dataloader, desc=f'Train Epoch{epoch}/{self.args.epoches}')
for data in pbar:
img, target = data['image'], data['label']
img, target = img.to(self.device), target.to(self.device)
self.optimizer.zero_grad()
outputs = self.model(img)
loss = self.criterion(outputs, target.long()).cpu()
average_loss.append(loss.item())
loss.backward()
self.optimizer.step()
pbar.set_description(
f'Train Epoch:{epoch}/{self.args.epoches} '
f'train_loss:{round(np.mean(average_loss), 2)} '
f'learning_rate:{self.optimizer.state_dict()["param_groups"][0]["lr"]}')
self.scheduler.step()
if np.mean(average_loss)<metric_loss and self.args.save_path:
metric_loss=np.mean(average_loss)
torch.save({
'epoch': epoch,
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'loss': round(np.mean(average_loss), 2)
},
'./weights/'+f'Epoch-{epoch}-loss-{metric_loss}.pth')
print("model saved")
def test(self, epoch):
self.model.eval()
with torch.no_grad():
pbar = tqdm(self.test_dataloader, desc=f'Test Epoch{epoch}/{self.args.epoches}')
for data in pbar:
img, target = data['image'], data['label']
img, target = img.to(self.device), target.to(self.device)
outputs = self.model(img)
outputs = F.interpolate(outputs, size=(1020, 3384), mode='bilinear', align_corners=True)
preds = outputs.data.max(1)[1].cpu().numpy()
pbar.set_description(
f'【Test Epoch】:{epoch}/{self.args.epoches} '
)
# 最后一个批次里的一张图拿出来看效果
temp = img
img = img.cpu().numpy()
img = np.transpose(img[0], axes=[1, 2, 0])
img *= (0.229, 0.224, 0.225)
img += (0.485, 0.456, 0.406)
img *= 255.0
img = img.astype(np.uint8)
img = cv2.resize(img, (3384, 1710))
mask = np.zeros((temp.size(0), 690, 3384))
preds = np.hstack((mask.astype(preds.dtype), preds))
pred = preds[0].astype(np.uint8)
pred = self.test_dataset.decode_color_map(pred)
result = np.vstack((pred, img))
cv2.imwrite("./result/epoch-" + str(epoch) + "_result_predict.jpg", result)
train = train()
如图所示,这是我train的几个epoch 之后的效果,这里大致可以通过更换datalist和model实现不同而的分割项目