2018-04-05

使用语义分割的方法实现息肉检测

模型介绍

U-net由全卷积神经网络FCN发展而来,分为编码器和解码器两部分,编码器负责提取图像中的特征,解码器利用上采样对像素进行分类,U-net常被用于医学图像分割当中。它的网络结构如下所示:


image

本文的网络模型在U-net网络模型的基础上进行修改,使用VGG16进行特征提取。网络模型如图:


网络结构图.jpg

使用adam算法进行训练,损失函数采用diceloss
    loss = smp.utils.losses.DiceLoss()
    metrics = [smp.utils.metrics.IoU(), smp.utils.metrics.Accuracy()]
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
from utils import *
import hiddenlayer as hl
from test2 import *
DATA_DIR = '../input/kvasirseg'
import segmentation_models_pytorch as smp

import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif']=['SimHei']  # 显示中文标签
plt.rcParams['axes.unicode_minus']=False

def one_hot_encode(label, label_values):
    semantic_map = []
    for colour in label_values:
        #np.equal实现把label每个像素的RGB值与某个class的RGB值进行比对,变成RGB bool值。
        equality = np.equal(label, colour)
        #np.all 把RGB bool值,变成一个bool值,即实现某个class 的label mask
        class_map = np.all(equality, axis=-1)
        semantic_map.append(class_map)
    #np.stack实现所有class的label mask的堆叠。最终depth size 为num_classes的数量。
    semantic_map = np.stack(semantic_map, axis=-1)

    return semantic_map

