图像分割——基于pytorch的牙齿分割

作为视觉基础任务的图像分割是大多数深度学习入门者的进一步学习,本文将用牙齿分割作为数据集,分享一下图像分割的训练内容。

一、引入库

import os
import torch
import torch.nn as nn
from PIL import Image
import torch.optim as optim
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import DataLoader,Dataset,random_split

二、定义超参数

1.数据预处理和增强相关的超参数

train_transforms = transforms.Compose([
#     transforms.RandomHorizontalFlip(),
#     transforms.ColorJitter(brightness=(2, 2)),
#     transforms.ColorJitter(saturation=(2, 2)),
#     transforms.ColorJitter(contrast=(2, 2)),
#     transforms.CenterCrop((620,300)),
#     transforms.GaussianBlur(21, 10),
#         transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.Grayscale(num_output_channels=1),
    transforms.ToTensor()
#             transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
            ])
val_trasforms = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    transforms.ToTensor(),
#             transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
            ])

2.训练过程相关的超参数

batch_size = 4
num_epochs = 10
learning_rate = 0.001
optimizer = optim.Adam(unet.parameters(), lr=1e-4)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

三、定义数据集加载和预处理方式

class ycDataset(Dataset):
    def __init__(self, root_dir,transform=None):
        self.path1 = os.path.join(root_dir, "image")
        self.path2 = os.path.join(root_dir, "mask")
        self.img_path = os.listdir(self.path1)
        self.mask_path = os.listdir(self.path2)
        self.transform = transform
    def __getitem__(self, idx):
        img_name = self.img_path[idx]
        mask_name = self.mask_path[idx]
        img_item_path = os.path.join(self.path1, img_name)
        mask_item_path = os.path.join(self.path2,mask_name)
        img = Image.open(img_item_path).convert('RGB')
        mask = Image.open(mask_item_path).convert('RGB')
        if self.transform:
            img = self.transform(img)
            mask = self.transform(mask)
        return img,mask
    def __len__(self):
        return len(self.img_path)

四、载入数据集

dataset = ycDataset(r"./yc-train",train_transforms)
validation_size = int(0.2 * len(dataset))
training_size = len(dataset) - validation_size
training_dataset, validation_dataset = random_split(dataset, [training_size, validation_size])
train_loader = DataLoader(training_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(validation_dataset, batch_size=batch_size, shuffle=False)

五、检验载入的数据

1.检查数据维度

for images, maskes in train_loader:
    print(f"Batch shape: {images.size()}")
    print(f"Maskes: {maskes}")
    break  

图像分割——基于pytorch的牙齿分割_第1张图片

2.可视化数据样本

def imshow(img):
    img = img / 2 + 0.5 
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()
 
dataiter = iter(train_loader)
images, masks = dataiter.next()
imshow(utils.make_grid(images))
print(' '.join(f'{masks[j]}' for j in range(batch_size)))

图像分割——基于pytorch的牙齿分割_第2张图片

六、定义模型结构并实例化

class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        self.img_shape = (3,640,320)
        self.df = 64
        self.uf = 64

        self.conv1 = self.conv_block(1, self.df)
        self.pool1 = nn.MaxPool2d(2, 2)
        self.conv2 = self.conv_block(self.df, self.df * 2, bn=True)
        self.pool2 = nn.MaxPool2d(2, 2)
        self.conv3 = self.conv_block(self.df * 2, self.df * 4, bn=True)
        self.pool3 = nn.MaxPool2d(2, 2)
        self.conv4 = self.conv_block(self.df * 4, self.df * 8, dropout_rate=0.5, bn=True)
        self.pool4 = nn.MaxPool2d(2, 2)
        self.conv5 = self.conv_block(self.df * 8, self.df * 16, dropout_rate=0.5, bn=True)

        self.up6 = self.deconv_block(self.df * 16, self.uf * 8, bn=True)
        self.conv6 = self.conv_block(self.uf * 8 * 2, self.uf * 8)

        self.up7 = self.deconv_block(self.uf * 8, self.uf * 4, bn=True)
        self.conv7 = self.conv_block(self.uf * 4 * 2, self.uf * 4)

        self.up8 = self.deconv_block(self.uf * 4, self.uf * 2, bn=True)
        self.conv8 = self.conv_block(self.uf * 2 * 2, self.uf * 2)

        self.up9 = self.deconv_block(self.uf * 2, self.uf, bn=True)
        self.conv9 = self.conv_block(self.uf * 2, self.uf)

        self.output = nn.Conv2d(self.uf, 1, kernel_size=1, stride=1)

    def conv_block(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, dropout_rate=0, bn=False):
        layers = []
        layers.append(nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding))
        if bn:
            layers.append(nn.BatchNorm2d(out_channels))
        layers.append(nn.ReLU(inplace=True))
        layers.append(nn.Conv2d(out_channels, out_channels, kernel_size, stride, padding))
        if bn:
            layers.append(nn.BatchNorm2d(out_channels))
        layers.append(nn.ReLU(inplace=True))
        if dropout_rate:
            layers.append(nn.Dropout(dropout_rate))
        return nn.Sequential(*layers)

    def deconv_block(self, in_channels, out_channels, kernel_size=2, stride=2, padding=0, bn=False):
        layers = []
        layers.append(nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding))
        if bn:
            layers.append(nn.BatchNorm2d(out_channels))
        layers.append(nn.ReLU(inplace=True))
        return nn.Sequential(*layers)

    def forward(self, x):
        x = x.permute(0, 1, 2,3)
        x = x.double()
        conv1 = self.conv1(x)
        pool1 = self.pool1(conv1)
        conv2 = self.conv2(pool1)
        pool2 = self.pool2(conv2)
        conv3 = self.conv3(pool2)
        pool3 = self.pool3(conv3)
        conv4 = self.conv4(pool3)
        pool4 = self.pool4(conv4)
        conv5 = self.conv5(pool4)
        up6 = self.up6(conv5)
        up6 = torch.cat([up6, conv4], dim=1)
        conv6 = self.conv6(up6)
        up7 = self.up7(conv6)
        up7 = torch.cat([up7, conv3], dim=1)
        conv7 = self.conv7(up7)
        up8 = self.up8(conv7)
        up8 = torch.cat([up8, conv2], dim=1)
        conv8 = self.conv8(up8)
        up9 = self.up9(conv8)
        up9 = torch.cat([up9, conv1], dim=1)
        conv9 = self.conv9(up9)
        output = self.output(conv9)
        output = nn.Sigmoid()(output)
        return output
