PyTorch项目实战1-构建UNet实现道路裂纹检测

PyTorch项目实战1-构建UNet实现道路裂纹检测

  • 项目构建步骤
    • 1. 数据集下载
    • 2. 数据准备
    • 3. 代码所需库引用
    • 4. 数据读取
    • 5. Unet代码
    • 6. 训练代码
    • 7. 测试代码

项目构建步骤

本文的内容是参照微信上的一篇公众号,地址如下:
https://mp.weixin.qq.com/s/xeUdW2l71RsHe1Zdzr5a7Q

1. 数据集下载

下载地址:
https://github.com/cuilimeng/CrackForest-dataset
PyTorch项目实战1-构建UNet实现道路裂纹检测_第1张图片
PyTorch项目实战1-构建UNet实现道路裂纹检测_第2张图片

2. 数据准备

因为数据集里的groundTruth是.mat的文件格式,所以需要进行转换,转换代码如下:

#!/user/bin/python
# coding=utf-8
from os.path import isdir
from scipy import io
import os, sys
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

if __name__ == '__main__':
    file_path = 'CrackForest-dataset-master/CrackForest-dataset-master/groundTruth/'
    png_img_dir = 'CrackForest-dataset-master/CrackForest-dataset-master/groundTruthPngImg/'
    if not isdir(png_img_dir):
        os.makedirs(png_img_dir)
    image_path_lists = os.listdir(file_path)
    images_path = []
    for index in range(len(image_path_lists)):
        image_file = os.path.join(file_path, image_path_lists[index])
        images_path.append(image_file)
        image_mat = io.loadmat(image_file)
        segmentation_image = image_mat['groundTruth']['Segmentation'][0]
        segmentation_image_array = np.array(segmentation_image[0])
        image = Image.fromarray((segmentation_image_array -1) * 255)
        png_image_path = os.path.join(png_img_dir, "%s.png" % image_path_lists[index][0:3])
        image.save(png_image_path)
        plt.figure()
        plt.imshow(image)
        # plt.pause(2)
        plt.show()

转换后的图像分割label如下:
PyTorch项目实战1-构建UNet实现道路裂纹检测_第3张图片
PyTorch项目实战1-构建UNet实现道路裂纹检测_第4张图片

3. 代码所需库引用

#!/user/bin/python
# coding=utf-8
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import lr_scheduler, optimizer
import torchvision
import os, sys
import cv2 as cv
from torch.utils.data import DataLoader, sampler

4. 数据读取

class SegmentationDataset(object):
    def __init__(self, image_dir, mask_dir):
        self.images = []
        self.masks = []
        files = os.listdir(image_dir)
        sfiles = os.listdir(mask_dir)
        for i in range(len(sfiles)):
            img_file = os.path.join(image_dir, files[i])
            mask_file = os.path.join(mask_dir, sfiles[i])
            # print(img_file, mask_file)
            self.images.append(img_file)
            self.masks.append(mask_file)

    def __len__(self):
        return len(self.images)

    def num_of_samples(self):
        return len(self.images)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
            image_path = self.images[idx]
            mask_path = self.masks[idx]
        else:
            image_path = self.images[idx]
            mask_path = self.masks[idx]
        img = cv.imread(image_path, cv.IMREAD_GRAYSCALE)  # BGR order
        mask = cv.imread(mask_path, cv.IMREAD_GRAYSCALE)

        # 输入图像
        img = np.float32(img) / 255.0
        img = np.expand_dims(img, 0)

        # 目标标签0 ~ 1, 对于
        mask[mask <= 128] = 0
        mask[mask > 128] = 1
        mask = np.expand_dims(mask, 0)
        sample = {'image': torch.from_numpy(img), 'mask': torch.from_numpy(mask),}
        return sample

5. Unet代码

PyTorch项目实战1-构建UNet实现道路裂纹检测_第5张图片