class EndoscopyDataset(torch.utils.data.Dataset):
    def __init__(
            self,
            df,
            class_rgb_values=None,
            augmentation=None,
    ):
        self.image_paths = df['image_path'].tolist()
        self.mask_paths = df['mask_path'].tolist()

        self.class_rgb_values = class_rgb_values
        self.augmentation = augmentation

    def __getitem__(self, i):

        # read images and masks
        image = cv2.cvtColor(cv2.imread(self.image_paths[i]), cv2.COLOR_BGR2RGB)
        mask = cv2.cvtColor(cv2.imread(self.mask_paths[i]), cv2.COLOR_BGR2RGB)
        mask = one_hot_encode(mask, self.class_rgb_values).astype('float32')

        # apply augmentations
        if self.augmentation:
            sample = self.augmentation(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']
        # one-hot-encode the mask

        image = image / 255

        mean = [0.5588, 0.3219, 0.2358]
        mean = np.array(mean)
        image = image - mean

        std = [0.3057, 0.2146, 0.1775]
        std = np.array(std)
        image = image / std

        image = np.transpose(image, (2, 0, 1)).astype('float32')
        mask = np.transpose(mask, (2, 0, 1)).astype('float32')


        return image, mask

    def __len__(self):
        # return length of
        return len(self.image_paths)

def visualize(**images):
    """
    Plot images in one row
    """
    n_images = len(images)
    plt.figure(figsize=(20, 8))
    for idx, (name, image) in enumerate(images.items()):
        plt.subplot(1, n_images, idx + 1)
        plt.xticks([])
        plt.yticks([])
        # get title from the parameter names
        plt.title(name.replace('_', ' ').title(), fontsize=20)
        plt.imshow(image)
    plt.show()

def get_training_augmentation():
    train_transform = [
        album.Resize(height=512, width=512, interpolation=cv2.INTER_CUBIC, always_apply=True),
        album.HorizontalFlip(p=0.5),
        album.VerticalFlip(p=0.5),
    ]
    return album.Compose(train_transform)


def get_validation_augmentation():
    # Add sufficient padding to ensure image is divisible by 32
    test_transform = [
        album.Resize(height=512, width=512, interpolation=cv2.INTER_CUBIC, always_apply=True),
    ]
    return album.Compose(test_transform)

def load_data(data_path):
    metadata_df = pd.read_csv(os.path.join(data_path, 'metadata.csv'))
    metadata_df = metadata_df[['image_id', 'image_path', 'mask_path']]
    metadata_df['image_path'] = metadata_df['image_path'].apply(lambda img_pth: os.path.join(data_path, img_pth))
    metadata_df['mask_path'] = metadata_df['mask_path'].apply(lambda img_pth: os.path.join(data_path, img_pth))
    metadata_df = metadata_df.sample(frac=1).reset_index(drop=True)
    # SAMPLE随机选取若干行
    # random_state=None,取得数据不重复
    # random_state=1,可以取得重复数据
    valid_df = metadata_df.sample(frac=0.1, random_state=42)
    train_df = metadata_df.drop(valid_df.index)
    return train_df, valid_df


if __name__ =="__main__":
    TRAINING = False

    # Get class RGB values
    select_class_rgb_values = [[255, 255, 255], [0, 0, 0]]

    # STD = [0.1767, 0.2144, 0.3059]
    # MEAN = [0.2341, 0.3202, 0.554]
    # MEAN = np.array(MEAN)
    # STD = np.array(STD)

    ENCODER = 'vgg16'
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    ENCODER_WEIGHTS = 'imagenet'
    CLASSES = ['polyp', 'background']
    ACTIVATION = 'sigmoid'  # could be None for logits or 'softmax2d' for multiclass segmentation

    # load best saved model checkpoint from the current run
    if os.path.exists('./my_best_model.pth'):
        model = torch.load('./my_best_model.pth')
        print('Loaded UNet model from previous run.')
    else:
        # create segmentation model with pretrained encoder
        model = smp.Unet(
            encoder_name=ENCODER,
            encoder_weights=ENCODER_WEIGHTS,
            classes=len(CLASSES),
            activation=ACTIVATION,
        )
        print('Create UNet model.')
    model.to(DEVICE)

    # hl_graph = hl.build_graph(model, torch.zeros([1, 3, 512, 512]).to(DEVICE))
    # hl_graph.save('./model.png', format='png')

    # print(model)
    preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)

    loss = smp.utils.losses.DiceLoss()
    metrics = [smp.utils.metrics.IoU(), smp.utils.metrics.Accuracy()]

    # loss = smp.utils.losses.DiceLoss(ignore_channels=[1])
    # metrics = [smp.utils.metrics.IoU(ignore_channels=[1])]
    # define optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    # optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)
    # optimizer = torch.optim.Adagrad(model.parameters(), lr=1e-4)

    # Set num of epochs
    EPOCHS = 20

    if TRAINING:



        x = []
        y_acc_valid = []
        y_acc_train = []

        y_loss_valid = []
        y_loss_train = []

        for i in range(EPOCHS):
            x.append(i+1)

        train_epoch = smp.utils.train.TrainEpoch(
            model,
            loss=loss,
            metrics=metrics,
            optimizer=optimizer,
            device=DEVICE,
            verbose=True,
        )

        valid_epoch = smp.utils.train.ValidEpoch(
            model,
            loss=loss,
            metrics=metrics,
            device=DEVICE,
            verbose=True,
        )

        best_iou_score = 0.0
        metric_name = 'iou_score'
        loss_name = 'dice_loss'
        train_logs_list, valid_logs_list = [], []
        train_df, valid_df = load_data(DATA_DIR)

        for i in range(0, EPOCHS):


            train_dataset = EndoscopyDataset(
                train_df,
                augmentation=get_training_augmentation(),
                class_rgb_values=select_class_rgb_values,
            )

            valid_dataset = EndoscopyDataset(
                valid_df,
                augmentation=get_validation_augmentation(),
                class_rgb_values=select_class_rgb_values,
            )
            # Get train and val data loaders
            train_loader = DataLoader(train_dataset, batch_size=3, shuffle=True, num_workers=2)
            valid_loader = DataLoader(valid_dataset, batch_size=2, shuffle=False, num_workers=2)

            # Perform training & validation
            print('\nEpoch: {}'.format(i))
            train_logs = train_epoch.run(train_loader)
            valid_logs = valid_epoch.run(valid_loader)
            train_logs_list.append(train_logs)
            valid_logs_list.append(valid_logs)

            y_acc_train.append(train_logs[metric_name])
            y_acc_valid.append(valid_logs[metric_name])

            y_loss_train.append(train_logs[loss_name])
            y_loss_valid.append(valid_logs[loss_name])

            # Save model if a better val IoU score is obtained
            if best_iou_score < valid_logs[metric_name]:
                best_iou_score = valid_logs[metric_name]
                torch.save(model, './my_best_model.pth')
                print('Model saved!')

        plt.figure()
        plt.plot(x, y_acc_train, '-ro', label="训练集IoU", linestyle='--')
        plt.plot(x, y_acc_valid, '-bo',  label="测试集IoU")
        plt.xticks(np.arange(1, EPOCHS+1, 1))
        plt.yticks(np.arange(0, 1.1, 0.1))
        plt.xlabel("Epoch")
        plt.ylabel('Iou')
        # plt.ylim((0, 1))

        # for a, b in zip(x, y_acc_train):
        #     plt.text(a, b, round(b, 3), ha='center', va='bottom', fontsize=10)
        # for a, b in zip(x, y_acc_valid):
        #     plt.text(a, b, round(b, 3), ha='center', va='bottom', fontsize=10)
        #
        #添加图例
        plt.legend(loc='lower right')
        plt.show()

        plt.figure()
        plt.plot(x, y_loss_train, '-ro', label="训练集loss", linestyle='--')
        plt.plot(x, y_loss_valid, '-bo', label="测试集loss")
        plt.xticks(np.arange(1, EPOCHS + 1, 1))
        plt.yticks(np.arange(0, 1.1, 0.1))
        plt.xlabel("Epoch")
        plt.ylabel('Loss')

        # for a, b in zip(x, y_loss_train):
        #     plt.text(a, b, round(b, 3), ha='center', va='bottom', fontsize=10)
        # for a, b in zip(x, y_loss_valid):
        #     plt.text(a, b, round(b, 3), ha='center', va='bottom', fontsize=10)

        # 添加图例
        plt.legend(loc='upper right')
        # plt.ylim((0, 1))
        plt.show()
    else:
            # create test dataloader to be used with UNet model (with preprocessing operation: to_tensor(...))
            train_df, valid_df = load_data(DATA_DIR)
            test_dataset = EndoscopyDataset(
                valid_df,
                augmentation=get_validation_augmentation(),
                class_rgb_values=select_class_rgb_values,
            )
            test_dataset_vis = EndoscopyDataset(
                valid_df,
                class_rgb_values=select_class_rgb_values,
            )
            for idx in range(len(test_dataset)):
                image, gt_mask = test_dataset[idx]
                STD = [0.3057, 0.2146, 0.1775]
                MEAN = [0.5588, 0.3219, 0.2358]
                image_vis = test_dataset_vis[idx][0]
                image_vis = np.transpose(image_vis, (1, 2, 0))
                image_vis = ((image_vis * STD+MEAN)*255.0).astype('uint8')

                x_tensor = torch.from_numpy(image).to(DEVICE).unsqueeze(0)
                # Predict test image

                pred_mask = model(x_tensor)
                pred_mask = pred_mask.detach().squeeze().cpu().numpy()
                # Convert pred_mask from `CHW` format to `HWC` format


                pred_mask = np.transpose(pred_mask, (1, 2, 0))
                # Get prediction channel corresponding to foreground
                pred_mask = colour_code_segmentation(reverse_one_hot(pred_mask), select_class_rgb_values)

                # Convert gt_mask from `CHW` format to `HWC` format
                gt_mask = np.transpose(gt_mask, (1, 2, 0))
                gt_mask = colour_code_segmentation(reverse_one_hot(gt_mask), select_class_rgb_values)
                visualize(
                    original_image=image_vis,
                    ground_truth_mask=gt_mask,
                    predicted_mask=pred_mask,
                )