unet = UNet().to(device).double()
optimizer = optim.Adam(unet.parameters(), lr=1e-4)

七、检验模型的输出情况

test_input= torch.rand((4, 1, 320, 640)).to(device)
print(test_input.shape)  
test_output = unet(test_input)
test_output

图像分割——基于pytorch的牙齿分割_第3张图片

八、加载之前训练的模型参数(可选)

unet.load_state_dict(torch.load('path/to/your/trained/model.pth'))

九、训练模型且保存训练过程

train_losses, train_dice_scores, train_iou_scores = [], [], []
val_losses, val_dice_scores, val_iou_scores = [], [], []
for epoch in range(num_epochs):
    unet.train()
    total_loss, total_dice, total_iou = 0, 0, 0
    for img, Mask in train_loader:
        img, Mask = img.double().to(device), Mask.double().to(device)
        optimizer.zero_grad()
        outputs = unet(img)
        loss = criterion(outputs, Mask)
        dice = DiceLoss(outputs, Mask)
        IoU = IoULoss(outputs, Mask)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        total_dice += dice.item()
        total_iou += IoU.item()
    
    average_loss = total_loss / len(train_loader)
    average_dice = total_dice / len(train_loader)
    average_iou = total_iou / len(train_loader)
    
    train_losses.append(average_loss)
    train_dice_scores.append(average_dice)
    train_iou_scores.append(average_iou)
    
    unet.eval()
    total_val_loss, total_val_dice, total_val_iou = 0, 0, 0
    with torch.no_grad():
        for img, Mask in val_loader:
            img, Mask = img.double().to(device), Mask.double().to(device)
            val_outputs = unet(img)
            val_loss = criterion(val_outputs, Mask)
            val_dice = DiceLoss(val_outputs, Mask)
            val_IoU = IoULoss(val_outputs, Mask)
            total_val_loss += val_loss.item()
            total_val_dice += val_dice.item()
            total_val_iou += val_IoU.item()
        
    average_val_loss = total_val_loss / len(val_loader)
    average_val_dice = total_val_dice / len(val_loader)
    average_val_iou = total_val_iou / len(val_loader)
    
    val_losses.append(average_val_loss)
    val_dice_scores.append(average_val_dice)
    val_iou_scores.append(average_val_iou)
    
    print(f"Epoch [{epoch+1}/10], Train Loss: {average_loss:.4f}, Train Dice: {average_dice:.4f}, Train IoU: {average_iou:.4f}, Val Loss: {average_val_loss:.4f}, Val Dice: {average_val_dice:.4f}, Val IoU: {average_val_iou:.4f}")
path = r'path/to/your/trained/model.pth'
torch.save(model.state_dict(), path)

由于电脑配置较低所以训练及评估过程截图并未展示完全,但是代码是可以正常跑通的。

十、可视化训练过程并评估模型性能

epochs = range(1, 11)
plt.figure(figsize=(10, 5))
plt.subplot(1, 3, 1)
plt.plot(epochs, train_losses, label='Train Loss')
plt.plot(epochs, val_losses, label='Val Loss')
plt.title('Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()

plt.subplot(1, 3, 2)
plt.plot(epochs, train_dice_scores, label='Train Dice Score')
plt.plot(epochs, val_dice_scores, label='Val Dice Score')
plt.title('Dice Score')
plt.xlabel('Epochs')
plt.ylabel('Dice Score')
plt.legend()

plt.subplot(1, 3, 3)
plt.plot(epochs, train_iou_scores, label='Train IoU Score')
plt.plot(epochs, val_iou_scores, label='Val IoU Score')
plt.title('IoU Score')
plt.xlabel('Epochs')
plt.ylabel('IoU Score')
plt.legend()

plt.tight_layout()
plt.show()

十一、定义预测函数进行预测

def predict_image(model, image, device):
    image = val_trasforms(image).unsqueeze(0)  # Add batch dimension
    image = image.to(device)
    model.eval()
    with torch.no_grad():
        output = model(image)
    pred_mask = output.sigmoid().squeeze().cpu().numpy()  
    threshold = 0.5  
    pred_mask = (pred_mask > threshold).astype(np.uint8)  
    return pred_mask

你可能感兴趣的:(计算机视觉实战,pytorch,人工智能,python)