Carvana Image Masking
- 导包
- UNet
- 设置多卡训练环境
- 加载数据
- 自行构建Dataset类
- 初始化及数据集划分
- 损失函数和优化器
- 评价指标
- 训练与评估
导包
import os
import numpy as np
import collections
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torch.optim as optim
import PIL
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
UNet
设置多卡训练环境
os.environ['CUDA_VISIBLE_DEVICES'] = '4, 5, 6, 7'
加载数据
自行构建Dataset类
class CarvanaDataset(Dataset):
def __init__(self, base_dir, idx_list, mode='train', transform=None):
self.base_dir = base_dir
self.idx_list = idx_list
self.images = os.listdir(base_dir + 'train')
self.masks = os.listdir(base_dir + 'train_masks')
self.mode = mode
self.transform = transform
def __len__(self):
return len(self.idx_list)
def __getitem__(self, index):
image_file = self.images[self.idx_list[index]]
mask_file = image_file[:-4] + '_mask.gif'
image = PIL.Image.open(os.path.join(self.base_dir, 'train', image_file))
if self.mode == 'train':
mask = PIL.Image.open(os.path.join(self.base_dir, 'train_masks', mask_file))
if transforms is not None:
image = self.transform(image)
mask = self.transform(mask)
mask[mask!=0] = 1.0
return image, mask.float()
else:
if self.transforms is not None:
image = self.transform(image)
return image
初始化及数据集划分
base_dir = './Carvana/'
batch_size = 32
num_workers = 4
img_size = (256, 256)
transform = transforms.Compose([
transforms.Resize(img_size),
transforms.ToTensor()
])
train_idxs, val_idxs = train_test_split(range(len(os.listdir(base_dir + 'train_masks'))), test_size=0.3)
train_data = CarvanaDataset(base_dir=base_dir, idx_list=train_idxs, transform=transform)
val_data = CarvanaDataset(base_dir=base_dir, idx_list=val_idxs, transform=transform)
train_loader = DataLoader(dataset=train_data, batch_size=batch_size, num_workers=num_workers, shuffle=True)
val_loader = DataLoader(dataset=val_data, batch_size=batch_size, num_workers=num_workers, shuffle=False)
image, mask = next(iter(train_loader))
print(image.shape, mask.shape)
plt.subplot(121)
plt.imshow(image[0, 0])
plt.subplot(122)
plt.imshow(mask[0, 0], cmap='gray')
结果如下:
损失函数和优化器
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(unet.parameters(), lr=1e-3, weight_decay=1e-8)
unet = nn.DataParallel(unet).cuda()
class DiceLoss(nn.Module):
def __init__(self, weight=None, size_average=True):
super(DiceLoss, self).__init__()
def forward(self, inputs, targets, smooth=1):
inputs = torch.sigmoid(inputs)
inputs = inputs.view(-1)
targets = targets.view(-1)
intersection = (inputs * targets).sum()
dice = (2. * intersection + smooth) / (inputs.sum() + targets.sum() + smooth)
newcriterion = DiceLoss()
unet.eval()
image, mask = next(iter(val_loader))
out_unet = unet(image.cuda())
loss = newcriterion(out_unet, mask.cuda())
print(loss)
评价指标
def dice_coeff(pred, target):
eps = 0.0001
num = pred.size(0)
m1 = pred.view(num, -1)
m2 = target.view(num, -1)
intersection = (m1 * m2).sum()
return (2. * intersection + eps) / (m1.sum() + m2.sum() + eps)
训练与评估
def train(epoch):
unet.train()
train_loss = 0
for data, mask in train_loader:
data, mask = data.cuda(), mask.cuda()
optimizer.zero_grad()
output = unet(data)
loss = criterion(output, mask)
loss.backward()
optimizer.step()
train_loss += loss.item() * data.size(0)
train_loss = train_loss / len(train_loader.dataset)
print('Epoch:{} \t Training Loss:{:.6f}'.format(epoch, train_loss))
def val(epoch):
print('current learning rate:', optimizer.state_dict()['param_groups'][0]['lr'])
unet.eval()
val_loss = 0
dice_score = 0
with torch.no_grad():
for data, mask in val_loader:
data, mask = data.cuda(), mask.cuda()
output = unet(data)
loss = criterion(output, mask)
val_loss += loss.item() * data.size(0)
dice_score += dice_coeff(torch.sigmoid(output).cpu(), mask.cpu()) * data.size(0)
val_loss = val_loss / len(val_loader.dataset)
dice_score = dice_score / len(val_loader.dataset)
print('Epoch:{} \t Validation Loss:{:.6f}, dice score:{:.6f}'.format(epoch, val_loss, dice_score))
epochs = 100
for epoch in range(1, epochs + 1):
train(epoch)
val(epoch)