pytorch实现segnet_SegNet网络的Pytorch实现

1 importtorch.nn as nn2 importtorch3

4 classconv2DBatchNormRelu(nn.Module):5 def __init__(self,in_channels,out_channels,kernel_size,stride,padding,bias=True,dilation=1,is_batchnorm=True):6 super(conv2DBatchNormRelu,self).__init__()7 ifis_batchnorm:8 self.cbr_unit=nn.Sequential(9 nn.Conv2d(in_channels,out_channels,kernel_size=kernel_size,stride=stride,padding=padding,10 bias=bias,dilation=dilation),11 nn.BatchNorm2d(out_channels),12 nn.ReLU(inplace=True),13 )14 else:15 self.cbr_unit=nn.Sequential(16 nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding,17 bias=bias, dilation=dilation),18 nn.ReLU(inplace=True)19 )20

21 defforward(self,inputs):22 outputs=self.cbr_unit(inputs)23 returnoutputs24

25 classsegnetDown2(nn.Module):26 def __init__(self,in_channels,out_channels):27 super(segnetDown2,self).__init__()28 self.conv1=conv2DBatchNormRelu(in_channels,out_channels,kernel_size=3,stride=1,padding=1)29 self.conv2=conv2DBatchNormRelu(out_channels,out_channels,kernel_size=3,stride=1,padding=1)30 self.maxpool_with_argmax=nn.MaxPool2d(kernel_size=2,stride=2,return_indices=True)31

32 defforward(self,inputs):33 outputs=self.conv1(inputs)34 outputs=self.conv2(outputs)35 unpooled_shape=outputs.size()36 outputs,indices=self.maxpool_with_argmax(outputs)37 returnoutputs,indices,unpooled_shape38

39 classsegnetDown3(nn.Module):40 def __init__(self,in_channels,out_channels):41 super(segnetDown3,self).__init__()42 self.conv1=conv2DBatchNormRelu(in_channels,out_channels,kernel_size=3,stride=1,padding=1)43 self.conv2=conv2DBatchNormRelu(out_channels,out_channels,kernel_size=3,stride=1,padding=1)44 self.conv3=conv2DBatchNormRelu(out_channels,out_channels,kernel_size=3,stride=1,padding=1)45 self.maxpool_with_argmax=nn.MaxPool2d(kernel_size=2,stride=2,return_indices=True)46

47 defforward(self,inputs):48 outputs=self.conv1(inputs)49 outputs=self.conv2(outputs)50 outputs=self.conv3(outputs)51 unpooled_shape=outputs.size()52 outputs,indices=self.maxpool_with_argmax(outputs)53 returnoutputs,indices,unpooled_shape54

55

56 classsegnetUp2(nn.Module):57 def __init__(self,in_channels,out_channels):58 super(segnetUp2,self).__init__()59 self.unpool=nn.MaxUnpool2d(2,2)60 self.conv1=conv2DBatchNormRelu(in_channels,out_channels,kernel_size=3,stride=1,padding=1)61 self.conv2=conv2DBatchNormRelu(out_channels,out_channels,kernel_size=3,stride=1,padding=1)62

63 defforward(self,inputs,indices,output_shape):64 outputs=self.unpool(inputs,indices=indices,output_size=output_shape)65 outputs=self.conv1(outputs)66 outputs=self.conv2(outputs)67 returnoutputs68

69 classsegnetUp3(nn.Module):70 def __init__(self,in_channels,out_channels):71 super(segnetUp3,self).__init__()72 self.unpool=nn.MaxUnpool2d(2,2)73 self.conv1=conv2DBatchNormRelu(in_channels,out_channels,kernel_size=3,stride=1,padding=1)74 self.conv2=conv2DBatchNormRelu(out_channels,out_channels,kernel_size=3,stride=1,padding=1)75 self.conv3=conv2DBatchNormRelu(out_channels,out_channels,kernel_size=3,stride=1,padding=1)76

77 defforward(self,inputs,indices,output_shape):78 outputs=self.unpool(inputs,indices=indices,output_size=output_shape)79 outputs=self.conv1(outputs)80 outputs=self.conv2(outputs)81 outputs=self.conv3(outputs)82 returnoutputs83

84 classsegnet(nn.Module):85 def __init__(self,in_channels=3,num_classes=21):86 super(segnet,self).__init__()87 self.down1=segnetDown2(in_channels=in_channels,out_channels=64)88 self.down2=segnetDown2(64,128)89 self.down3=segnetDown3(128,256)90 self.down4=segnetDown3(256,512)91 self.down5=segnetDown3(512,512)92

93 self.up5=segnetUp3(512,512)94 self.up4=segnetUp3(512,256)95 self.up3=segnetUp3(256,128)96 self.up2=segnetUp2(128,64)97 self.up1=segnetUp2(64,64)98 self.finconv=conv2DBatchNormRelu(64,num_classes,3,1,1)99

100 defforward(self,inputs):101 down1,indices_1,unpool_shape1=self.down1(inputs)102 down2,indices_2,unpool_shape2=self.down2(down1)103 down3,indices_3,unpool_shape3=self.down3(down2)104 down4,indices_4,unpool_shape4=self.down4(down3)105 down5,indices_5,unpool_shape5=self.down5(down4)106

107 up5=self.up5(down5,indices=indices_5,output_shape=unpool_shape5)108 up4=self.up4(up5,indices=indices_4,output_shape=unpool_shape4)109 up3=self.up3(up4,indices=indices_3,output_shape=unpool_shape3)110 up2=self.up2(up3,indices=indices_2,output_shape=unpool_shape2)111 up1=self.up1(up2,indices=indices_1,output_shape=unpool_shape1)112 outputs=self.finconv(up1)113

114 returnoutputs115

116 if __name__=="__main__":117 inputs=torch.ones(1,3,224,224)118 model=segnet()119 print(model(inputs).size())120 print(model)

你可能感兴趣的:(pytorch实现segnet)