torch里面有连个build网络的基本概念,一个是modules,另一个则是containers。
modules是一些可以简单执行前传和后传的层
而containers则是将modules结合在一起的一个模块
1、nngraph建立一个新的网络将重点放在了网络的结构图上面,nngraph重载了call 运算符,也即(),这个重载的运算符可以应用在任何的module上面,避免了container的使用。通过执行call操作,会返回一个nngraph包含一个nn.Module这样的一个模块的node
call操作接受父节点作为输入参数,指定哪一个作为前传
h1 = nn.Linear(20, 10)()
h2 = nn.Linear(10, 1)(h1)
mlp = nn.gModule({h1},{h2})
上面的结构h1是第一层,h2是第二层,最后用nn.gModule模块接受{h1},{h2}两个参数,{h1}代表的是nn的输入node的一个table,{h2}代表的是nn的输入node的一个table
2、分析一个复杂点的结构
m = nn.Sequential()
m:add(nn.SplitTable(1))
m:add(nn.ParallelTable():add(nn.Linear(10, 20)):add(nn.Linear(10, 30)))
input = nn.Identity()()
input1, input2 = m(input):split(2) --sequential模块的输出是input1,input2,这里的split是对m的输出进行操作
m3 = nn.JoinTable(1)({input1, input2})
g = nn.gModule({input}, {m3})
indata = torch.rand(2, 10)
gdata = torch.rand(50)
g:forward(indata)
g:backward(indata, gdata)
mapindex是父节点,module是模块的名字,input是输入,output是输出
3、
nngraph还有一些别的性质,可以利用graphviz的性质为整个输出的图加上特殊的可视化的性质
4、
此外,nngraph还能在debug模式下测试网络是否是有错误的。参考官网的例子即可
-- draw graph (the forward graph, '.fg'), use it with itorch notebook
graph.dot(model.fg, 'MLP')
-- or save graph to file MLP.svg and MLP.dot
graph.dot(model.fg, 'MLP', 'MLP')
graph.dot两行的区别在于第一行需要显示,如果qtlua不成功就会失败,第二行只需要保存生成的svg和dot文件即可
5、既然nngraph代表的是container,那么他就可以调用container的get(),size()等函数
self.model:get(3).output
获得网络中第三个模块的输出