utils.py

import torch
from torch.utils.data import DataLoader
import albumentations as album
import numpy as np
np.set_printoptions(threshold=np.inf)
import cv2, os
import pandas as pd
import numpy as np


#加**时,返回为字典,输入指定pred_mask = predmask,则该函数中images为字典。
def getImg(**images):
    imglist = []
    for idx, (name, image) in enumerate(images.items()):
        imglist.append(image)
    return imglist
        # plt.subplot(1, n_images, idx + 1)
        # print("name:", name)
        # print(type(image))
        # plt.imshow(image)


# Perform one hot encoding on label
def one_hot_encode(label, label_values):
    """
    Convert a segmentation image label array to one-hot format
    by replacing each pixel value with a vector of length num_classes
    # Arguments
        label: The 2D array segmentation image label
        label_values

    # Returns
        A 2D array with the same width and hieght as the input, but
        with a depth size of num_classes
    """
    semantic_map = []
    for colour in label_values:
        equality = np.equal(label, colour)
        class_map = np.all(equality, axis=-1)
        semantic_map.append(class_map)
    semantic_map = np.stack(semantic_map, axis=-1)

    return semantic_map


# Perform reverse one-hot-encoding on labels / preds
def reverse_one_hot(image):
    """
    Transform a 2D array in one-hot format (depth is num_classes),
    to a 2D array with only 1 channel, where each pixel value is
    the classified class key.
    # Arguments
        image: The one-hot format image

    # Returns
        A 2D array with the same width and hieght as the input, but
        with a depth size of 1, where each pixel value is the classified
        class key.
    """
    x = np.argmax(image, axis=-1)
    return x


