nngraph包在构建更加复杂的网络极其有用。毕竟是有点类似”静态图“了。
简单来说就是以前加网络需要不断add,现在用了nngraph,只要不断”一“就行了。
h1 = - nn.Linear(20,10)
h2 = h1
- nn.Tanh()
- nn.Linear(10,10)
- nn.Tanh()
- nn.Linear(10, 1)
mlp = nn.gModule({h1}, {h2})
注意点:
1. 刚开始时需要用”-“来初始化。
2. 在nn.gModule中写入两个table,第一个table表示输入节点,第二个是输出节点。
当然,这两个table都可以有多个值。值得注意的是。这两个table必须是”node“。不能是任何其他的。
以Unet结构为例子:
function defineG_unet(input_nc, output_nc, ngf)
local netG = nil
-- input is (nc) x 256 x 256
-- 初始化时先用“-”
local e1 = - nn.SpatialConvolution(input_nc, ngf, 4, 4, 2, 2, 1, 1)
-- input is (ngf) x 128 x 128
local e2 = e1 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf, ngf * 2, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 2)
-- input is (ngf * 2) x 64 x 64
local e3 = e2 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 2, ngf * 4, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 4)
-- input is (ngf * 4) x 32 x 32
local e4 = e3 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 4, ngf * 8, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 8)
-- input is (ngf * 8) x 16 x 16
local e5 = e4 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 8)
-- input is (ngf * 8) x 8 x 8
local e6 = e5 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 8)
-- input is (ngf * 8) x 4 x 4
local e7 = e6 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 8)
-- input is (ngf * 8) x 2 x 2
local e8 = e7 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) -- nn.SpatialBatchNormalization(ngf * 8)
-- input is (ngf * 8) x 1 x 1
local d1_ = e8 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 8) - nn.Dropout(0.5)
-- input is (ngf * 8) x 2 x 2
local d1 = {d1_,e7} - nn.JoinTable(2)
local d2_ = d1 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8 * 2, ngf * 8, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 8) - nn.Dropout(0.5)
-- input is (ngf * 8) x 4 x 4
local d2 = {d2_,e6} - nn.JoinTable(2)
local d3_ = d2 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8 * 2, ngf * 8, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 8) - nn.Dropout(0.5)
-- input is (ngf * 8) x 8 x 8
local d3 = {d3_,e5} - nn.JoinTable(2)
local d4_ = d3 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8 * 2, ngf * 8, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 8)
-- input is (ngf * 8) x 16 x 16
local d4 = {d4_,e4} - nn.JoinTable(2)
local d5_ = d4 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8 * 2, ngf * 4, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 4)
-- input is (ngf * 4) x 32 x 32
local d5 = {d5_,e3} - nn.JoinTable(2)
local d6_ = d5 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 4 * 2, ngf * 2, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 2)
-- input is (ngf * 2) x 64 x 64
local d6 = {d6_,e2} - nn.JoinTable(2)
local d7_ = d6 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 2 * 2, ngf, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf)
-- input is (ngf) x128 x 128
local d7 = {d7_,e1} - nn.JoinTable(2)
local d8 = d7 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 2, output_nc, 4, 4, 2, 2, 1, 1)
-- input is (nc) x 256 x 256
local o1 = d8 - nn.Tanh()
-- 输入节点只有一个,e1
-- 输出节点也只有一个, o1
netG = nn.gModule({e1},{o1})
return netG
end
简洁明了,另外,你如果用普通方式构建Unet,github上也有:
https://github.com/dmarnerides/dlt/blob/master/src/models/unet.lua
注意:
以前构建都是用 nn.Module来构建网络。而nngraph包中是用nngraph.Node来构建的网络的,构建出来的网络类型是nn.gModule,nn.gModule是继承自nn.Module的子类nn.Container。nn.Module可以通过上述提到的“-”来 “变成” nngraph.Node类型。
h1 = - nn.Linear(20,20)
h2 = - nn.Linear(10,10)
hh1 = h1 - nn.Tanh() - nn.Linear(20,1)
hh2 = h2 - nn.Tanh() - nn.Linear(10,1)
madd = {hh1,hh2} - nn.CAddTable()
oA = madd - nn.Sigmoid()
oB = madd - nn.Tanh()
gmod = nn.gModule( {h1,h2}, {oA,oB} )
local out = gmod:forward({input1, input2})
local out1 = out[1]
local out2 = out[2]
--或是
local unpack = unpack or table.unpack
local out1,out2 = unpack(out)
-- 反向传播
gmod:backward({grad1, grad2})
h1 = - nn.Linear(20,20)
h2 = - nn.Linear(10,10)
hh1 = h1 - nn.Tanh() - nn.Linear(20,1)
hh2 = h2 - nn.Tanh() - nn.Linear(10,1)
madd = {hh1,hh2} - nn.CAddTable()
oA = madd - nn.Sigmoid()
oB = madd - nn.Tanh()
gmod = nn.gModule( {h1,h2}, {oA,oB} )
比如下面这样:
h1 = - nn.Linear(20,20)
--这样是错误的,必须要用内置的nn.Module的层,这种自定义的层,重载了
-- nn.Module,会导致出错。
h2 = - myLayer
hh1 = h1 - nn.Tanh() - nn.Linear(20,1)
hh2 = h2 - nn.Tanh() - nn.Linear(10,1)
madd = {hh1,hh2} - nn.CAddTable()
oA = madd - nn.Sigmoid()
oB = madd - nn.Tanh()
gmod = nn.gModule( {h1,h2}, {oA,oB} )
Expected nnop.Parameters node, found : nn.MyLayer
当然,中间节点可以是自定义层。
gModule可以允许多个输入,多个输出,当然构建gModule的modules不能构成环。
每个结点可以有多个输入,这些结点的输入顺序存储在node.data.mapindex
中,即每个结点的父结点的指针。
每个结点的输入只能是一个,当然我们可以用类似nn.JoinTable(2)
的方法将其join后,输入到一个结点中。如果输出可以是一张表,存着要输出的各个结点。
另外,node.data.input
是一个list,存储着所有输入,如果输入只有一个,那么只有node.data.input[1]
使用了。
值得注意的是, node.data.gradOutput
是将所有的连接到该结点的反传梯度全部相加,再往前传。
还有一点就是,网络的第一个结点和最后一个结点都是“dummy”的。因为输入和输出可能是多个,所以要用这些dummy的结点来分别处理多个输入和输出。比如对于多个输出,则最后一个dummy结点,内部通过split操作,可以将网络分块,每个结点对应这些“网络块”,当然这些小块很可能会有重叠区域。有趣的是,对于多个输入和多个输出,我们反传时,输入可以不要求全部填满,比如只填第一个输入,那么整个网络相当于截取只和第一个输入有关的网络进行更新。
每个结点包含一个module。如果只是单独想看module的信息,那么直接net.modules[i]
就行。而net.fowardnodes
可以获得每个结点的信息。
只要获得了node
,我们主要可以获得结点的输入以及gradOutput.
node.data.input
node.data.gradOutput
你可能会问,那gradInput呢?结点是没有gradInput的。每个结点首先综合gradOutput,从而得到本结点的gradOutput, 再将gradOuput传入结点内部的module,得到gradInput之后,再将gradInput作为该结点子节点的gradOuput.
由于每个结点存储了一个module
node.data.module
以得到存储某个结点的module的各种信息。
--这是获得net.modules[9]的。
--这种写法的好处就是,可以获得结点的信息,而不是单独结点内部的module的信息。
local ind = 10
local latent = nil
for indexNode, node in pairs(net.forwardnodes) do
if indexNode == ind then
if node.data.module then
latent = node.data.module.output:clone() -- use it to get the specific module output
end
end
end
大概步骤:
首先进行检查输入和输出。要求必须是nngraph.Node
类型的。就是上面提到的-
可以把nn.Module变成nngraph.Node类型。
然后对输入再次进行检查,如果输入只有一个,直接inputs[1]:add(innode,true)
,
否则多个输入必须检查每个输入结点不能有子结点。
然后再构建2张图:fg和bg图。fg图用于网络前向传播评估用,bg用于反向传播。
如果输入结点有多个,还要对每个结点进行assert
assert(root.data.module, 'Expected nnop.Parameters node, module not found in node')
assert(torch.typename(root.data.module) == 'nnop.Parameters',
'Expected nnop.Parameters node, found : ' ..torch.typename(root.data.module))
这也是为什么在刚才的例子中,不能输入结点不能是自定义的层。从代码上简要的看,如果输入结点只有一个的话,应该可以。
最后把每个结点的其他一些信息,加入到gModule中。
function gModule:__init(inputs,outputs)
parent.__init(self)
-- the graph is defined backwards, we have the output modules as input here
-- we will define a dummy output node that connects all output modules
-- into itself. This will be the output for the forward graph and
-- input point for the backward graph
local node
local outnode = nngraph.Node({input={}})
for i = 1, utils.tableMaxN(outputs) do
node = outputs[i]
if torch.typename(node) ~= 'nngraph.Node' then
error(utils.expectingNodeErrorMessage(node, 'outputs', i))
end
outnode:add(node, true)
end
for i = 1, utils.tableMaxN(inputs) do
node = inputs[i]
if torch.typename(node) ~= 'nngraph.Node' then
error(utils.expectingNodeErrorMessage(node, 'inputs', i))
end
end
-- We add also a dummy input node.
-- The input node will be split to feed the passed input nodes.
local innode = nngraph.Node({input={}})
assert(#inputs > 0, "no inputs are not supported")
if #inputs == 1 then
inputs[1]:add(innode,true)
else
local splits = {innode:split(#inputs)}
for i = 1, #inputs do
assert(#inputs[i].children == 0, "an input should have no inputs")
end
for i = 1, #inputs do
inputs[i]:add(splits[i],true)
end
end
-- the backward graph (bg) is for gradients
-- the forward graph (fg) is for function evaluation
self.bg = outnode:graph()
self.fg = self.bg:reverse()
-- the complete graph is constructed
-- now regenerate the graphs with the additional nodes
local roots = self.fg:roots()
-- if there are more than one root in the forward graph, then make sure that
-- extra roots are parameter nodes
if #roots > 1 then
local innodeRoot = nil
-- first find our innode
for _, root in ipairs(roots) do
if root.data == innode.data then
assert(innodeRoot == nil, 'more than one matching input node found in leaves')
innodeRoot = root
else
assert(root.data.module, 'Expected nnop.Parameters node, module not found in node')
assert(torch.typename(root.data.module) == 'nnop.Parameters',
'Expected nnop.Parameters node, found : ' ..torch.typename(root.data.module))
end
end
assert(innodeRoot ~= nil, 'input node not found among roots')
self.innode = innodeRoot
else
assert(#self.fg:roots() == 1, "expecting only one start")
self.innode = self.fg:roots()[1]
end
assert(self.innode.data == innode.data, "expecting the forward innode")
self.outnode = outnode
self.verbose = false
self.nInputs = #inputs
-- computation on the graph is done through topsort of forward and backward graphs
self.forwardnodes = self.fg:topsort()
self.backwardnodes = self.bg:topsort()
-- iteratare over all nodes: check, tag and add to container
for i,node in ipairs(self.forwardnodes) do
-- check for unused inputs or unused split() outputs
if node.data.nSplitOutputs and node.data.nSplitOutputs ~= #node.children then
local nUnused = node.data.nSplitOutputs - #node.children
local debugLabel = node.data.annotations._debugLabel
local errStr =
"%s of split(%s) outputs from the node declared at %s are unused"
error(string.format(errStr, nUnused, node.data.nSplitOutputs,
debugLabel))
end
-- Check whether any nodes were defined as taking this node as an input,
-- but then left dangling and don't connect to the output. If this is
-- the case, then they won't be present in forwardnodes, so error out.
for successor, _ in pairs(node.data.reverseMap) do
local successorIsInGraph = false
-- Only need to the part of forwardnodes from i onwards, topological
-- sort guarantees it cannot be in the first part.
for j = i+1, #self.forwardnodes do
-- Compare equality of data tables, as new Node objects have been
-- created by processes such as topoological sort, but the
-- underlying .data table is shared.
if self.forwardnodes[j].data == successor.data then
successorIsInGraph = true
break
end
end
local errStr =
"node declared on %s does not connect to gmodule output"
assert(successorIsInGraph,
string.format(errStr, successor.data.annotations._debugLabel))
end
-- set data.forwardNodeId for node:label() output
node.data.forwardNodeId = node.id
-- add module to container
if node.data.module then
self:add(node.data.module)
end
end
self.output = nil
self.gradInput = nil
if #self.outnode.children > 1 then
self.output = self.outnode.data.input
end
end
主要调用这个函数,可以看到
local function getTotalGradOutput(node)
local gradOutput = node.data.gradOutput
assert(istable(gradOutput), "expecting gradients to sum")
if #gradOutput > 1 then
-- Check if we can bypass the allocation, for the special case where all
-- gradOutputs but one are zero tensors with an underlying one-element
-- storage. Note that for the case that we
-- cannot bypass it, this check will only be performed once
if not node.data.gradOutputBuffer then
local count = 0
local idx = 1
-- Count how many gradOutput are tensors of 1 element filled with zero
for i=1,#gradOutput do
local zero = torch.isTensor(gradOutput[i]) and
gradOutput[i]:storage() ~= nil and
gradOutput[i]:storage():size() == 1 and
gradOutput[i]:storage()[1] == 0
if not zero then
idx = i
count = count + 1
end
end
if count < 2 then
-- Return the only non-zero one, or the first one
-- if they are all zero
return gradOutput[idx]
end
end
node.data.gradOutputBuffer = node.data.gradOutputBuffer or nesting.cloneNested(gradOutput[1])
local gobuff = node.data.gradOutputBuffer
nesting.resizeNestedAs(gobuff, gradOutput[1])
nesting.copyNested(gobuff, gradOutput[1])
-- 注释:
for i=2,#gradOutput do
nesting.addNestedTo(gobuff, gradOutput[i])
end
gradOutput = gobuff
else
gradOutput = gradOutput[1]
end
return gradOutput
end
注释:可以看到这里首先将第一个节点的第一个梯度node.data.gradOutput[1]
作为gobuff
, 然后对于其他的梯度进行addNestedTo
操作。
-- Adds the input to the output.
-- The input can contain nested tables.
-- The output will contain the same nesting of tables.
function nesting.addNestedTo(output, input)
if torch.isTensor(output) then
output:add(input) --不断累加
else
for key, child in pairs(input) do
assert(output[key] ~= nil, "missing key")
nesting.addNestedTo(output[key], child)
end
end
end
可以看到是不断累加的。得到总的gradOuput后,再传入Node里面的module内,调用
module的updateGradInput来计算。
反向传播的主要代码:
function gModule:updateGradInput(input,gradOutput)
local function neteval(node)
if node.data.selectindex then
assert(not node.data.module, "the selectindex-handling nodes should have no module")
assert(#node.children == 1, "only the splitted node should be the input")
local child = node.children[1]
local go = getTotalGradOutput(node)
child.data.gradOutput = child.data.gradOutput or {}
assert(#child.data.gradOutput <= 1, "the splitted node should be used only once")
-- The data.gradOutput holds the to-be-summed gradients.
child.data.gradOutput[1] = child.data.gradOutput[1] or {}
assert(not child.data.gradOutput[1][node.data.selectindex], "no gradOutput should be assigned yet")
child.data.gradOutput[1][node.data.selectindex] = go
else
--********得到结点的总的gradOutput**************
local gradOutput = getTotalGradOutput(node)
-- updateGradInput through this node
-- If no module is present, the node behaves like nn.Identity.
local gradInput
if not node.data.module then
gradInput = gradOutput
else
local input = node.data.input
-- a parameter node is captured
if input == nil and node.data.module ~= nil then
input = {}
end
if #input == 1 then
input = input[1]
end
--********得到结点存储的module,调用module的updateGradInput
--********进行更新。
local module = node.data.module
gradInput = module:updateGradInput(input,gradOutput)
end
-- 反传时,将该结点的梯度传给每个子结点的data.gradOutput
-- propagate the output to children
for i,child in ipairs(node.children) do
child.data.gradOutput = child.data.gradOutput or {}
local mapindex = node.data.mapindex[child.data]
local gi
if #node.children == 1 then
gi = gradInput
else
gi = gradInput[mapindex]
end
table.insert(child.data.gradOutput,gi)
end
end
if self.verbose then
print(' V : ' .. node:label())
end
end
local outnode = self.outnode
if #outnode.children > 1 and #gradOutput ~= #outnode.children then
error(string.format('Got %s gradOutputs instead of %s', #gradOutput, #outnode.children))
end
for _,node in ipairs(self.backwardnodes) do
local gradOutput = node.data.gradOutput
while gradOutput and #gradOutput >0 do
table.remove(gradOutput)
end
end
-- Set the starting gradOutput.
outnode.data.gradOutput = outnode.data.gradOutput or {}
outnode.data.gradOutput[1] = gradOutput
for i,node in ipairs(self.backwardnodes) do
neteval(node)
end
assert(#self.innode.data.gradOutput == 1, "expecting the innode to be used only once")
self.gradInput = self.innode.data.gradOutput[1]
return self.gradInput
end
可以看到,反传时,就是先对每个Node综合一下梯度,然后将该梯度传给该结点的每个子节点。然后对每个结点这样做,最终的梯度就是self.innode.data.gradOutput[1]