UNet多分类记录

因为要做个简单的实验,这里把UNet多分类记录下。

1.数据准备

以公共数据CamVid为例,一级目录

UNet多分类记录_第1张图片

二级目录

UNet多分类记录_第2张图片

2.数据加载

文件data.py,脚本前面是数据增强部分,我自己的数据加载函数从128行开始,包括read_own_data、own_data_loader、own_data_test_loader函数以及ImageFolder加载类。

# -*- coding: utf-8 -*-
import torch
import torch.utils.data as data
from torch.autograd import Variable as V
from PIL import Image
 
import cv2
import numpy as np
import os
import scipy.misc as misc
 
def randomHueSaturationValue(image, hue_shift_limit=(-180, 180),
                             sat_shift_limit=(-255, 255),
                             val_shift_limit=(-255, 255), u=0.5):
    if np.random.random() < u:
        image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
        h, s, v = cv2.split(image)
        hue_shift = np.random.randint(hue_shift_limit[0], hue_shift_limit[1]+1)
        hue_shift = np.uint8(hue_shift)
        h += hue_shift
        sat_shift = np.random.uniform(sat_shift_limit[0], sat_shift_limit[1])
        s = cv2.add(s, sat_shift)
        val_shift = np.random.uniform(val_shift_limit[0], val_shift_limit[1])
        v = cv2.add(v, val_shift)
        image = cv2.merge((h, s, v))
        #image = cv2.merge((s, v))
        image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR)
 
    return image
 
def randomShiftScaleRotate(image, mask,
                           shift_limit=(-0.0, 0.0),
                           scale_limit=(-0.0, 0.0),
                           rotate_limit=(-0.0, 0.0), 
                           aspect_limit=(-0.0, 0.0),
                           borderMode=cv2.BORDER_CONSTANT, u=0.5):
    if np.random.random() < u:
        height, width, channel = image.shape
 
        angle = np.random.uniform(rotate_limit[0], rotate_limit[1])
        scale = np.random.uniform(1 + scale_limit[0], 1 + scale_limit[1])
        aspect = np.random.uniform(1 + aspect_limit[0], 1 + aspect_limit[1])
        sx = scale * aspect / (aspect ** 0.5)
        sy = scale / (aspect ** 0.5)
        dx = round(np.random.uniform(shift_limit[0], shift_limit[1]) * width)
        dy = round(np.random.uniform(shift_limit[0], shift_limit[1]) * height)
 
        cc = np.math.cos(angle / 180 * np.math.pi) * sx
        ss = np.math.sin(angle / 180 * np.math.pi) * sy
        rotate_matrix = np.array([[cc, -ss], [ss, cc]])
 
        box0 = np.array([[0, 0], [width, 0], [width, height], [0, height], ])
        box1 = box0 - np.array([width / 2, height / 2])
        box1 = np.dot(box1, rotate_matrix.T) + np.array([width / 2 + dx, height / 2 + dy])
 
        box0 = box0.astype(np.float32)
        box1 = box1.astype(np.float32)
        mat = cv2.getPerspectiveTransform(box0, box1)
        image = cv2.warpPerspective(image, mat, (width, height), flags=cv2.INTER_LINEAR, borderMode=borderMode,
                                    borderValue=(
                                        0, 0,
                                        0,))
        mask = cv2.warpPerspective(mask, mat, (width, height), flags=cv2.INTER_LINEAR, borderMode=borderMode,
                                   borderValue=(
                                       0, 0,
                                       0,))
 
    return image, mask
 
def randomHorizontalFlip(image, mask, u=0.5):
    if np.random.random() < u:
        image = cv2.flip(image, 1)
        mask = cv2.flip(mask, 1)
 
    return image, mask
 
def randomVerticleFlip(image, mask, u=0.5):
    if np.random.random() < u:
        image = cv2.flip(image, 0)
        mask = cv2.flip(mask, 0)
 
    return image, mask
 
def randomRotate90(image, mask, u=0.5):
    if np.random.random() < u:
        image=np.rot90(image)
        mask=np.rot90(mask)
 
    return image, mask
 
 
