Pytorch之图像分割(单个目标分割,Single Object Segmentation)

示例数据为Feta-Head-Circumference
下载地址: https://zenodo.org/record/1322001#.YTHD2Y4zaUl

Feta-Head-Circumference.png

模型结构 U-Net
U-Net

扩展阅读:https://github.com/pranjalrai-iitd/Fetal-head-segmentation-and-circumference-measurement-from-ultrasound-images

引入包

%matplotlib inline
import matplotlib.pyplot as plt
import matplotlib.pylab as plab
from PIL import Image, ImageDraw
import numpy as np
import pandas as pd
import os
import copy
import collections
from sklearn.model_selection import ShuffleSplit
from scipy import ndimage as ndi
from skimage.segmentation import mark_boundaries

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split, Subset
import torchvision.transforms as transforms
from torchvision import models,utils, datasets
import torch.nn.functional as F
from torchvision.transforms.functional import to_tensor, to_pil_image
from torch import optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from albumentations import (HorizontalFlip, VerticalFlip, Compose, Resize,)
from torchsummary import summary

# CPU or GPU
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# dataloader里的多进程用到num_workers
workers = 0 if os.name=='nt' else 4

数据初探

# 数据地址
path_train="./data/sos/training_set/"

imgs_list = [pp for pp in os.listdir(path_train) if "Annotation" not in pp and pp.endswith('.png')]
annts_list = [pp for pp in os.listdir(path_train) if "Annotation" in pp and pp.endswith('.png')]
print("number of images:", len(imgs_list))
print("number of annotations:", len(annts_list))
"""
number of images: 999
number of annotations: 999
"""

# 查看一些图片
np.random.seed(2019)
rnd_imgs = np.random.choice(imgs_list, 4)
print('The random images are: ', rnd_imgs)
# The random images are:  ['166_2HC.png' '434_HC.png' '244_HC.png' '826_3HC.png']
# 可视化图片
def show_img_mask(img, mask):
    if torch.is_tensor(img):
        img = to_pil_image(img)
        mask = to_pil_image(mask)
        
    img_mask = mark_boundaries(
                np.array(img), 
                np.array(mask),
                outline_color=(0,1,0),
                color=(0,1,0)
            )
    plt.imshow(img_mask)

# 画图查看图片    
for fn in rnd_imgs:
    img_path = os.path.join(path_train, fn)
    annt_path = img_path.replace(".png", "_Annotation.png")
    
    img = Image.open(img_path)
    annt_edges = Image.open(annt_path)
    mask = ndi.binary_fill_holes(annt_edges)        

    plt.figure(figsize=(10, 10))
    plt.subplot(1, 3, 1) 
    plt.imshow(img, cmap="gray")

    plt.subplot(1, 3, 2) 
    plt.imshow(mask, cmap="gray")

    plt.subplot(1, 3, 3) 
    show_img_mask(img, mask)
data status

构建Dataset,Transforms,DataLoader

# transforms
h, w = 128, 192
transform_train = Compose([ Resize(h, w), 
                HorizontalFlip(p=0.5), 
                VerticalFlip(p=0.5), 
              ])

transform_val = Resize(h, w)

