笔记:pytorch 的unet segnet模型

pytorch 的unet segnet模型

pytorch搭建的语义分割模型Unet SegNet

https://github.com/piglaker/SHcrack/tree/master/Desktop/pycharm/crack/net

Unet


import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optimi
import numpy as np


class UNet(nn.Module):
    def __init__(self, in_channels, output_channels):
        super(UNet, self).__init__()

        self.down1 = self.down(in_channels, 64, kernel_size = 3)

        self.mxp1 = nn.MaxPool2d(kernel_size = 2)

        self.down2 = self.down(64, 128, kernel_size = 3)

        self.mxp2 = nn.MaxPool2d(kernel_size = 2)

        self.down3 = self.down(128, 256, kernel_size = 3)

        self.mxp3 = nn.MaxPool2d(kernel_size = 2)

        self.down4 = self.down(256, 512, kernel_size = 3)

        self.mxp4 = nn.MaxPool2d(kernel_size = 2)

        self.bottom = nn.Sequential(
                            torch.nn.Conv2d(in_channels = 512, out_channels = 1024, kernel_size = 3 ),
                            torch.nn.ReLU(),
                            torch.nn.BatchNorm2d(1024),
                            torch.nn.Conv2d(in_channels = 1024, out_channels = 1024, kernel_size = 3,),
                            torch.nn.ReLU(),
                            torch.nn.BatchNorm2d(1024),
                            torch.nn.ConvTranspose2d(in_channels = 1024, out_channels = 512, kernel_size = 3, stride = 2, padding = 1, output_padding = 1)
                            )

        self.up1 = self.up(1024, 512, 256)

        self.up2 = self.up(512, 256, 128)

        self.up3 = self.up(256, 128, 64)

        self.final_layer = self.final(128, 64, out_channels = output_channels)




    def down(self, in_channels, out_channels, kernel_size = 3):
        stage = nn.Sequential(
            nn.Conv2d(in_channels = in_channels, out_channels = out_channels, kernel_size = kernel_size, ),
            nn.ReLU(),
            nn.BatchNorm2d(out_channels),
            nn.Conv2d(in_channels = out_channels, out_channels = out_channels, kernel_size = kernel_size),
            nn.ReLU(),
            nn.BatchNorm2d(out_channels),
        )
        return stage


    def up(self, in_channels, mid_channels, out_channels, kernel_size = 3):
        stage = nn.Sequential(
            nn.Conv2d(in_channels = in_channels, out_channels = mid_channels, kernel_size = kernel_size),
            nn.ReLU(),
            nn.BatchNorm2d(mid_channels),
            nn.Conv2d(in_channels = mid_channels, out_channels = mid_channels, kernel_size = kernel_size),
            nn.ReLU(),
            nn.BatchNorm2d(mid_channels),
            nn.ConvTranspose2d(in_channels = mid_channels, out_channels = out_channels, kernel_size = 3, stride = 2, padding = 1, output_padding = 1),
        )
        return stage


    def final(self, in_channels, mid_channels, out_channels, kernel_size=3):
        layers = nn.Sequential(
            nn.Conv2d(kernel_size = kernel_size, in_channels = in_channels, out_channels = mid_channels),
            nn.ReLU(),
            nn.BatchNorm2d(mid_channels),
            nn.Conv2d(kernel_size = kernel_size, in_channels = mid_channels, out_channels = mid_channels),
            nn.ReLU(),
            nn.BatchNorm2d(mid_channels),
            nn.Conv2d(kernel_size = kernel_size, in_channels = mid_channels, out_channels = out_channels, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(out_channels),
        )
        return layers


    def crop_and_concat(self, upsampled, bypass, crop = False):
        #copy from torch
        if crop:
            c = (bypass.size()[2] - upsampled.size()[2]) // 2
            bypass = F.pad(bypass, [-c, -c, -c, -c])
        return torch.cat((upsampled, bypass), 1)


    def forward(self, input):

        x = self.down1(input)
        feature_map1 = x
        x = self.mxp1(x)

        x = self.down2(x)
        feature_map2 = x
        x = self.mxp2(x)

        x = self.down3(x)
        feature_map3 = x
        x = self.mxp3(x)

        x = self.down4(x)
        feature_map4 = x
        x = self.mxp4(x)

        x = self.bottom(x)

        x = self.crop_and_concat(x, feature_map4, True)


        x = self.up1(x)

        x = self.crop_and_concat(x, feature_map3, True)

        x = self.up2(x)

        x = self.crop_and_concat(x, feature_map2, True)

        x = self.up3(x)

        x = self.crop_and_concat(x, feature_map1, True)

        x = self.final_layer(x)

        return x



if __name__ == "__main__":
    """
    testing
    """


   model = UNet(1, 2)
    x = torch.rand(1, 1, 572, 572)
    out = model(x)
    loss = torch.sum(out)
    loss.backward()






SegNet


import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

class SegNet(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(SegNet, self).__init__()


        self.conv11 = nn.Conv2d(in_channels, 64, kernel_size = 3, padding = 1)
        self.bn11 = nn.BatchNorm2d(64)
        self.conv12 = nn.Conv2d(64, 64, kernel_size = 3, padding = 1)
        self.bn12 = nn.BatchNorm2d(64)
        #maxpool1
        self.conv21 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn21 = nn.BatchNorm2d(128)
        self.conv22 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.bn22 = nn.BatchNorm2d(128)
        #maxpool2
        self.conv31 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.bn31 = nn.BatchNorm2d(256)
        self.conv32 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        self.bn32 = nn.BatchNorm2d(256)
        self.conv33 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        self.bn33 = nn.BatchNorm2d(256)
        #maxpooling3
        self.conv41 = nn.Conv2d(256, 512, kernel_size=3, padding=1)
        self.bn41 = nn.BatchNorm2d(512)
        self.conv42 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.bn42 = nn.BatchNorm2d(512)
        self.conv43 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.bn43 = nn.BatchNorm2d(512)
        #maxpooling4
        self.conv51 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.bn51 = nn.BatchNorm2d(512)
        self.conv52 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.bn52 = nn.BatchNorm2d(512)
        self.conv53 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.bn53 = nn.BatchNorm2d(512)
        #maxpooling5

        self.conv51d = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.bn51d = nn.BatchNorm2d(512)
        self.conv52d = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.bn52d = nn.BatchNorm2d(512)
        self.conv53d = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.bn53d = nn.BatchNorm2d(512)

        self.conv43d = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.bn43d = nn.BatchNorm2d(512)
        self.conv42d = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.bn42d = nn.BatchNorm2d(512)
        self.conv41d = nn.Conv2d(512, 256, kernel_size=3, padding=1)
        self.bn41d = nn.BatchNorm2d(256)

        self.conv33d = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        self.bn33d = nn.BatchNorm2d(256)
        self.conv32d = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        self.bn32d = nn.BatchNorm2d(256)
        self.conv31d = nn.Conv2d(256, 128, kernel_size=3, padding=1)
        self.bn31d = nn.BatchNorm2d(128)

        self.conv22d = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.bn22d = nn.BatchNorm2d(128)
        self.conv21d = nn.Conv2d(128, 64, kernel_size=3, padding=1)
        self.bn21d = nn.BatchNorm2d(64)

        self.conv12d = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.bn12d = nn.BatchNorm2d(64)
        self.conv11d = nn.Conv2d(64, out_channels, kernel_size=3, padding=1)

    def forward(self, input):

        x11 = F.relu(self.bn11(self.conv11(input)))
        x12 = F.relu(self.bn12(self.conv12(x11)))
        x1p, id1 = F.max_pool2d_with_indices(x12,kernel_size = 2, stride = 2,return_indices = True)

        x21 = F.relu(self.bn21(self.conv21(x1p)))
        x22 = F.relu(self.bn22(self.conv22(x21)))
        x2p, id2 = F.max_pool2d_with_indices(x22, kernel_size=2, stride=2, return_indices=True)

        x31 = F.relu(self.bn31(self.conv31(x2p)))
        x32 = F.relu(self.bn32(self.conv32(x31)))
        x33 = F.relu(self.bn33(self.conv33(x32)))
        x3p, id3 = F.max_pool2d_with_indices(x33,kernel_size = 2, stride = 2,return_indices = True)

        x41 = F.relu(self.bn41(self.conv41(x3p)))
        x42 = F.relu(self.bn42(self.conv42(x41)))
        x43 = F.relu(self.bn43(self.conv43(x42)))
        x4p, id4 = F.max_pool2d_with_indices(x43,kernel_size = 2, stride = 2,return_indices = True)

        x51 = F.relu(self.bn51(self.conv51(x4p)))
        x52 = F.relu(self.bn52(self.conv52(x51)))
        x53 = F.relu(self.bn53(self.conv53(x52)))
        x5p, id5 = F.max_pool2d(x53, kernel_size = 2,stride = 2,return_indices =True)
        print(x5p.size(), id5.size())
        #  unpooling - conv - bn - activation
        #            - conv - bn - activation
        #            - conv - bn - activation
        #            -

        x5d = F.max_unpool2d(x5p, id5, kernel_size=2, stride=2)
        x53d = F.relu(self.bn53d(self.conv53d(x5d)))
        x52d = F.relu(self.bn52d(self.conv52d(x53d)))
        x51d = F.relu(self.bn51d(self.conv51d(x52d)))

        x4d = F.max_unpool2d(x51d, id4, kernel_size=2, stride=2)
        x43d = F.relu(self.bn43d(self.conv43d(x4d)))
        x42d = F.relu(self.bn42d(self.conv42d(x43d)))
        x41d = F.relu(self.bn41d(self.conv41d(x42d)))

        x3d = F.max_unpool2d(x41d, id3, kernel_size=2, stride=2)
        x33d = F.relu(self.bn33d(self.conv33d(x3d)))
        x32d = F.relu(self.bn32d(self.conv32d(x33d)))
        x31d = F.relu(self.bn31d(self.conv31d(x32d)))

        x2d = F.max_unpool2d(x31d, id2, kernel_size=2, stride=2)
        x22d = F.relu(self.bn22d(self.conv22d(x2d)))
        x21d = F.relu(self.bn21d(self.conv21d(x22d)))

        x1d = F.max_unpool2d(x21d, id1, kernel_size=2, stride=2)
        x12d = F.relu(self.bn12d(self.conv12d(x1d)))
        x11d = self.conv11d(x12d)

        return x11d


if __name__ == "__main__":
    """
    testing
    """
    model = SegNet(1, 2)
    x = torch.rand(1, 1, 320, 320)
    out = model(x)
    loss = torch.sum(out)
    loss.backward()









你可能感兴趣的:(笔记:pytorch 的unet segnet模型)