def default_loader(img_path, mask_path):
 
    img = cv2.imread(img_path)
    # print("img:{}".format(np.shape(img)))
    img = cv2.resize(img, (448, 448))
 
    mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
 
    mask = 255. - cv2.resize(mask, (448, 448))
    
    img = randomHueSaturationValue(img,
                                   hue_shift_limit=(-30, 30),
                                   sat_shift_limit=(-5, 5),
                                   val_shift_limit=(-15, 15))
 
    img, mask = randomShiftScaleRotate(img, mask,
                                       shift_limit=(-0.1, 0.1),
                                       scale_limit=(-0.1, 0.1),
                                       aspect_limit=(-0.1, 0.1),
                                       rotate_limit=(-0, 0))
    img, mask = randomHorizontalFlip(img, mask)
    img, mask = randomVerticleFlip(img, mask)
    img, mask = randomRotate90(img, mask)
    
    mask = np.expand_dims(mask, axis=2)
    #
    # print(np.shape(img))
    # print(np.shape(mask))
 
    img = np.array(img, np.float32).transpose(2,0,1)/255.0 * 3.2 - 1.6
    mask = np.array(mask, np.float32).transpose(2,0,1)/255.0
    mask[mask >= 0.5] = 1
    mask[mask <= 0.5] = 0
    #mask = abs(mask-1)
    return img, mask
 
def read_own_data(root_path, mode = 'train'):
    images = []
    masks = []
 
    image_root = os.path.join(root_path, mode + '/images')
    gt_root = os.path.join(root_path, mode + '/labels')
 
 
    for image_name in os.listdir(gt_root):
        image_path = os.path.join(image_root, image_name)
        label_path = os.path.join(gt_root, image_name)
 
        images.append(image_path)
        masks.append(label_path)
 
    return images, masks
 
def own_data_loader(img_path, mask_path):
    img = cv2.imread(img_path)
    img = cv2.resize(img, (512,512), interpolation = cv2.INTER_NEAREST)
    mask = cv2.imread(mask_path, 0)
    mask = cv2.resize(mask, (512,512), interpolation = cv2.INTER_NEAREST)
 
    img = randomHueSaturationValue(img,
                                   hue_shift_limit=(-30, 30),
                                   sat_shift_limit=(-5, 5),
                                   val_shift_limit=(-15, 15))
 
    img, mask = randomShiftScaleRotate(img, mask,
                                       shift_limit=(-0.1, 0.1),
                                       scale_limit=(-0.1, 0.1),
                                       aspect_limit=(-0.1, 0.1),
                                       rotate_limit=(-0, 0))
    img, mask = randomHorizontalFlip(img, mask)
    img, mask = randomVerticleFlip(img, mask)
    img, mask = randomRotate90(img, mask)
 
    mask = np.expand_dims(mask, axis=2)
 
    img = np.array(img, np.float32) / 255.0 * 3.2 - 1.6
    mask = np.array(mask, np.float32)

    img = np.array(img, np.float32).transpose(2, 0, 1)
    mask = np.array(mask, np.float32).transpose(2, 0, 1)
    return img, mask
 
def own_data_test_loader(img_path, mask_path):
    img = cv2.imread(img_path)
    img = cv2.resize(img, (512,512), interpolation = cv2.INTER_NEAREST)
    mask = cv2.imread(mask_path, 0)
    mask = cv2.resize(mask, (512,512), interpolation = cv2.INTER_NEAREST)
    mask = np.expand_dims(mask, axis=2)
 
    img = np.array(img, np.float32) / 255.0 * 3.2 - 1.6
    mask = np.array(mask, np.float32)
    #mask[mask >= 0.5] = 1
    #mask[mask < 0.5] = 0
 
    img = np.array(img, np.float32).transpose(2, 0, 1)
    mask = np.array(mask, np.float32).transpose(2, 0, 1)
 
    return img, mask
 
class ImageFolder(data.Dataset):
 
    def __init__(self,root_path, mode='train'):
        self.root = root_path
        self.mode = mode
        self.images, self.labels = read_own_data(self.root, self.mode)
 
    def __getitem__(self, index):
        if self.mode == 'test':
            img, mask = own_data_test_loader(self.images[index], self.labels[index])
        else:
            img, mask = own_data_loader(self.images[index], self.labels[index])
            img = torch.Tensor(img)
            mask = torch.Tensor(mask)
        return img, mask
 
    def __len__(self):
        assert len(self.images) == len(self.labels), 'The number of images must be equal to labels'
        return len(self.images)