# 创建datasets
class FetalDataset(Dataset):
    def __init__(self, path_data, transform=None):
        imgs_list = [pp for pp in os.listdir(path_train) if "Annotation" not in pp and pp.endswith('.png')]
        annts_list = [pp for pp in os.listdir(path_train) if "Annotation" in pp and pp.endswith('.png')]
        self.path_imgs = [os.path.join(path_data, fn) for fn in imgs_list]
        self.path_annts = [path_img.replace('.png', '_Annotation.png') for path_img in self.path_imgs]
        self.transform = transform
        
    def __len__(self):
        return len(self.path_imgs)
    
    def __getitem__(self, idx):
        path_img = self.path_imgs[idx]
        image = Image.open(path_img)
        path_annt = self.path_annts[idx]
        annt_edges = Image.open(path_annt)
        mask = ndi.binary_fill_holes(annt_edges)
        image = np.array(image)
        mask = mask.astype('uint8')
        
        if self.transform:
            augmented = self.transform(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']
            
        image = to_tensor(image)
        mask = 255 * to_tensor(mask)
        
        return image, mask

# 实例化dataset
fetal_train_ds = FetalDataset(path_train, transform=transform_train)
fetal_val_ds = FetalDataset(path_train, transform=transform_val)
# print(len(fetal_train_ds))
# print(len(fetal_val_ds))

# 数据分割为训练验证集
sss = ShuffleSplit(n_splits=1, test_size=0.2, random_state=0)
indices = range(len(fetal_train_ds))

for train_index, val_index in sss.split(indices):
    train_ds = Subset(fetal_train_ds, train_index)
    print(len(train_ds))

    val_ds = Subset(fetal_val_ds, val_index)
    print(len(val_ds))

plt.figure(figsize=(5,5))
for img,mask in train_ds:
    show_img_mask(img,mask)
    break
    
# 构建dataloader
train_dl = DataLoader(train_ds, batch_size=8, shuffle=True)
val_dl = DataLoader(val_ds, batch_size=16, shuffle=False)

# 打印出数据查看
for img, mask in train_dl:
    print(img.shape, img.dtype)
    # torch.Size([8, 1, 128, 192]) torch.float32
    print(mask.shape, mask.dtype)
    # torch.Size([8, 1, 128, 192]) torch.float32
    break
"""
799
200
torch.Size([8, 1, 128, 192]) torch.float32
torch.Size([8, 1, 128, 192]) torch.float32
"""
转换后图片

模型定义

# 定义模型 encoder-decoder model  U-Net
class SegNet(nn.Module):
    def __init__(self, params):
        super(SegNet, self).__init__()
        C_in, H_in, W_in = params['input_shape']
        init_f = params['initial_filters']
        num_outputs = params['num_outputs']
        # 定义各卷积层
        self.conv1 = nn.Conv2d(C_in, init_f, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(init_f, 2*init_f, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(2*init_f, 4*init_f, kernel_size=3, stride=1, padding=1)
        self.conv4 = nn.Conv2d(4*init_f, 8*init_f, kernel_size=3, stride=1, padding=1)
        self.conv5 = nn.Conv2d(8*init_f, 16*init_f, kernel_size=3, stride=1, padding=1)
        # 定义上采样层
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        
        self.conv_up1 = nn.Conv2d(16*init_f, 8*init_f, kernel_size=3, stride=1, padding=1)
        self.conv_up2 = nn.Conv2d(8*init_f, 4*init_f, kernel_size=3, stride=1, padding=1)
        self.conv_up3 = nn.Conv2d(4*init_f, 2*init_f, kernel_size=3, stride=1, padding=1)
        self.conv_up4 = nn.Conv2d(2*init_f, init_f, kernel_size=3, stride=1, padding=1)
        
        self.conv_out = nn.Conv2d(init_f, num_outputs, kernel_size=3, padding=1)
        
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        
        x = F.relu(self.conv3(x))
        x = F.max_pool2d(x, 2, 2)

        x = F.relu(self.conv4(x))
        x = F.max_pool2d(x, 2, 2)

        x = F.relu(self.conv5(x))
        
        x = self.upsample(x)
        x = F.relu(self.conv_up1(x))

        x = self.upsample(x)
        x = F.relu(self.conv_up2(x))
        
        x = self.upsample(x)
        x = F.relu(self.conv_up3(x))
        
        x = self.upsample(x)
        x = F.relu(self.conv_up4(x))

        x = self.conv_out(x)
        
        return x
        
params_model={
        "input_shape": (1, 128, 192),
        "initial_filters": 16, 
        "num_outputs": 1,
            }

model = SegNet(params_model).to(device)
# print(model)
# """
# SegNet(
#   (conv1): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
#   (conv2): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
#   (conv3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
#   (conv4): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
#   (conv5): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
#   (upsample): Upsample(scale_factor=2.0, mode=bilinear)
#   (conv_up1): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
#   (conv_up2): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
#   (conv_up3): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
#   (conv_up4): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
#   (conv_out): Conv2d(16, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# )
# """
# 查看模型信息
summary(model, input_size=(1, 128, 192))

# ----------------------------------------------------------------
#         Layer (type)               Output Shape         Param #
# ================================================================
#             Conv2d-1         [-1, 16, 128, 192]             160
#             Conv2d-2           [-1, 32, 64, 96]           4,640
#             Conv2d-3           [-1, 64, 32, 48]          18,496
#             Conv2d-4          [-1, 128, 16, 24]          73,856
#             Conv2d-5           [-1, 256, 8, 12]         295,168
#           Upsample-6          [-1, 256, 16, 24]               0
#             Conv2d-7          [-1, 128, 16, 24]         295,040
#           Upsample-8          [-1, 128, 32, 48]               0
#             Conv2d-9           [-1, 64, 32, 48]          73,792
#          Upsample-10           [-1, 64, 64, 96]               0
#            Conv2d-11           [-1, 32, 64, 96]          18,464
#          Upsample-12         [-1, 32, 128, 192]               0
#            Conv2d-13         [-1, 16, 128, 192]           4,624
#            Conv2d-14          [-1, 1, 128, 192]             145
# ================================================================
# Total params: 784,385
# Trainable params: 784,385
# Non-trainable params: 0
# ----------------------------------------------------------------
# Input size (MB): 0.09
# Forward/backward pass size (MB): 22.88
# Params size (MB): 2.99
# Estimated Total Size (MB): 25.96
# ----------------------------------------------------------------

定义损失函数 Dice metric

Dice系数, 根据 Lee Raymond Dice命名,是一种集合相似度度量函数,通常用于计算两个样本的相似度(值范围为 [0, 1]):


dice coefficient

|X⋂Y| - X 和 Y 之间的交集;|X| 和 |Y| 分别表示 X 和 Y 的元素个数. 其中,分子中的系数 2,是因为分母存在重复计算 X 和 Y 之间的共同元素的原因.

Dice 系数差异函数(Dice loss):


Dice loss.png
## 定义损失函数
# Dice系数是一种集合相似度度量函数,通常用于计算两个样本的相似度,取值范围在[0,1]
# https://blog.csdn.net/JMU_Ma/article/details/97533768  , https://zhuanlan.zhihu.com/p/86704421
def dice_loss(pred, target, smooth = 1e-5):
    intersection = (pred * target).sum(dim=(2,3))
    union = pred.sum(dim=(2,3)) + target.sum(dim=(2,3)) 
    
    dice = 2.0 * (intersection + smooth) / (union+ smooth)    
    loss = 1.0 - dice
    
    return loss.sum(), dice.sum()

def loss_func(pred, target):
    bce = F.binary_cross_entropy_with_logits(pred, target,  reduction='sum')
    
    pred = torch.sigmoid(pred)
    dlv, _ = dice_loss(pred, target)
    
    loss = bce  + dlv

    return loss

模型设计及训练

定义几个计算辅助函数
# 取得学习率
def get_lr(opt):
    for param_group in opt.param_groups:
        return param_group['lr']
# 定义评价函数
def metrics_batch(pred, target):
    pred = torch.sigmoid(pred)
    _, metric = dice_loss(pred, target)
    
    return metric

# 各批次损失计算
def loss_batch(loss_func, output, target, opt=None):   
    loss = loss_func(output, target)
    
    with torch.no_grad():
        pred = torch.sigmoid(output)
        _, metric_b = dice_loss(pred, target)
    
    if opt is not None:
        opt.zero_grad()
        loss.backward()
        opt.step()

    return loss.item(), metric_b

# 各轮次计算
def loss_epoch(model,loss_func,dataset_dl,sanity_check=False,opt=None):
    running_loss = 0.0
    running_metric = 0.0
    len_data = len(dataset_dl.dataset)

    for xb, yb in dataset_dl:
        xb = xb.to(device)
        yb = yb.to(device)
        
        output = model(xb)
        loss_b, metric_b = loss_batch(loss_func, output, yb, opt)
        running_loss += loss_b
        
        if metric_b is not None:
            running_metric += metric_b

        if sanity_check is True:
            break
    
    loss = running_loss / float(len_data)
    
    metric = running_metric / float(len_data)
    
    return loss, metric
模型定义
# 定义模型 encoder-decoder model  U-Net
class SegNet(nn.Module):
    def __init__(self, params):
        super(SegNet, self).__init__()
        C_in, H_in, W_in = params['input_shape']
        init_f = params['initial_filters']
        num_outputs = params['num_outputs']
        # 定义各卷积层
        self.conv1 = nn.Conv2d(C_in, init_f, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(init_f, 2*init_f, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(2*init_f, 4*init_f, kernel_size=3, stride=1, padding=1)
        self.conv4 = nn.Conv2d(4*init_f, 8*init_f, kernel_size=3, stride=1, padding=1)
        self.conv5 = nn.Conv2d(8*init_f, 16*init_f, kernel_size=3, stride=1, padding=1)
        # 定义上采样层
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        
        self.conv_up1 = nn.Conv2d(16*init_f, 8*init_f, kernel_size=3, stride=1, padding=1)
        self.conv_up2 = nn.Conv2d(8*init_f, 4*init_f, kernel_size=3, stride=1, padding=1)
        self.conv_up3 = nn.Conv2d(4*init_f, 2*init_f, kernel_size=3, stride=1, padding=1)
        self.conv_up4 = nn.Conv2d(2*init_f, init_f, kernel_size=3, stride=1, padding=1)
        
        self.conv_out = nn.Conv2d(init_f, num_outputs, kernel_size=3, padding=1)
        
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        
        x = F.relu(self.conv3(x))
        x = F.max_pool2d(x, 2, 2)

        x = F.relu(self.conv4(x))
        x = F.max_pool2d(x, 2, 2)

        x = F.relu(self.conv5(x))
        
        x = self.upsample(x)
        x = F.relu(self.conv_up1(x))

        x = self.upsample(x)
        x = F.relu(self.conv_up2(x))
        
        x = self.upsample(x)
        x = F.relu(self.conv_up3(x))
        
        x = self.upsample(x)
        x = F.relu(self.conv_up4(x))

        x = self.conv_out(x)
        
        return x
        
params_model={
        "input_shape": (1, 128, 192),
        "initial_filters": 16, 
        "num_outputs": 1,
            }

model = SegNet(params_model).to(device)
# print(model)
# """
# SegNet(
#   (conv1): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
#   (conv2): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
#   (conv3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
#   (conv4): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
#   (conv5): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
#   (upsample): Upsample(scale_factor=2.0, mode=bilinear)
#   (conv_up1): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
#   (conv_up2): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
#   (conv_up3): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
#   (conv_up4): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
#   (conv_out): Conv2d(16, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# )
# """
# 查看模型信息
summary(model, input_size=(1, 128, 192))

# ----------------------------------------------------------------
#         Layer (type)               Output Shape         Param #
# ================================================================
#             Conv2d-1         [-1, 16, 128, 192]             160
#             Conv2d-2           [-1, 32, 64, 96]           4,640
#             Conv2d-3           [-1, 64, 32, 48]          18,496
#             Conv2d-4          [-1, 128, 16, 24]          73,856
#             Conv2d-5           [-1, 256, 8, 12]         295,168
#           Upsample-6          [-1, 256, 16, 24]               0
#             Conv2d-7          [-1, 128, 16, 24]         295,040
#           Upsample-8          [-1, 128, 32, 48]               0
#             Conv2d-9           [-1, 64, 32, 48]          73,792
#          Upsample-10           [-1, 64, 64, 96]               0
#            Conv2d-11           [-1, 32, 64, 96]          18,464
#          Upsample-12         [-1, 32, 128, 192]               0
#            Conv2d-13         [-1, 16, 128, 192]           4,624
#            Conv2d-14          [-1, 1, 128, 192]             145
# ================================================================
# Total params: 784,385
# Trainable params: 784,385
# Non-trainable params: 0
# ----------------------------------------------------------------
# Input size (MB): 0.09
# Forward/backward pass size (MB): 22.88
# Params size (MB): 2.99
# Estimated Total Size (MB): 25.96
# ----------------------------------------------------------------

模型训练与验证

模型训练主函数
# 训练验证主函数
def train_val(model, params):
    num_epochs = params["num_epochs"]
    loss_func = params["loss_func"]
    opt = params["optimizer"]
    train_dl = params["train_dl"]
    val_dl = params["val_dl"]
    sanity_check = params["sanity_check"]
    lr_scheduler = params["lr_scheduler"]
    path2weights = params["path2weights"]
    
    loss_history = {
        "train": [],
        "val": []}
    
    metric_history = {
        "train": [],
        "val": []}    
    
    
    best_model_wts = copy.deepcopy(model.state_dict())
    best_loss = float('inf')    
    
    for epoch in range(num_epochs):
        current_lr = get_lr(opt)
        print('Epoch {}/{}, current lr={}'.format(epoch, num_epochs - 1, current_lr))   

        model.train()
        train_loss, train_metric = loss_epoch(model,loss_func,train_dl,sanity_check,opt)

        loss_history["train"].append(train_loss)
        metric_history["train"].append(train_metric)
        
        model.eval()
        with torch.no_grad():
            val_loss, val_metric = loss_epoch(model,loss_func,val_dl,sanity_check)
       
        loss_history["val"].append(val_loss)
        metric_history["val"].append(val_metric)   
        
        if val_loss < best_loss:
            best_loss = val_loss
            best_model_wts = copy.deepcopy(model.state_dict())
            
            torch.save(model.state_dict(), path2weights)
            print("Copied best model weights!")
            
        lr_scheduler.step(val_loss)
        if current_lr != get_lr(opt):
            print("Loading best model weights!")
            model.load_state_dict(best_model_wts) 
            
        print("train loss: %.6f, accuracy: %.2f" %(train_loss, 100*train_metric))
        print("val loss: %.6f, accuracy: %.2f" %(val_loss, 100*val_metric))
        print("-"*10) 
        

    model.load_state_dict(best_model_wts)
    return model, loss_history, metric_history
模型训练
# 优化函数及学习率更新策略
opt = optim.Adam(model.parameters(), lr=3e-4)
lr_scheduler = ReduceLROnPlateau(opt, mode='min',factor=0.5, patience=20,verbose=1)

path_models = "./models/sos/"
if not os.path.exists(path_models):
    os.mkdir(path_models)

params_train={
    "num_epochs": 10,
    "optimizer": opt,
    "loss_func": loss_func,
    "train_dl": train_dl,
    "val_dl": val_dl,
    "sanity_check": False,
    "lr_scheduler": lr_scheduler,
    "path2weights": path_models+"weights.pt",
}

model, loss_hist, metric_hist = train_val(model,params_train)

可视化结果

num_epochs=params_train["num_epochs"]

plt.title("Train-Val Loss")
plt.plot(range(1,num_epochs+1),loss_hist["train"],label="train")
plt.plot(range(1,num_epochs+1),loss_hist["val"],label="val")
plt.ylabel("Loss")
plt.xlabel("Training Epochs")
plt.legend()
plt.show()

# plot accuracy progress
plt.title("Train-Val Accuracy")
plt.plot(range(1,num_epochs+1),metric_hist["train"],label="train")
plt.plot(range(1,num_epochs+1),metric_hist["val"],label="val")
plt.ylabel("Accuracy")
plt.xlabel("Training Epochs")
plt.legend()
plt.show()
Train-Val Loss.png
Train-Val Accuracy.png

部署测试

# 部署并对测试数据进行测试验证 
# 部署前需要加载model的网络结构,这里因为前面model已存在,所以未实例化
np.random.seed(2019)
path_test = './data/sos/test_set/'
imgs_list = [pp for pp in os.listdir(path_test) if "Annotation" not in pp]

rnd_imgs = np.random.choice(imgs_list, 4)
print(rnd_imgs)

model_weights_path = './models/sos/weights.pt'
model.load_state_dict(torch.load(model_weights_path))
model.eval()


for fn in rnd_imgs:
    path_img = os.path.join(path_test, fn)
    img = Image.open(path_img)
    img = img.resize((w,h))
    img_t = to_tensor(img).unsqueeze(0).to(device)
    
    pred = model(img_t)
    pred = torch.sigmoid(pred)[0]
    mask_pred = (pred[0]>=0.5)

    plt.figure(figsize=(10, 10))
    plt.subplot(1, 3, 1) 
    plt.imshow(img, cmap="gray")

    plt.subplot(1, 3, 2) 
    plt.imshow(mask_pred.cpu(), cmap="gray")
    
    plt.subplot(1, 3, 3) 
    show_img_mask(img, mask_pred.cpu())
test data result

你可能感兴趣的:(Pytorch之图像分割(单个目标分割,Single Object Segmentation))