如何调用
点击进入Unet
encoder_name = "resnet101"
in_channels = 3
depth = encoder_depth # 默认值5
weights = "imagenet"
get_encoder
怎么实现的Encoder = encoders[name]["encoder"]
name
, 根据上一步传入的参数可以知道name = encoder_name = "restnet101"
encoders
是什么呢?可以猜测得到他是一个dict
,根据encoder_name
取值Encoder
到底是什么encoders
import
中导入resnet101
, 于是大胆推测我需要的就是第一行的resnet_encoders
resnet_encoders
到底是什么?dict
resnet101
get_encoder
里面的代码Encoder = encoders[name][encoder]
得到这个Encoder = ResNetEncoder
ResNetEncoder
源码forward
然后逐行执行
x
,这里其实就是输入的特征stages = self.get_stages()
for
循环,循环次数_depth
默认为5,也可自动传入,先做个记号A
,待会会用到这里stages
的实际长度为6,如果你输入的depth > 5
就会出错哦stages
nn.Identity()
实际就是一个输入层nn.Sequential(self.conv1, self.bn1, self.relu)
,这一层需要执行三个操作,继续看着三个操作是什么nn.Sequential(self.maxpool, self.layer1)
, 这里一看就是一个最大池化层,那self.layer1
是什么呢?还有下面的self.layer2
,self.layer3
, self.layer4
self._make_layer
,只是传入的参数不同self.layer1...self.layer4
看出来,其实主要就是block
和blocks
,根据形参和实参位置对应,得到blocks = layers[*]
, 也就是说主要的就是block
和layers
Encoder = encoders[name]["encoder"]
, 只是初始化了类,还并未实例化,更不用说调用了,那么,问题就容易解决了,去看看在哪里实例化的。直接回到初始位置get_encoder
可以发现下图params
是什么?这里也可以看见传入的之前记号Adepth
block = Bottleneck
以及layers=[3, 4, 23, 3]
Bottleneck
self._make_layer(...)
block
主要是对传入的特征进行如下操作:省略参数
Unet.encoder
是如何组成的,其实可以画个图更加直观,但是由于时间优先,这里暂时先不加,后续画完了我再添加上来,另外写这篇文章,其实主要目的是帮我自己梳理思路,边分析编写可以记得更真切。同时如果能帮助到跟我一样有迷惑的人,那就更好了。——<未完待续>