3.训练

里面用到了比较好的两个包segmentation_models_pytorch,pytorch_toolbelt。其中segmentation_models_pytorch里面包含了很多常见语义分割模型的实现,同时支持了二分类和多分类,直接pip install segmentation_models_pytorch就可以安装。另外注意,训练的类别数在一开始的地方改,那个n_classes = 12那里

# -*- coding: utf-8 -*-
import time
import warnings
import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.optim.swa_utils import AveragedModel, SWALR
import segmentation_models_pytorch as smp
from segmentation_models_pytorch.losses import DiceLoss, SoftCrossEntropyLoss, LovaszLoss
from pytorch_toolbelt import losses as L
from data import ImageFolder
from sklearn import metrics

warnings.filterwarnings('ignore')
torch.backends.cudnn.enabled = True
 
DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu' 
n_classes = 12

def cal_cm(y_true, y_pred):
	y_true = y_true.reshape(1,-1).squeeze()
	y_pred = y_pred.reshape(1,-1).squeeze()
	cm = metrics.confusion_matrix(y_true,y_pred)
	return cm


def iou_mean(pred, target, n_classes = n_classes):
    #n_classes :the number of classes in your dataset,not including background
	# for mask and ground-truth label, not probability map
	ious = []
	iousSum = 0
	# pred = torch.from_numpy(pred)
	pred = pred.view(-1)
	# print(type(pred))
	target = np.array(target.cpu())
	target = torch.from_numpy(target)
	# print(type(target))
	target = target.view(-1)

	# Ignore IoU for background class ("0")
	for cls in range(1, n_classes):  # This goes from 1:n_classes-1 -> class "0" is ignored
		pred_inds = pred == cls
		target_inds = target == cls
		intersection = (pred_inds[target_inds]).long().sum().data.cpu().item()  # Cast to long to prevent overflows
		union = pred_inds.long().sum().data.cpu().item() + target_inds.long().sum().data.cpu().item() - intersection
		if union == 0:
			ious.append(float('nan'))  # If there is no ground truth, do not include in evaluation
		else:
			ious.append(float(intersection) / float(max(union, 1)))
			iousSum += float(intersection) / float(max(union, 1))
	return iousSum/n_classes

def multi_acc(pred, label):
    probs = torch.log_softmax(pred, dim = 1)
    _, tags = torch.max(probs, dim = 1)
    corrects = torch.eq(tags,label).int()
    acc = corrects.sum()/corrects.numel()
    return acc