class UNetModel(torch.nn.Module):

    def __init__(self, in_features=1, out_features=2, init_features=32):
        super(UNetModel, self).__init__()
        features = init_features
        self.encode_layer1 = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels=in_features, out_channels=features, kernel_size=3, padding=1, stride=1),
            torch.nn.BatchNorm2d(num_features=features),
            torch.nn.ReLU(),
            torch.nn.Conv2d(in_channels=features, out_channels=features, kernel_size=3, padding=1, stride=1),
            torch.nn.BatchNorm2d(num_features=features),
            torch.nn.ReLU()
        )
        self.pool1 = torch.nn.MaxPool2d(kernel_size=2, stride=2)
        self.encode_layer2 = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels=features, out_channels=features*2, kernel_size=3, padding=1, stride=1),
            torch.nn.BatchNorm2d(num_features=features*2),
            torch.nn.ReLU(),
            torch.nn.Conv2d(in_channels=features*2, out_channels=features*2, kernel_size=3, padding=1, stride=1),
            torch.nn.BatchNorm2d(num_features=features * 2),
            torch.nn.ReLU()
        )
        self.pool2 = torch.nn.MaxPool2d(kernel_size=2, stride=2)
        self.encode_layer3 = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels=features*2, out_channels=features*4, kernel_size=3, padding=1, stride=1),
            torch.nn.BatchNorm2d(num_features=features * 4),
            torch.nn.ReLU(),
            torch.nn.Conv2d(in_channels=features*4, out_channels=features*4, kernel_size=3, padding=1, stride=1),
            torch.nn.BatchNorm2d(num_features=features * 4),
            torch.nn.ReLU()
        )
        self.pool3 = torch.nn.MaxPool2d(kernel_size=2, stride=2)
        self.encode_layer4 = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels=features*4, out_channels=features*8, kernel_size=3, padding=1, stride=1),
            torch.nn.BatchNorm2d(num_features=features * 8),
            torch.nn.ReLU(),
            torch.nn.Conv2d(in_channels=features*8, out_channels=features*8, kernel_size=3, padding=1, stride=1),
            torch.nn.BatchNorm2d(num_features=features * 8),
            torch.nn.ReLU(),
        )
        self.pool4 = torch.nn.MaxPool2d(kernel_size=2, stride=2)
        self.encode_decode_layer = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels=features*8, out_channels=features*16, kernel_size=3, padding=1, stride=1),
            torch.nn.BatchNorm2d(num_features=features * 16),
            torch.nn.ReLU(),
            torch.nn.Conv2d(in_channels=features*16, out_channels=features*16, kernel_size=3, padding=1, stride=1),
            torch.nn.BatchNorm2d(num_features=features * 16),
            torch.nn.ReLU()
        )
        self.upconv4 = torch.nn.ConvTranspose2d(
            features * 16, features * 8, kernel_size=2, stride=2
        )
        self.decode_layer4 = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels=features*16, out_channels=features*8, kernel_size=3, padding=1, stride=1),
            torch.nn.BatchNorm2d(num_features=features*8),
            torch.nn.ReLU(),
            torch.nn.Conv2d(in_channels=features*8, out_channels=features*8, kernel_size=3, padding=1, stride=1),
            torch.nn.BatchNorm2d(num_features=features * 8),
            torch.nn.ReLU(),
        )
        self.upconv3 = torch.nn.ConvTranspose2d(
            features * 8, features * 4, kernel_size=2, stride=2
        )
        self.decode_layer3 = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels=features*8, out_channels=features*4, kernel_size=3, padding=1, stride=1),
            torch.nn.BatchNorm2d(num_features=features * 4),
            torch.nn.ReLU(),
            torch.nn.Conv2d(in_channels=features*4, out_channels=features*4, kernel_size=3, padding=1, stride=1),
            torch.nn.BatchNorm2d(num_features=features * 4),
            torch.nn.ReLU()
        )
        self.upconv2 = torch.nn.ConvTranspose2d(
            features * 4, features * 2, kernel_size=2, stride=2
        )
        self.decode_layer2 = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels=features*4, out_channels=features*2, kernel_size=3, padding=1, stride=1),
            torch.nn.BatchNorm2d(num_features=features * 2),
            torch.nn.ReLU(),
            torch.nn.Conv2d(in_channels=features*2, out_channels=features*2, kernel_size=3, padding=1, stride=1),
            torch.nn.BatchNorm2d(num_features=features * 2),
            torch.nn.ReLU()
        )
        self.upconv1 = torch.nn.ConvTranspose2d(
            features * 2, features, kernel_size=2, stride=2
        )
        self.decode_layer1 = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels=features*2, out_channels=features, kernel_size=3, padding=1, stride=1),
            torch.nn.BatchNorm2d(num_features=features),
            torch.nn.ReLU(),
            torch.nn.Conv2d(in_channels=features, out_channels=features, kernel_size=3, padding=1, stride=1),
            torch.nn.BatchNorm2d(num_features=features),
            torch.nn.ReLU()
        )
        self.out_layer = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels=features, out_channels=out_features, kernel_size=1, padding=0, stride=1),
        )

    def forward(self, x):
        enc1 = self.encode_layer1(x)
        enc2 = self.encode_layer2(self.pool1(enc1))
        enc3 = self.encode_layer3(self.pool2(enc2))
        enc4 = self.encode_layer4(self.pool3(enc3))

        bottleneck = self.encode_decode_layer(self.pool4(enc4))
        dec4 = self.upconv4(bottleneck)
        dec4 = torch.cat((dec4, enc4), dim=1)
        dec4 = self.decode_layer4(dec4)

        dec3 = self.upconv3(dec4)
        dec3 = torch.cat((dec3, enc3), dim=1)
        dec3 = self.decode_layer3(dec3)

        dec2 = self.upconv2(dec3)
        dec2 = torch.cat((dec2, enc2), dim=1)
        dec2 = self.decode_layer2(dec2)

        dec1 = self.upconv1(dec2)
        dec1 = torch.cat((dec1, enc1), dim=1)
        dec1 = self.decode_layer1(dec1)

        out = self.out_layer(dec1)
        return out