# Perform colour coding on the reverse-one-hot outputs
def colour_code_segmentation(image, label_values):
    """
    Given a 1-channel array of class keys, colour code the segmentation results.
    # Arguments
        image: single channel array where each value represents the class key.
        label_values

    # Returns
        Colour coded image for segmentation visualization
    """
    colour_codes = np.array(label_values)
    x = colour_codes[image.astype(int)]

    return x


def get_training_augmentation():
    train_transform = [
        album.Resize(height=528, width=624, interpolation=cv2.INTER_CUBIC, always_apply=True),
        album.HorizontalFlip(p=0.5),
        album.VerticalFlip(p=0.5),
    ]
    return album.Compose(train_transform)


def get_validation_augmentation():
    # Add sufficient padding to ensure image is divisible by 32
    test_transform = [
        album.Resize(height=528, width=624, interpolation=cv2.INTER_CUBIC, always_apply=True),
    ]
    return album.Compose(test_transform)


def to_tensor(x, **kwargs):
    return x.transpose(2, 0, 1).astype('float32')


def get_preprocessing(preprocessing_fn=None):
    """Construct preprocessing transform
    Args:
        preprocessing_fn (callable): data normalization function
            (can be specific for each pretrained neural network)
    Return:
        transform: albumentations.Compose
    """
    _transform = []
    if preprocessing_fn:
        _transform.append(album.Lambda(image=preprocessing_fn))
    _transform.append(album.Lambda(image=to_tensor, mask=to_tensor))

    return album.Compose(_transform)


class EndoscopyDataset(torch.utils.data.Dataset):
    def __init__(self, df, class_rgb_values=None, augmentation=None, preprocessing=None):
        self.image_paths = df['image_path'].tolist()
        self.mask_paths = df['mask_path'].tolist()

        self.class_rgb_values = class_rgb_values
        self.augmentation = augmentation
        self.preprocessing = preprocessing

    def __getitem__(self, i):

        # read images and masks
        image = cv2.cvtColor(cv2.imread(self.image_paths[i]), cv2.COLOR_BGR2RGB)
        mask = cv2.cvtColor(cv2.imread(self.mask_paths[i]), cv2.COLOR_BGR2RGB)

        # one-hot-encode the mask
        mask = one_hot_encode(mask, self.class_rgb_values).astype('float')

        # apply augmentations
        if self.augmentation:
            sample = self.augmentation(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']

        # apply preprocessing
        if self.preprocessing:
            sample = self.preprocessing(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']

        return image, mask

    def __len__(self):
        # return length of
        return len(self.image_paths)

def load_data(data_path):
    metadata_df = pd.read_csv(os.path.join(data_path, 'metadata.csv'))
    metadata_df = metadata_df[['image_id', 'image_path', 'mask_path']]
    metadata_df['image_path'] = metadata_df['image_path'].apply(lambda img_pth: os.path.join(data_path, img_pth))
    metadata_df['mask_path'] = metadata_df['mask_path'].apply(lambda img_pth: os.path.join(data_path, img_pth))
    metadata_df = metadata_df.sample(frac=1).reset_index(drop=True)
    # SAMPLE随机选取若干行
    # random_state=None,取得数据不重复
    # random_state=1,可以取得重复数据
    valid_df = metadata_df.sample(frac=0.1, random_state=42)
    train_df = metadata_df.drop(valid_df.index)
    return train_df, valid_df

参考文章

你可能感兴趣的:(2018-04-05)