nn.ModuleList

​ ModuleList:顾名思义,专门用于存储module的list。

参数

​ nn.ModuleList接受的必须是subModule类型,即不管ModuleList包裹了多少个列表,内嵌的所有列表的内部都要是可迭代的Module的子类 ,如:

nn.ModuleList([nn.ModuleList([Conv(inp_dim + j * increase, oup_dim, 1, relu=False, bn=False) for j in range(5)]) for i in range(nstack)])

​ 在这个例子中,二次嵌套的list内部也必须额外使用一个nn.ModuleList修饰实例化,否则会无法识别类型而报错!

形状

​ ModuleList不是创建前后连接的网络,而是创建上下并列的网络。

实现

​ nn.ModuleList可以像python里的list一样对模型的各个层进行索引。

​ 当添加 nn.ModuleList 作为 nn.Module 对象的一个成员时(即当我们添加模块到我们的网络时),所有 nn.ModuleList 内部的 nn.Module 的 parameter 将自动被添加作为我们的网络的 parameter。这些层的参数只有被正确注册,优化器才能发现和训练这些参数!

​ nn.ModuleList没有实现forward()方法

你可能感兴趣的:(推荐,python,pytorch)