6. 训练代码

if __name__ == '__main__':
    index = 0
    num_epochs = 50
    train_on_gpu = True
    unet = UNetModel().cuda()
    # model_dict = unet.load_state_dict(torch.load('unet_road_model-100.pt'))

    image_dir = 'CrackForest-dataset-master/CrackForest-dataset-master/image/'
    mask_dir = 'CrackForest-dataset-master/CrackForest-dataset-master/groundTruthPngImg/'
    dataloader = SegmentationDataset(image_dir, mask_dir)
    optimizer = torch.optim.SGD(unet.parameters(), lr=0.01, momentum=0.9)
    train_loader = DataLoader(
        dataloader, batch_size=1, shuffle=False)

    for epoch in range(num_epochs):
        train_loss = 0.0
        for i_batch, sample_batched in enumerate(train_loader):
            images_batch, target_labels = \
                sample_batched['image'], sample_batched['mask']
            print(target_labels.min())
            print(target_labels.max())

            if train_on_gpu:
                images_batch, target_labels = images_batch.cuda(), target_labels.cuda()
            optimizer.zero_grad()

            # forward pass: compute predicted outputs by passing inputs to the model
            m_label_out_ = unet(images_batch)
            # print(m_label_out_)
            # calculate the batch loss
            target_labels = target_labels.contiguous().view(-1)
            m_label_out_ = m_label_out_.transpose(1,3).transpose(1, 2).contiguous().view(-1, 2)
            target_labels = target_labels.long()
            loss = torch.nn.functional.cross_entropy(m_label_out_, target_labels)
            print(loss)
            # backward pass: compute gradient of the loss with respect to model parameters
            loss.backward()

            # perform a single optimization step (parameter update)
            optimizer.step()

            # update training loss
            train_loss += loss.item()
            if index % 100 == 0:
                print('step: {} \tcurrent Loss: {:.6f} '.format(index, loss.item()))
            index += 1
            # test(unet)
        # 计算平均损失
        train_loss = train_loss / dataloader.num_of_samples()
        # 显示训练集与验证集的损失函数
        print('Epoch: {} \tTraining Loss: {:.6f} '.format(epoch, train_loss))
        # test(unet)
    # save model
    unet.eval()
    torch.save(unet.state_dict(), 'unet_road_model.pt')

7. 测试代码

def test(unet):
    model_dict=unet.load_state_dict(torch.load('unet_road_model.pt'))
    root_dir = 'CrackForest-dataset-master/CrackForest-dataset-master/test/'
    fileNames = os.listdir(root_dir)
    for f in fileNames:
        image = cv.imread(os.path.join(root_dir, f), cv.IMREAD_GRAYSCALE)
        h, w = image.shape
        img = np.float32(image) /255.0
        img = np.expand_dims(img, 0)
        x_input = torch.from_numpy(img).view( 1, 1, h, w)
        probs = unet(x_input.cuda())
        m_label_out_ = probs.transpose(1, 3).transpose(1, 2).contiguous().view(-1, 2)
        grad, output = m_label_out_.data.max(dim=1)
        output[output > 0] = 255
        predic_ = output.view(h, w).cpu().detach().numpy()

        # print(predic_)
        # print(predic_.max())
        # print(predic_.min())

        # print(predic_)
        # print(predic_.shape)
        # cv.imshow("input", image)
        result = cv.resize(np.uint8(predic_), (w, h))

        cv.imshow("unet-segmentation-demo", result)
        cv.waitKey(0)
    cv.destroyAllWindows()

测试结果:
PyTorch项目实战1-构建UNet实现道路裂纹检测_第6张图片
PyTorch项目实战1-构建UNet实现道路裂纹检测_第7张图片
PyTorch项目实战1-构建UNet实现道路裂纹检测_第8张图片
PyTorch项目实战1-构建UNet实现道路裂纹检测_第9张图片
备注:当使用batch大于1的时候,测试时图像是黑色,检查发现,test里面最终的分割预测结果输出为0,怀疑可能是转换成分割时的图像存在问题。后面再去看看其他示例,确认下具体原因,如果有兄弟了解这种现象,可以告知下。。

你可能感兴趣的:(PyTorch学习,PyTorch,Unet)