def train(EPOCHES, BATCH_SIZE, data_root, channels, optimizer_name,
		  model_path, swa_model_path, loss, early_stop):
 
	train_dataset = ImageFolder(data_root, mode='train')
	val_dataset = ImageFolder(data_root, mode='val')
 
	train_data_loader = torch.utils.data.DataLoader(
		train_dataset,
		batch_size = BATCH_SIZE,
		shuffle=True,
		num_workers=0)
	
	val_data_loader = torch.utils.data.DataLoader(
		val_dataset,
		batch_size = BATCH_SIZE,
		shuffle=True,
		num_workers=0)
	
	# 定义模型,优化器,损失函数
	# model = smp.UnetPlusPlus(
	#         encoder_name="efficientnet-b7",
	#         encoder_weights="imagenet",
	#         in_channels=channels,
	#         classes=17,
	# )
	# model = smp.UnetPlusPlus(
	#         encoder_name="timm-resnest101e",
	#         encoder_weights="imagenet",
	#         in_channels=channels,
	#         classes=2,
	# )
	
	model = smp.Unet(
	encoder_name="resnext50_32x4d",        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7, resnet34
	encoder_weights="imagenet",     # use `imagenet` pre-trained weights for encoder initialization
	in_channels=channels,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
	classes=n_classes,                      # model output channels (number of classes in your dataset)
	activation='softmax',           #二分类需要换成sigmoid
	)
 	

	model.to(DEVICE)
    #加载预模型可以打开下面这句,model_path给预模型路径
	# model.load_state_dict(torch.load(model_path))
	if(optimizer_name == "sgd"):
		optimizer = torch.optim.SGD(model.parameters(), 
									lr=1e-4, weight_decay=1e-3, momentum=0.9)
	else:
		optimizer = torch.optim.AdamW(model.parameters(),
									  lr=1e-3, weight_decay=1e-3)
	# 余弦退火调整学习率
	scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
			optimizer, 
			T_0=2, # T_0就是初始restart的epoch数目
			T_mult=2, # T_mult就是重启之后因子,即每个restart后,T_0 = T_0 * T_mult
			eta_min=1e-5 # 最低学习率
			) 
   
	if(loss == "SoftCE_dice"):  #mode: Loss mode 'binary', 'multiclass' or 'multilabel'
		# 损失函数采用SoftCrossEntropyLoss+DiceLoss
		# diceloss在一定程度上可以缓解类别不平衡,但是训练容易不稳定
		# DiceLoss_fn = DiceLoss(mode='binary')
		DiceLoss_fn = DiceLoss(mode='multiclass')   #多分类改为multiclass
		#Bceloss_fn = nn.BCELoss()
		# 软交叉熵,即使用了标签平滑的交叉熵,会增加泛化性
		SoftCrossEntropy_fn=SoftCrossEntropyLoss(smooth_factor=0.1) #用于多分类
		loss_fn = L.JointLoss(first=DiceLoss_fn, second=SoftCrossEntropy_fn, first_weight=0.8, second_weight=0.2).cuda()
		# loss_fn = smp.utils.losses.DiceLoss()
	else: 
		# 损失函数采用SoftCrossEntropyLoss+LovaszLoss
		# LovaszLoss是对基于子模块损失凸Lovasz扩展的mIoU损失的直接优化
		# LovaszLoss_fn = LovaszLoss(mode='binary')
		LovaszLoss_fn = LovaszLoss(mode='multiclass')
		# 软交叉熵,即使用了标签平滑的交叉熵,会增加泛化性
		SoftCrossEntropy_fn=SoftCrossEntropyLoss(smooth_factor=0.1) #这里我没有改,这里是多分类的,有需求就改下
		loss_fn = L.JointLoss(first=LovaszLoss_fn, second=SoftCrossEntropy_fn,
							  first_weight=0.5, second_weight=0.5).cuda()
	
	
	best_miou = 0
	best_miou_epoch = 0
	train_loss_epochs, val_mIoU_epochs, lr_epochs = [], [], []
	for epoch in range(1, EPOCHES+1):
		losses = []
		start_time = time.time()
		model.train()
		for image, target in tqdm(train_data_loader, ncols=20, total=len(train_data_loader)):
			image, target = image.to(DEVICE), target.to(DEVICE)	
			output = model(image)
			target = torch.tensor(target, dtype=torch.int64)
			loss = loss_fn(output, target)
			losses.append(loss.item())
			optimizer.zero_grad()
			loss.backward()
			optimizer.step()
 
		scheduler.step()

		val_acc = []
		val_iou = []
		val_data_loader_num = iter(val_data_loader)
		for val_img, val_mask in tqdm(val_data_loader_num,ncols=20,total=len(val_data_loader_num)):
			val_img, label = val_img.to(DEVICE), val_mask.to(DEVICE)
			predict = model(val_img)
			label = label.squeeze(1)
			
			acc = multi_acc(predict, label)
			val_acc.append(acc.item())
			
			predict = torch.argmax(predict, axis=1)
			iou = iou_mean(predict, label, n_classes)
			val_iou.append(iou)
 
 
		train_loss_epochs.append(np.array(losses).mean())
		val_mIoU_epochs.append(np.mean(val_iou))
		lr_epochs.append(optimizer.param_groups[0]['lr'])
 
		print('Epoch:' + str(epoch) + ' Loss:' + str(np.array(losses).mean()) + ' Val_Acc:'+ str(np.array(val_acc).mean()) + ' Val_IOU:' + str(np.mean(val_iou)) + ' Time_use:' + str((time.time()-start_time)/60.0))
 
		if best_miou < np.stack(val_iou).mean(0).mean():
			best_miou = np.stack(val_iou).mean(0).mean()
			best_miou_epoch = epoch
			torch.save(model.state_dict(), model_path)
			print("  valid mIoU is improved. the model is saved.")
		else:
			print("")
			if (epoch - best_miou_epoch) >= early_stop:
				break
 
	return train_loss_epochs, val_mIoU_epochs, lr_epochs
 
 
 
