pytorch实现基于resnet的Unet

  1. resnet可以作为Unet的编码模块,只要把最后的全连接层去掉即可,其它无需改变。
  2. Unet网络的整体结构为编码模块,和解码模块。解码模块要将每个stage模块的中间输出保存下来,以便与对应的解码模块的stage相连接。
  3. Unet的创新之处在于它的解码模块,所以代码实现的时候尤其要关注每个模块的输入输出信道数,解码模块中间几个模块的处理过程相同,可以用for循环处理,而作为解码模块的输入和输出,需要单独处理。
  4. 基于resnet的Unet,就是将resnet模块嵌入到编码模块,Unet解码模块的处理都一样。
    `import torch
    import torch.nn as nn
    from torch.hub import load_state_dict_from_url

class Unet(nn.Module):
#初始化参数:Encoder,Decoder,bridge
#bridge默认值为无,如果有参数传入,则用该参数替换None
def init(self,Encoder,Decoder,bridge = None):
super(Unet,self).init()
self.encoder = Encoder(encoder_blocks)
self.decoder = Decoder(decoder_blocks)
self.bridge = bridge
def forward(self,x):
res = self.encoder(x)
out,skip = res[0],res[1,:]
if bridge is not None:
out = bridge(out)

你可能感兴趣的:(pytorch,深度学习)