使用语义分割的方法实现息肉检测
模型介绍
U-net由全卷积神经网络FCN发展而来,分为编码器和解码器两部分,编码器负责提取图像中的特征,解码器利用上采样对像素进行分类,U-net常被用于医学图像分割当中。它的网络结构如下所示:
本文的网络模型在U-net网络模型的基础上进行修改,使用VGG16进行特征提取。网络模型如图:
使用adam算法进行训练,损失函数采用diceloss
loss = smp.utils.losses.DiceLoss()
metrics = [smp.utils.metrics.IoU(), smp.utils.metrics.Accuracy()]
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
from utils import *
import hiddenlayer as hl
from test2 import *
DATA_DIR = '../input/kvasirseg'
import segmentation_models_pytorch as smp
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif']=['SimHei'] # 显示中文标签
plt.rcParams['axes.unicode_minus']=False
def one_hot_encode(label, label_values):
semantic_map = []
for colour in label_values:
#np.equal实现把label每个像素的RGB值与某个class的RGB值进行比对,变成RGB bool值。
equality = np.equal(label, colour)
#np.all 把RGB bool值,变成一个bool值,即实现某个class 的label mask
class_map = np.all(equality, axis=-1)
semantic_map.append(class_map)
#np.stack实现所有class的label mask的堆叠。最终depth size 为num_classes的数量。
semantic_map = np.stack(semantic_map, axis=-1)
return semantic_map
class EndoscopyDataset(torch.utils.data.Dataset):
def __init__(
self,
df,
class_rgb_values=None,
augmentation=None,
):
self.image_paths = df['image_path'].tolist()
self.mask_paths = df['mask_path'].tolist()
self.class_rgb_values = class_rgb_values
self.augmentation = augmentation
def __getitem__(self, i):
# read images and masks
image = cv2.cvtColor(cv2.imread(self.image_paths[i]), cv2.COLOR_BGR2RGB)
mask = cv2.cvtColor(cv2.imread(self.mask_paths[i]), cv2.COLOR_BGR2RGB)
mask = one_hot_encode(mask, self.class_rgb_values).astype('float32')
# apply augmentations
if self.augmentation:
sample = self.augmentation(image=image, mask=mask)
image, mask = sample['image'], sample['mask']
# one-hot-encode the mask
image = image / 255
mean = [0.5588, 0.3219, 0.2358]
mean = np.array(mean)
image = image - mean
std = [0.3057, 0.2146, 0.1775]
std = np.array(std)
image = image / std
image = np.transpose(image, (2, 0, 1)).astype('float32')
mask = np.transpose(mask, (2, 0, 1)).astype('float32')
return image, mask
def __len__(self):
# return length of
return len(self.image_paths)
def visualize(**images):
"""
Plot images in one row
"""
n_images = len(images)
plt.figure(figsize=(20, 8))
for idx, (name, image) in enumerate(images.items()):
plt.subplot(1, n_images, idx + 1)
plt.xticks([])
plt.yticks([])
# get title from the parameter names
plt.title(name.replace('_', ' ').title(), fontsize=20)
plt.imshow(image)
plt.show()
def get_training_augmentation():
train_transform = [
album.Resize(height=512, width=512, interpolation=cv2.INTER_CUBIC, always_apply=True),
album.HorizontalFlip(p=0.5),
album.VerticalFlip(p=0.5),
]
return album.Compose(train_transform)
def get_validation_augmentation():
# Add sufficient padding to ensure image is divisible by 32
test_transform = [
album.Resize(height=512, width=512, interpolation=cv2.INTER_CUBIC, always_apply=True),
]
return album.Compose(test_transform)
def load_data(data_path):
metadata_df = pd.read_csv(os.path.join(data_path, 'metadata.csv'))
metadata_df = metadata_df[['image_id', 'image_path', 'mask_path']]
metadata_df['image_path'] = metadata_df['image_path'].apply(lambda img_pth: os.path.join(data_path, img_pth))
metadata_df['mask_path'] = metadata_df['mask_path'].apply(lambda img_pth: os.path.join(data_path, img_pth))
metadata_df = metadata_df.sample(frac=1).reset_index(drop=True)
# SAMPLE随机选取若干行
# random_state=None,取得数据不重复
# random_state=1,可以取得重复数据
valid_df = metadata_df.sample(frac=0.1, random_state=42)
train_df = metadata_df.drop(valid_df.index)
return train_df, valid_df
if __name__ =="__main__":
TRAINING = False
# Get class RGB values
select_class_rgb_values = [[255, 255, 255], [0, 0, 0]]
# STD = [0.1767, 0.2144, 0.3059]
# MEAN = [0.2341, 0.3202, 0.554]
# MEAN = np.array(MEAN)
# STD = np.array(STD)
ENCODER = 'vgg16'
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ENCODER_WEIGHTS = 'imagenet'
CLASSES = ['polyp', 'background']
ACTIVATION = 'sigmoid' # could be None for logits or 'softmax2d' for multiclass segmentation
# load best saved model checkpoint from the current run
if os.path.exists('./my_best_model.pth'):
model = torch.load('./my_best_model.pth')
print('Loaded UNet model from previous run.')
else:
# create segmentation model with pretrained encoder
model = smp.Unet(
encoder_name=ENCODER,
encoder_weights=ENCODER_WEIGHTS,
classes=len(CLASSES),
activation=ACTIVATION,
)
print('Create UNet model.')
model.to(DEVICE)
# hl_graph = hl.build_graph(model, torch.zeros([1, 3, 512, 512]).to(DEVICE))
# hl_graph.save('./model.png', format='png')
# print(model)
preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)
loss = smp.utils.losses.DiceLoss()
metrics = [smp.utils.metrics.IoU(), smp.utils.metrics.Accuracy()]
# loss = smp.utils.losses.DiceLoss(ignore_channels=[1])
# metrics = [smp.utils.metrics.IoU(ignore_channels=[1])]
# define optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
# optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)
# optimizer = torch.optim.Adagrad(model.parameters(), lr=1e-4)
# Set num of epochs
EPOCHS = 20
if TRAINING:
x = []
y_acc_valid = []
y_acc_train = []
y_loss_valid = []
y_loss_train = []
for i in range(EPOCHS):
x.append(i+1)
train_epoch = smp.utils.train.TrainEpoch(
model,
loss=loss,
metrics=metrics,
optimizer=optimizer,
device=DEVICE,
verbose=True,
)
valid_epoch = smp.utils.train.ValidEpoch(
model,
loss=loss,
metrics=metrics,
device=DEVICE,
verbose=True,
)
best_iou_score = 0.0
metric_name = 'iou_score'
loss_name = 'dice_loss'
train_logs_list, valid_logs_list = [], []
train_df, valid_df = load_data(DATA_DIR)
for i in range(0, EPOCHS):
train_dataset = EndoscopyDataset(
train_df,
augmentation=get_training_augmentation(),
class_rgb_values=select_class_rgb_values,
)
valid_dataset = EndoscopyDataset(
valid_df,
augmentation=get_validation_augmentation(),
class_rgb_values=select_class_rgb_values,
)
# Get train and val data loaders
train_loader = DataLoader(train_dataset, batch_size=3, shuffle=True, num_workers=2)
valid_loader = DataLoader(valid_dataset, batch_size=2, shuffle=False, num_workers=2)
# Perform training & validation
print('\nEpoch: {}'.format(i))
train_logs = train_epoch.run(train_loader)
valid_logs = valid_epoch.run(valid_loader)
train_logs_list.append(train_logs)
valid_logs_list.append(valid_logs)
y_acc_train.append(train_logs[metric_name])
y_acc_valid.append(valid_logs[metric_name])
y_loss_train.append(train_logs[loss_name])
y_loss_valid.append(valid_logs[loss_name])
# Save model if a better val IoU score is obtained
if best_iou_score < valid_logs[metric_name]:
best_iou_score = valid_logs[metric_name]
torch.save(model, './my_best_model.pth')
print('Model saved!')
plt.figure()
plt.plot(x, y_acc_train, '-ro', label="训练集IoU", linestyle='--')
plt.plot(x, y_acc_valid, '-bo', label="测试集IoU")
plt.xticks(np.arange(1, EPOCHS+1, 1))
plt.yticks(np.arange(0, 1.1, 0.1))
plt.xlabel("Epoch")
plt.ylabel('Iou')
# plt.ylim((0, 1))
# for a, b in zip(x, y_acc_train):
# plt.text(a, b, round(b, 3), ha='center', va='bottom', fontsize=10)
# for a, b in zip(x, y_acc_valid):
# plt.text(a, b, round(b, 3), ha='center', va='bottom', fontsize=10)
#
#添加图例
plt.legend(loc='lower right')
plt.show()
plt.figure()
plt.plot(x, y_loss_train, '-ro', label="训练集loss", linestyle='--')
plt.plot(x, y_loss_valid, '-bo', label="测试集loss")
plt.xticks(np.arange(1, EPOCHS + 1, 1))
plt.yticks(np.arange(0, 1.1, 0.1))
plt.xlabel("Epoch")
plt.ylabel('Loss')
# for a, b in zip(x, y_loss_train):
# plt.text(a, b, round(b, 3), ha='center', va='bottom', fontsize=10)
# for a, b in zip(x, y_loss_valid):
# plt.text(a, b, round(b, 3), ha='center', va='bottom', fontsize=10)
# 添加图例
plt.legend(loc='upper right')
# plt.ylim((0, 1))
plt.show()
else:
# create test dataloader to be used with UNet model (with preprocessing operation: to_tensor(...))
train_df, valid_df = load_data(DATA_DIR)
test_dataset = EndoscopyDataset(
valid_df,
augmentation=get_validation_augmentation(),
class_rgb_values=select_class_rgb_values,
)
test_dataset_vis = EndoscopyDataset(
valid_df,
class_rgb_values=select_class_rgb_values,
)
for idx in range(len(test_dataset)):
image, gt_mask = test_dataset[idx]
STD = [0.3057, 0.2146, 0.1775]
MEAN = [0.5588, 0.3219, 0.2358]
image_vis = test_dataset_vis[idx][0]
image_vis = np.transpose(image_vis, (1, 2, 0))
image_vis = ((image_vis * STD+MEAN)*255.0).astype('uint8')
x_tensor = torch.from_numpy(image).to(DEVICE).unsqueeze(0)
# Predict test image
pred_mask = model(x_tensor)
pred_mask = pred_mask.detach().squeeze().cpu().numpy()
# Convert pred_mask from `CHW` format to `HWC` format
pred_mask = np.transpose(pred_mask, (1, 2, 0))
# Get prediction channel corresponding to foreground
pred_mask = colour_code_segmentation(reverse_one_hot(pred_mask), select_class_rgb_values)
# Convert gt_mask from `CHW` format to `HWC` format
gt_mask = np.transpose(gt_mask, (1, 2, 0))
gt_mask = colour_code_segmentation(reverse_one_hot(gt_mask), select_class_rgb_values)
visualize(
original_image=image_vis,
ground_truth_mask=gt_mask,
predicted_mask=pred_mask,
)
utils.py
import torch
from torch.utils.data import DataLoader
import albumentations as album
import numpy as np
np.set_printoptions(threshold=np.inf)
import cv2, os
import pandas as pd
import numpy as np
#加**时,返回为字典,输入指定pred_mask = predmask,则该函数中images为字典。
def getImg(**images):
imglist = []
for idx, (name, image) in enumerate(images.items()):
imglist.append(image)
return imglist
# plt.subplot(1, n_images, idx + 1)
# print("name:", name)
# print(type(image))
# plt.imshow(image)
# Perform one hot encoding on label
def one_hot_encode(label, label_values):
"""
Convert a segmentation image label array to one-hot format
by replacing each pixel value with a vector of length num_classes
# Arguments
label: The 2D array segmentation image label
label_values
# Returns
A 2D array with the same width and hieght as the input, but
with a depth size of num_classes
"""
semantic_map = []
for colour in label_values:
equality = np.equal(label, colour)
class_map = np.all(equality, axis=-1)
semantic_map.append(class_map)
semantic_map = np.stack(semantic_map, axis=-1)
return semantic_map
# Perform reverse one-hot-encoding on labels / preds
def reverse_one_hot(image):
"""
Transform a 2D array in one-hot format (depth is num_classes),
to a 2D array with only 1 channel, where each pixel value is
the classified class key.
# Arguments
image: The one-hot format image
# Returns
A 2D array with the same width and hieght as the input, but
with a depth size of 1, where each pixel value is the classified
class key.
"""
x = np.argmax(image, axis=-1)
return x
# Perform colour coding on the reverse-one-hot outputs
def colour_code_segmentation(image, label_values):
"""
Given a 1-channel array of class keys, colour code the segmentation results.
# Arguments
image: single channel array where each value represents the class key.
label_values
# Returns
Colour coded image for segmentation visualization
"""
colour_codes = np.array(label_values)
x = colour_codes[image.astype(int)]
return x
def get_training_augmentation():
train_transform = [
album.Resize(height=528, width=624, interpolation=cv2.INTER_CUBIC, always_apply=True),
album.HorizontalFlip(p=0.5),
album.VerticalFlip(p=0.5),
]
return album.Compose(train_transform)
def get_validation_augmentation():
# Add sufficient padding to ensure image is divisible by 32
test_transform = [
album.Resize(height=528, width=624, interpolation=cv2.INTER_CUBIC, always_apply=True),
]
return album.Compose(test_transform)
def to_tensor(x, **kwargs):
return x.transpose(2, 0, 1).astype('float32')
def get_preprocessing(preprocessing_fn=None):
"""Construct preprocessing transform
Args:
preprocessing_fn (callable): data normalization function
(can be specific for each pretrained neural network)
Return:
transform: albumentations.Compose
"""
_transform = []
if preprocessing_fn:
_transform.append(album.Lambda(image=preprocessing_fn))
_transform.append(album.Lambda(image=to_tensor, mask=to_tensor))
return album.Compose(_transform)
class EndoscopyDataset(torch.utils.data.Dataset):
def __init__(self, df, class_rgb_values=None, augmentation=None, preprocessing=None):
self.image_paths = df['image_path'].tolist()
self.mask_paths = df['mask_path'].tolist()
self.class_rgb_values = class_rgb_values
self.augmentation = augmentation
self.preprocessing = preprocessing
def __getitem__(self, i):
# read images and masks
image = cv2.cvtColor(cv2.imread(self.image_paths[i]), cv2.COLOR_BGR2RGB)
mask = cv2.cvtColor(cv2.imread(self.mask_paths[i]), cv2.COLOR_BGR2RGB)
# one-hot-encode the mask
mask = one_hot_encode(mask, self.class_rgb_values).astype('float')
# apply augmentations
if self.augmentation:
sample = self.augmentation(image=image, mask=mask)
image, mask = sample['image'], sample['mask']
# apply preprocessing
if self.preprocessing:
sample = self.preprocessing(image=image, mask=mask)
image, mask = sample['image'], sample['mask']
return image, mask
def __len__(self):
# return length of
return len(self.image_paths)
def load_data(data_path):
metadata_df = pd.read_csv(os.path.join(data_path, 'metadata.csv'))
metadata_df = metadata_df[['image_id', 'image_path', 'mask_path']]
metadata_df['image_path'] = metadata_df['image_path'].apply(lambda img_pth: os.path.join(data_path, img_pth))
metadata_df['mask_path'] = metadata_df['mask_path'].apply(lambda img_pth: os.path.join(data_path, img_pth))
metadata_df = metadata_df.sample(frac=1).reset_index(drop=True)
# SAMPLE随机选取若干行
# random_state=None,取得数据不重复
# random_state=1,可以取得重复数据
valid_df = metadata_df.sample(frac=0.1, random_state=42)
train_df = metadata_df.drop(valid_df.index)
return train_df, valid_df
参考文章