if __name__ == '__main__':
	EPOCHES = 100
	BATCH_SIZE = 4
	loss = "SoftCE_dice"
	# loss = "SoftCE_Lovasz"
	channels = 3
	optimizer_name = "adamw"
 
	data_root = "./data/CamVid/"
	model_path = "./weights/CamVid_" + loss + '.pth'
	swa_model_path = model_path + "_swa.pth"
	early_stop = 400
	train_loss_epochs, val_mIoU_epochs, lr_epochs = train(EPOCHES, BATCH_SIZE, data_root, channels, optimizer_name, model_path, swa_model_path, loss,early_stop)
	
	if(True):    
		import matplotlib.pyplot as plt
		epochs = range(1, len(train_loss_epochs) + 1)
		plt.plot(epochs, train_loss_epochs, 'r', label = 'train loss')
		plt.plot(epochs, val_mIoU_epochs, 'b', label = 'val mIoU')
		plt.title('train loss and val mIoU')
		plt.legend()
		plt.savefig("train loss and val mIoU.png",dpi = 300)
		plt.figure()
		plt.plot(epochs, lr_epochs, 'r', label = 'learning rate')
		plt.title('learning rate')
		plt.legend()
		plt.savefig("learning rate.png", dpi = 300)
		plt.show() 

 4.预测

# -*- coding: utf-8 -*-
import os
import glob
import time
import cv2
import numpy as np
import torch
import segmentation_models_pytorch as smp
from torch.optim.swa_utils import AveragedModel
from data import ImageFolder
 
DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu' 
         
def test_1(channels, model_path, output_dir, test_path):
    # model = smp.UnetPlusPlus(
    #         encoder_name="resnet101",
    #         encoder_weights="imagenet",
    #         in_channels=4,
    #         classes=10,
    # )
    # model = smp.DeepLabV3Plus(
    #         encoder_name="resnet101",
    #         encoder_weights="imagenet",
    #         in_channels=in_channels,
    #         classes=1,
    # )
 
    model = smp.Unet(
    encoder_name="resnext50_32x4d",        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
    encoder_weights="imagenet",     # use `imagenet` pre-trained weights for encoder initialization
    in_channels=channels,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
    classes=12,                      # model output channels (number of classes in your dataset)
    activation='softmax',
    )
 
    # 如果模型是SWA
    if("swa" in model_path):
        model = AveragedModel(model)
    model.to(DEVICE);
    model.load_state_dict(torch.load(model_path))
    model.eval()
 
    im_names = os.listdir(test_path)
    for name in im_names:
        full_path = os.path.join(test_path, name)
        img = cv2.imread(full_path)
        h, w , c = img.shape
        #resize是因为训练的输入我resize成了512,后面有还原
        img = cv2.resize(img, (512, 512), interpolation = cv2.INTER_NEAREST)
        image = np.array(img, np.float32) / 255.0 * 3.2 - 1.6
        image = np.array(image, np.float32).transpose(2, 0, 1)
        image = np.expand_dims(image, axis=0)
        image = torch.Tensor(image)
        image = image.cuda()
        output = model(image)
        output = torch.argmax(output, axis=1).cpu().data.numpy()
        output = output.squeeze()
        output = cv2.resize(output, (w, h), interpolation = cv2.INTER_NEAREST)
        save_path = os.path.join(output_dir, name)
        cv2.imwrite(save_path, output)
        
if __name__ == "__main__":
    data_root = "./data/CamVid/test/"
    model_path = "./weights/CamVid_SoftCE_dice.pth"
    output_dir = './data/CamVid/test_pre/'
 
    test_1(3, model_path, output_dir, data_root)

 结果

UNet多分类记录_第3张图片   UNet多分类记录_第4张图片

                                      原图                                                             标签

UNet多分类记录_第5张图片

预测结果 

注意颜色是随机的,不是分错了哈,这个结果明显还没有训练到收敛,我提前停了,继续训练效果还能提升。 

你可能感兴趣的:(语义分割,pytorch,unet,语义分割)