整理pytorch中一些常用的,容易忘记的知识点,持续更新。。。
简单的说就是children()输出模块的第一层子节点,modules是深度遍历输出所有的子节点
import torch.nn as nn
m = nn.Sequential(nn.Linear(2,2),
nn.ReLU(),
nn.Sequential(nn.Sigmoid(), nn.ReLU()))
list(m.children())
[Linear(in_features=2, out_features=2, bias=True), ReLU(), Sequential(
(0): Sigmoid()
(1): ReLU()
)]
list(m.modules())
[Sequential(
(0): Linear(in_features=2, out_features=2, bias=True)
(1): ReLU()
(2): Sequential(
(0): Sigmoid()
(1): ReLU()
)
), Linear(in_features=2, out_features=2, bias=True), ReLU(), Sequential(
(0): Sigmoid()
(1): ReLU()
), Sigmoid(), ReLU()]
三层嵌套的Sequential也是这样
import torch.nn as nn
m = nn.Sequential(nn.Linear(2,2),
nn.ReLU(),
nn.Sequential(nn.Sequential(nn.ReLU()), nn.Sigmoid(), nn.ReLU()))
list(m.children())
[Linear(in_features=2, out_features=2, bias=True), ReLU(), Sequential(
(0): Sequential(
(0): ReLU()
)
(1): Sigmoid()
(2): ReLU()
)]
list(m.modules())
[Sequential(
(0): Linear(in_features=2, out_features=2, bias=True)
(1): ReLU()
(2): Sequential(
(0): Sequential(
(0): ReLU()
)
(1): Sigmoid()
(2): ReLU()
)
), Linear(in_features=2, out_features=2, bias=True), ReLU(), Sequential(
(0): Sequential(
(0): ReLU()
)
(1): Sigmoid()
(2): ReLU()
), Sequential(
(0): ReLU()
), ReLU(), Sigmoid(), ReLU()]
参考:
https://discuss.pytorch.org/t/module-children-vs-module-modules/4551
https://blog.csdn.net/dss_dssssd/article/details/83958518