小黑的Python日记:Unet简单实现裂缝分割

大噶好,我系小黑喵

裂缝数据集

数据集地址:https://github.com/cuilimeng/CrackForest-dataset
结构:

  --project
    main.py
     --image
        --train
           --data
           --groundTruth
        --val
           --data
           --groundTruth

我手动将数据集做成这个格式,其中trian84张,val34张,都保存为了jpg图像。

Unet

论文地址:http://www.arxiv.org/pdf/1505.04597.pdf
代码来源:https://github.com/JavisPeng/u_net_liver
上面代码中,作者将Unet运用于liver识别,和裂缝一样,都只有一个mask,因而我们可以直接使用上述代码。

Unet结构

需要修改dataset.py为自己的数据集,其他小小改动即可。

#dataset.py
import torch.utils.data as data
import PIL.Image as Image
import os


def make_dataset(rootdata,roottarget):#获取img和mask的地址
    imgs = []
    filename_data = [x for x in os.listdir(rootdata)]
    for name in filename_data:
        img = os.path.join(rootdata, name)
        mask = os.path.join(roottarget, name)
        imgs.append((img, mask))#作为元组返回
    return imgs


class MyDataset(data.Dataset):
    def __init__(self, rootdata, roottarget, transform=None, target_transform=None):
        imgs = make_dataset(rootdata,roottarget)
        self.imgs = imgs
        self.transform = transform
        self.target_transform = target_transform

    def __getitem__(self, index):
        x_path, y_path = self.imgs[index]
        img_x = Image.open(x_path).convert('L')#读取并转换为二值图像
        img_y = Image.open(y_path).convert('L')
        if self.transform is not None:
            img_x = self.transform(img_x)
        if self.target_transform is not None:
            img_y = self.target_transform(img_y)
        return img_x, img_y

    def __len__(self):
        return len(self.imgs)
#main.py
import numpy as np
import torch
import argparse
from torch.utils.data import DataLoader
from torch import autograd, optim
from torchvision.transforms import transforms
from unet import Unet
from dataset import MyDataset

# 是否使用cuda
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

x_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])  # 复活了,这里修改就没错误了
])

# mask只需要转换为tensor
y_transforms = transforms.ToTensor()


def train_model(model, criterion, optimizer, dataload, num_epochs=10):
    for epoch in range(0,num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)
        dt_size = len(dataload.dataset)
        epoch_loss = 0
        step = 0
        for x, y in dataload:
            step += 1
            inputs = x.to(device)
            labels = y.to(device)
            # zero the parameter gradients
            optimizer.zero_grad()
            # forward
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            print("%d/%d,train_loss:%0.3f" %
                  (step,
                   (dt_size - 1) // dataload.batch_size + 1, loss.item()))
        print("epoch %d loss:%0.3f" % (epoch, epoch_loss))
    torch.save(model.cpu().state_dict(), 'weights_%d.pth' % epoch)
    return model


#训练模型
def train():
    batch_size = 1
    liver_dataset = MyDataset(
        "image/train/data", "image/train/gt",transform=x_transforms, target_transform=y_transforms)
    dataloaders = DataLoader(
        liver_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
    train_model(model, criterion, optimizer, dataloaders)


#显示模型的输出结果
def test():
    liver_dataset = MyDataset(
        "image/val/data", "image/val/gt", transform=x_transforms, target_transform=y_transforms)
    dataloaders = DataLoader(liver_dataset, batch_size=1)
    import matplotlib.pyplot as plt
    plt.ion()
    with torch.no_grad():
        for x, _ in dataloaders:
            y = model(x)
            img_y = torch.squeeze(y).numpy()
            plt.imshow(img_y)
            plt.pause(0.01)
        plt.show()


if __name__ == '__main__':
    pretrained = False
    model = Unet(1, 1).to(device)
    if pretrained:
        model.load_state_dict(torch.load('./weights_4.pth'))
    criterion = torch.nn.BCELoss()
    optimizer = optim.Adam(model.parameters())
    train()
    test()

unet.py不需要变动

结果

训练了10个epoch后:累加loss大概到3
前几张预测图片:


上为预测,下为groundTruth

对于100多张的数据集,这个效果还行。
也算是填了一个以前的坑。


你可能感兴趣的:(小黑的Python日记:Unet简单实现裂缝分割)