Fully Convolutional Networks for Semantic Segmentation ———— 全卷积网络FCN代码解读之模型架构

Fully Convolutional Networks for Semantic Segmentation–用于语义分割的全卷积网络


文章目录

  • Fully Convolutional Networks for Semantic Segmentation--用于语义分割的全卷积网络
  • 一、数据预处理
    • 1.标签处理
    • 2.标签编码
  • 二、模型搭建
    • 1.引入库
    • 2.模型架构


一、数据预处理

  1. 标签处理
  2. 标签编码
  3. 可视化编码过程
  4. 定义预处理类

1.标签处理

利用if函数和file_path(list),连接数据标签路径,并裁剪图片大小

代码如下:

class CamvidDataset(Dataset):
    def __init__(self, file_path=[], crop_size=None):
        if len(file_path) != 2:
            raise ValueError("同时需要图片和标签文件夹的路径,图片路径在前")
            #保证正确读入图片和标签路径,逻辑是判断是否是2个元素,是继续执行,否则提示valueError
        self.img_path = file_path[0]
        self.label_path = file_path[1]
        self.imgs = self.read_file(self.img_path)
        self.labels = self.read_file(self.label_path)
        self.crop_size = crop_size

数据处理初始化将图片和标签路径提取出,保持图片路径在前

2.标签编码

利用哈希算法形成一对一或者多对一的映射关系,形成颜色到标签的对应关系。
编码函数: (p[0]*256+p[1])*256+p[2]
**原理:**一个像素点由编码函数转化为整数,将整数作为这个像素点在哈希表中的索引,并查到相对应的类别。

二、模型搭建

1.引入库

代码如下:

import numpy as np
import torch
from torchvision import models
from torch import nn

2.模型架构

Fully Convolutional Networks for Semantic Segmentation ———— 全卷积网络FCN代码解读之模型架构_第1张图片
输入图像前阶段卷积池化采用VGG网络
代码如下:

class FCN(nn.Module):
    def __init__(self, num_classes):
        super().__init__()

        self.stage1 = pretrained_net.features[:7]
        self.stage2 = pretrained_net.features[7:14]
        self.stage3 = pretrained_net.features[14:24]
        self.stage4 = pretrained_net.features[24:34]
        self.stage5 = pretrained_net.features[34:]

        self.scores1 = nn.Conv2d(512, num_classes, 1)
        self.scores2 = nn.Conv2d(512, num_classes, 1)
        self.scores3 = nn.Conv2d(128, num_classes, 1)

        self.conv_trans1 = nn.Conv2d(512, 256, 1)
        self.conv_trans2 = nn.Conv2d(256, num_classes, 1)

        self.upsample_8x = nn.ConvTranspose2d(num_classes, num_classes, 16, 8, 4, bias=False)
        self.upsample_8x.weight.data = bilinear_kernel(num_classes, num_classes, 16)

        self.upsample_2x_1 = nn.ConvTranspose2d(512, 512, 4, 2, 1, bias=False)
        self.upsample_2x_1.weight.data = bilinear_kernel(512, 512, 4)

        self.upsample_2x_2 = nn.ConvTranspose2d(256, 256, 4, 2, 1, bias=False)
        self.upsample_2x_2.weight.data = bilinear_kernel(256, 256, 4)

FCN-8s网络搭建

    def forward(self, x):
        s1 = self.stage1(x)
        s2 = self.stage2(s1)
        s3 = self.stage3(s2)
        s4 = self.stage4(s3)
        s5 = self.stage5(s4)

        scores1 = self.scores1(s5)
        s5 = self.upsample_2x_1(s5)
        add1 = s5 + s4

        scores2 = self.scores2(add1)

        add1 = self.conv_trans1(add1)
        add1 = self.upsample_2x_2(add1)
        add2 = add1 + s3

        output = self.conv_trans2(add2)
        output = self.upsample_8x(output)
        return output

你可能感兴趣的:(深度学习,计算机视觉,神经网络)