作为视觉基础任务的图像分割是大多数深度学习入门者的进一步学习,本文将用牙齿分割作为数据集,分享一下图像分割的训练内容。
import os
import torch
import torch.nn as nn
from PIL import Image
import torch.optim as optim
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import DataLoader,Dataset,random_split
train_transforms = transforms.Compose([
# transforms.RandomHorizontalFlip(),
# transforms.ColorJitter(brightness=(2, 2)),
# transforms.ColorJitter(saturation=(2, 2)),
# transforms.ColorJitter(contrast=(2, 2)),
# transforms.CenterCrop((620,300)),
# transforms.GaussianBlur(21, 10),
# transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
transforms.Grayscale(num_output_channels=1),
transforms.ToTensor()
# transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
])
val_trasforms = transforms.Compose([
transforms.Grayscale(num_output_channels=1),
transforms.ToTensor(),
# transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
])
batch_size = 4
num_epochs = 10
learning_rate = 0.001
optimizer = optim.Adam(unet.parameters(), lr=1e-4)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class ycDataset(Dataset):
def __init__(self, root_dir,transform=None):
self.path1 = os.path.join(root_dir, "image")
self.path2 = os.path.join(root_dir, "mask")
self.img_path = os.listdir(self.path1)
self.mask_path = os.listdir(self.path2)
self.transform = transform
def __getitem__(self, idx):
img_name = self.img_path[idx]
mask_name = self.mask_path[idx]
img_item_path = os.path.join(self.path1, img_name)
mask_item_path = os.path.join(self.path2,mask_name)
img = Image.open(img_item_path).convert('RGB')
mask = Image.open(mask_item_path).convert('RGB')
if self.transform:
img = self.transform(img)
mask = self.transform(mask)
return img,mask
def __len__(self):
return len(self.img_path)
dataset = ycDataset(r"./yc-train",train_transforms)
validation_size = int(0.2 * len(dataset))
training_size = len(dataset) - validation_size
training_dataset, validation_dataset = random_split(dataset, [training_size, validation_size])
train_loader = DataLoader(training_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(validation_dataset, batch_size=batch_size, shuffle=False)
for images, maskes in train_loader:
print(f"Batch shape: {images.size()}")
print(f"Maskes: {maskes}")
break
def imshow(img):
img = img / 2 + 0.5
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.show()
dataiter = iter(train_loader)
images, masks = dataiter.next()
imshow(utils.make_grid(images))
print(' '.join(f'{masks[j]}' for j in range(batch_size)))
class UNet(nn.Module):
def __init__(self):
super(UNet, self).__init__()
self.img_shape = (3,640,320)
self.df = 64
self.uf = 64
self.conv1 = self.conv_block(1, self.df)
self.pool1 = nn.MaxPool2d(2, 2)
self.conv2 = self.conv_block(self.df, self.df * 2, bn=True)
self.pool2 = nn.MaxPool2d(2, 2)
self.conv3 = self.conv_block(self.df * 2, self.df * 4, bn=True)
self.pool3 = nn.MaxPool2d(2, 2)
self.conv4 = self.conv_block(self.df * 4, self.df * 8, dropout_rate=0.5, bn=True)
self.pool4 = nn.MaxPool2d(2, 2)
self.conv5 = self.conv_block(self.df * 8, self.df * 16, dropout_rate=0.5, bn=True)
self.up6 = self.deconv_block(self.df * 16, self.uf * 8, bn=True)
self.conv6 = self.conv_block(self.uf * 8 * 2, self.uf * 8)
self.up7 = self.deconv_block(self.uf * 8, self.uf * 4, bn=True)
self.conv7 = self.conv_block(self.uf * 4 * 2, self.uf * 4)
self.up8 = self.deconv_block(self.uf * 4, self.uf * 2, bn=True)
self.conv8 = self.conv_block(self.uf * 2 * 2, self.uf * 2)
self.up9 = self.deconv_block(self.uf * 2, self.uf, bn=True)
self.conv9 = self.conv_block(self.uf * 2, self.uf)
self.output = nn.Conv2d(self.uf, 1, kernel_size=1, stride=1)
def conv_block(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, dropout_rate=0, bn=False):
layers = []
layers.append(nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding))
if bn:
layers.append(nn.BatchNorm2d(out_channels))
layers.append(nn.ReLU(inplace=True))
layers.append(nn.Conv2d(out_channels, out_channels, kernel_size, stride, padding))
if bn:
layers.append(nn.BatchNorm2d(out_channels))
layers.append(nn.ReLU(inplace=True))
if dropout_rate:
layers.append(nn.Dropout(dropout_rate))
return nn.Sequential(*layers)
def deconv_block(self, in_channels, out_channels, kernel_size=2, stride=2, padding=0, bn=False):
layers = []
layers.append(nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding))
if bn:
layers.append(nn.BatchNorm2d(out_channels))
layers.append(nn.ReLU(inplace=True))
return nn.Sequential(*layers)
def forward(self, x):
x = x.permute(0, 1, 2,3)
x = x.double()
conv1 = self.conv1(x)
pool1 = self.pool1(conv1)
conv2 = self.conv2(pool1)
pool2 = self.pool2(conv2)
conv3 = self.conv3(pool2)
pool3 = self.pool3(conv3)
conv4 = self.conv4(pool3)
pool4 = self.pool4(conv4)
conv5 = self.conv5(pool4)
up6 = self.up6(conv5)
up6 = torch.cat([up6, conv4], dim=1)
conv6 = self.conv6(up6)
up7 = self.up7(conv6)
up7 = torch.cat([up7, conv3], dim=1)
conv7 = self.conv7(up7)
up8 = self.up8(conv7)
up8 = torch.cat([up8, conv2], dim=1)
conv8 = self.conv8(up8)
up9 = self.up9(conv8)
up9 = torch.cat([up9, conv1], dim=1)
conv9 = self.conv9(up9)
output = self.output(conv9)
output = nn.Sigmoid()(output)
return output
unet = UNet().to(device).double()
optimizer = optim.Adam(unet.parameters(), lr=1e-4)
test_input= torch.rand((4, 1, 320, 640)).to(device)
print(test_input.shape)
test_output = unet(test_input)
test_output
unet.load_state_dict(torch.load('path/to/your/trained/model.pth'))
train_losses, train_dice_scores, train_iou_scores = [], [], []
val_losses, val_dice_scores, val_iou_scores = [], [], []
for epoch in range(num_epochs):
unet.train()
total_loss, total_dice, total_iou = 0, 0, 0
for img, Mask in train_loader:
img, Mask = img.double().to(device), Mask.double().to(device)
optimizer.zero_grad()
outputs = unet(img)
loss = criterion(outputs, Mask)
dice = DiceLoss(outputs, Mask)
IoU = IoULoss(outputs, Mask)
loss.backward()
optimizer.step()
total_loss += loss.item()
total_dice += dice.item()
total_iou += IoU.item()
average_loss = total_loss / len(train_loader)
average_dice = total_dice / len(train_loader)
average_iou = total_iou / len(train_loader)
train_losses.append(average_loss)
train_dice_scores.append(average_dice)
train_iou_scores.append(average_iou)
unet.eval()
total_val_loss, total_val_dice, total_val_iou = 0, 0, 0
with torch.no_grad():
for img, Mask in val_loader:
img, Mask = img.double().to(device), Mask.double().to(device)
val_outputs = unet(img)
val_loss = criterion(val_outputs, Mask)
val_dice = DiceLoss(val_outputs, Mask)
val_IoU = IoULoss(val_outputs, Mask)
total_val_loss += val_loss.item()
total_val_dice += val_dice.item()
total_val_iou += val_IoU.item()
average_val_loss = total_val_loss / len(val_loader)
average_val_dice = total_val_dice / len(val_loader)
average_val_iou = total_val_iou / len(val_loader)
val_losses.append(average_val_loss)
val_dice_scores.append(average_val_dice)
val_iou_scores.append(average_val_iou)
print(f"Epoch [{epoch+1}/10], Train Loss: {average_loss:.4f}, Train Dice: {average_dice:.4f}, Train IoU: {average_iou:.4f}, Val Loss: {average_val_loss:.4f}, Val Dice: {average_val_dice:.4f}, Val IoU: {average_val_iou:.4f}")
path = r'path/to/your/trained/model.pth'
torch.save(model.state_dict(), path)
由于电脑配置较低所以训练及评估过程截图并未展示完全,但是代码是可以正常跑通的。
epochs = range(1, 11)
plt.figure(figsize=(10, 5))
plt.subplot(1, 3, 1)
plt.plot(epochs, train_losses, label='Train Loss')
plt.plot(epochs, val_losses, label='Val Loss')
plt.title('Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.subplot(1, 3, 2)
plt.plot(epochs, train_dice_scores, label='Train Dice Score')
plt.plot(epochs, val_dice_scores, label='Val Dice Score')
plt.title('Dice Score')
plt.xlabel('Epochs')
plt.ylabel('Dice Score')
plt.legend()
plt.subplot(1, 3, 3)
plt.plot(epochs, train_iou_scores, label='Train IoU Score')
plt.plot(epochs, val_iou_scores, label='Val IoU Score')
plt.title('IoU Score')
plt.xlabel('Epochs')
plt.ylabel('IoU Score')
plt.legend()
plt.tight_layout()
plt.show()
def predict_image(model, image, device):
image = val_trasforms(image).unsqueeze(0) # Add batch dimension
image = image.to(device)
model.eval()
with torch.no_grad():
output = model(image)
pred_mask = output.sigmoid().squeeze().cpu().numpy()
threshold = 0.5
pred_mask = (pred_mask > threshold).astype(np.uint8)
return pred_mask