Torch7入门续集补充--- nngraph包的使用

构建方法

nngraph包在构建更加复杂的网络极其有用。毕竟是有点类似”静态图“了。
简单来说就是以前加网络需要不断add,现在用了nngraph,只要不断”“就行了。

h1 = - nn.Linear(20,10)
h2 = h1
     - nn.Tanh()
     - nn.Linear(10,10)
     - nn.Tanh()
     - nn.Linear(10, 1)
mlp = nn.gModule({h1}, {h2})

注意点:
1. 刚开始时需要用”-“来初始化。
2. 在nn.gModule中写入两个table,第一个table表示输入节点,第二个是输出节点。
当然,这两个table都可以有多个值。值得注意的是。这两个table必须是”node“。不能是任何其他的。

以Unet结构为例子:

function defineG_unet(input_nc, output_nc, ngf)
    local netG = nil
    -- input is (nc) x 256 x 256
    -- 初始化时先用“-”
    local e1 = - nn.SpatialConvolution(input_nc, ngf, 4, 4, 2, 2, 1, 1)
    -- input is (ngf) x 128 x 128
    local e2 = e1 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf, ngf * 2, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 2)
    -- input is (ngf * 2) x 64 x 64
    local e3 = e2 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 2, ngf * 4, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 4)
    -- input is (ngf * 4) x 32 x 32
    local e4 = e3 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 4, ngf * 8, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 8)
    -- input is (ngf * 8) x 16 x 16
    local e5 = e4 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 8)
    -- input is (ngf * 8) x 8 x 8
    local e6 = e5 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 8)
    -- input is (ngf * 8) x 4 x 4
    local e7 = e6 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 8)
    -- input is (ngf * 8) x 2 x 2
    local e8 = e7 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) -- nn.SpatialBatchNormalization(ngf * 8)
    -- input is (ngf * 8) x 1 x 1

    local d1_ = e8 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 8) - nn.Dropout(0.5)
    -- input is (ngf * 8) x 2 x 2
    local d1 = {d1_,e7} - nn.JoinTable(2)
    local d2_ = d1 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8 * 2, ngf * 8, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 8) - nn.Dropout(0.5)
    -- input is (ngf * 8) x 4 x 4
    local d2 = {d2_,e6} - nn.JoinTable(2)
    local d3_ = d2 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8 * 2, ngf * 8, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 8) - nn.Dropout(0.5)
    -- input is (ngf * 8) x 8 x 8
    local d3 = {d3_,e5} - nn.JoinTable(2)
    local d4_ = d3 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8 * 2, ngf * 8, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 8)
    -- input is (ngf * 8) x 16 x 16
    local d4 = {d4_,e4} - nn.JoinTable(2)
    local d5_ = d4 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8 * 2, ngf * 4, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 4)
    -- input is (ngf * 4) x 32 x 32
    local d5 = {d5_,e3} - nn.JoinTable(2)
    local d6_ = d5 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 4 * 2, ngf * 2, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 2)
    -- input is (ngf * 2) x 64 x 64
    local d6 = {d6_,e2} - nn.JoinTable(2)
    local d7_ = d6 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 2 * 2, ngf, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf)
    -- input is (ngf) x128 x 128
    local d7 = {d7_,e1} - nn.JoinTable(2)
    local d8 = d7 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 2, output_nc, 4, 4, 2, 2, 1, 1)
    -- input is (nc) x 256 x 256

    local o1 = d8 - nn.Tanh()

    -- 输入节点只有一个,e1
    -- 输出节点也只有一个, o1
    netG = nn.gModule({e1},{o1})

    return netG
end

简洁明了,另外,你如果用普通方式构建Unet,github上也有:
https://github.com/dmarnerides/dlt/blob/master/src/models/unet.lua
注意:
以前构建都是用 nn.Module来构建网络。而nngraph包中是用nngraph.Node来构建的网络的,构建出来的网络类型是nn.gModule,nn.gModule是继承自nn.Module的子类nn.Container。nn.Module可以通过上述提到的“-”来 “变成” nngraph.Node类型。

多个输入与多个输出

h1 = - nn.Linear(20,20)
h2 = - nn.Linear(10,10)
hh1 = h1 - nn.Tanh() - nn.Linear(20,1)
hh2 = h2 - nn.Tanh() - nn.Linear(10,1)
madd = {hh1,hh2} - nn.CAddTable()
oA = madd - nn.Sigmoid()
oB = madd - nn.Tanh()
gmod = nn.gModule( {h1,h2}, {oA,oB} )

结构图:
Torch7入门续集补充--- nngraph包的使用_第1张图片
输入时,

local out = gmod:forward({input1, input2})
local out1 = out[1]
local out2 = out[2]
--或是 
local unpack = unpack or table.unpack
local out1,out2 = unpack(out)
-- 反向传播
gmod:backward({grad1, grad2})

其他

输入节点不能是自定义层

h1 = - nn.Linear(20,20)
h2 = - nn.Linear(10,10)
hh1 = h1 - nn.Tanh() - nn.Linear(20,1)
hh2 = h2 - nn.Tanh() - nn.Linear(10,1)
madd = {hh1,hh2} - nn.CAddTable()
oA = madd - nn.Sigmoid()
oB = madd - nn.Tanh()
gmod = nn.gModule( {h1,h2}, {oA,oB} )

比如下面这样:


h1 = - nn.Linear(20,20)

 --这样是错误的,必须要用内置的nn.Module的层,这种自定义的层,重载了
 -- nn.Module,会导致出错。
h2 = - myLayer
hh1 = h1 - nn.Tanh() - nn.Linear(20,1)
hh2 = h2 - nn.Tanh() - nn.Linear(10,1)
madd = {hh1,hh2} - nn.CAddTable()
oA = madd - nn.Sigmoid()
oB = madd - nn.Tanh()
gmod = nn.gModule( {h1,h2}, {oA,oB} )

Expected nnop.Parameters node, found : nn.MyLayer
当然,中间节点可以是自定义层。

nn.gModule的基本知识

gModule可以允许多个输入,多个输出,当然构建gModule的modules不能构成环。
每个结点可以有多个输入,这些结点的输入顺序存储在node.data.mapindex中,即每个结点的父结点的指针。

每个结点的输入只能是一个,当然我们可以用类似nn.JoinTable(2)的方法将其join后,输入到一个结点中。如果输出可以是一张表,存着要输出的各个结点。

另外,node.data.input是一个list,存储着所有输入,如果输入只有一个,那么只有node.data.input[1]使用了。
值得注意的是, node.data.gradOutput是将所有的连接到该结点的反传梯度全部相加,再往前传。

还有一点就是,网络的第一个结点和最后一个结点都是“dummy”的。因为输入和输出可能是多个,所以要用这些dummy的结点来分别处理多个输入和输出。比如对于多个输出,则最后一个dummy结点,内部通过split操作,可以将网络分块,每个结点对应这些“网络块”,当然这些小块很可能会有重叠区域。有趣的是,对于多个输入和多个输出,我们反传时,输入可以不要求全部填满,比如只填第一个输入,那么整个网络相当于截取只和第一个输入有关的网络进行更新。

gModule获得某结点的信息

每个结点包含一个module。如果只是单独想看module的信息,那么直接net.modules[i]就行。而net.fowardnodes可以获得每个结点的信息。
只要获得了node,我们主要可以获得结点的输入以及gradOutput.

node.data.input
node.data.gradOutput

你可能会问,那gradInput呢?结点是没有gradInput的。每个结点首先综合gradOutput,从而得到本结点的gradOutput, 再将gradOuput传入结点内部的module,得到gradInput之后,再将gradInput作为该结点子节点的gradOuput.

由于每个结点存储了一个module
node.data.module以得到存储某个结点的module的各种信息。

--这是获得net.modules[9]的。
--这种写法的好处就是,可以获得结点的信息,而不是单独结点内部的module的信息。
local ind = 10
local latent = nil
 for indexNode, node in pairs(net.forwardnodes) do
     if indexNode == ind then
         if node.data.module then
             latent = node.data.module.output:clone()  -- use it to get the specific module output
         end          
     end
 end

简略看看gModule

gModule的初始化

大概步骤:
首先进行检查输入和输出。要求必须是nngraph.Node类型的。就是上面提到的-可以把nn.Module变成nngraph.Node类型。
然后对输入再次进行检查,如果输入只有一个,直接inputs[1]:add(innode,true)
否则多个输入必须检查每个输入结点不能有子结点。
然后再构建2张图:fg和bg图。fg图用于网络前向传播评估用,bg用于反向传播。
如果输入结点有多个,还要对每个结点进行assert

            assert(root.data.module, 'Expected nnop.Parameters node, module not found in node')
            assert(torch.typename(root.data.module) == 'nnop.Parameters',
                  'Expected nnop.Parameters node, found : ' ..torch.typename(root.data.module))

这也是为什么在刚才的例子中,不能输入结点不能是自定义的层。从代码上简要的看,如果输入结点只有一个的话,应该可以。
最后把每个结点的其他一些信息,加入到gModule中。

function gModule:__init(inputs,outputs)
   parent.__init(self)
   -- the graph is defined backwards, we have the output modules as input here
   -- we will define a dummy output node that connects all output modules
   -- into itself. This will be the output for the forward graph and
   -- input point for the backward graph
   local node
   local outnode = nngraph.Node({input={}})
   for i = 1, utils.tableMaxN(outputs) do
      node = outputs[i]
      if torch.typename(node) ~= 'nngraph.Node' then
         error(utils.expectingNodeErrorMessage(node, 'outputs', i))
      end
      outnode:add(node, true)
   end
   for i = 1, utils.tableMaxN(inputs) do
      node = inputs[i]
      if torch.typename(node) ~= 'nngraph.Node' then
         error(utils.expectingNodeErrorMessage(node, 'inputs', i))
      end
   end
   -- We add also a dummy input node.
   -- The input node will be split to feed the passed input nodes.
   local innode = nngraph.Node({input={}})
   assert(#inputs > 0, "no inputs are not supported")
   if #inputs == 1 then
      inputs[1]:add(innode,true)
   else
      local splits = {innode:split(#inputs)}
      for i = 1, #inputs do
         assert(#inputs[i].children == 0, "an input should have no inputs")
      end
      for i = 1, #inputs do
         inputs[i]:add(splits[i],true)
      end
   end

   -- the backward graph (bg) is for gradients
   -- the forward graph (fg) is for function evaluation
   self.bg = outnode:graph()
   self.fg = self.bg:reverse()

   -- the complete graph is constructed
   -- now regenerate the graphs with the additional nodes

   local roots = self.fg:roots()
   -- if there are more than one root in the forward graph, then make sure that
   -- extra roots are parameter nodes
   if #roots > 1 then
      local innodeRoot = nil
      -- first find our innode
      for _, root in ipairs(roots) do
         if root.data == innode.data then
            assert(innodeRoot == nil, 'more than one matching input node found in leaves')
            innodeRoot = root
         else
            assert(root.data.module, 'Expected nnop.Parameters node, module not found in node')
            assert(torch.typename(root.data.module) == 'nnop.Parameters',
                  'Expected nnop.Parameters node, found : ' ..torch.typename(root.data.module))
         end
      end
      assert(innodeRoot ~= nil, 'input node not found among roots')
      self.innode = innodeRoot
   else
      assert(#self.fg:roots() == 1, "expecting only one start")
      self.innode = self.fg:roots()[1]
   end

   assert(self.innode.data == innode.data, "expecting the forward innode")
   self.outnode = outnode
   self.verbose = false
   self.nInputs = #inputs

   -- computation on the graph is done through topsort of forward and backward graphs
   self.forwardnodes = self.fg:topsort()
   self.backwardnodes = self.bg:topsort()

   -- iteratare over all nodes: check, tag and add to container
   for i,node in ipairs(self.forwardnodes) do
      -- check for unused inputs or unused split() outputs
      if node.data.nSplitOutputs and node.data.nSplitOutputs ~=  #node.children then
         local nUnused = node.data.nSplitOutputs - #node.children
         local debugLabel = node.data.annotations._debugLabel
         local errStr =
            "%s of split(%s) outputs from the node declared at %s are unused"
         error(string.format(errStr, nUnused, node.data.nSplitOutputs,
                             debugLabel))
      end

      -- Check whether any nodes were defined as taking this node as an input,
      -- but then left dangling and don't connect to the output. If this is
      -- the case, then they won't be present in forwardnodes, so error out.
      for successor, _ in pairs(node.data.reverseMap) do
         local successorIsInGraph = false

         -- Only need to the part of forwardnodes from i onwards, topological
         -- sort guarantees it cannot be in the first part.
         for j = i+1, #self.forwardnodes do
            -- Compare equality of data tables, as new Node objects have been
            -- created by processes such as topoological sort, but the
            -- underlying .data table is shared.
            if self.forwardnodes[j].data == successor.data then
               successorIsInGraph = true
               break
            end
         end
         local errStr =
            "node declared on %s does not connect to gmodule output"
         assert(successorIsInGraph,
                string.format(errStr, successor.data.annotations._debugLabel))
      end

      -- set data.forwardNodeId for node:label() output
      node.data.forwardNodeId = node.id

      -- add module to container
      if node.data.module then
         self:add(node.data.module)
      end
   end

   self.output = nil
   self.gradInput = nil
   if #self.outnode.children > 1 then
      self.output = self.outnode.data.input
   end
end

gModule的反向传播

主要调用这个函数,可以看到

local function getTotalGradOutput(node)
   local gradOutput = node.data.gradOutput
   assert(istable(gradOutput), "expecting gradients to sum")
   if #gradOutput > 1 then
      -- Check if we can bypass the allocation, for the special case where all
      -- gradOutputs but one are zero tensors with an underlying one-element
      -- storage. Note that for the case that we
      -- cannot bypass it, this check will only be performed once
      if not node.data.gradOutputBuffer then
         local count = 0
         local idx = 1
         -- Count how many gradOutput are tensors of 1 element filled with zero
         for i=1,#gradOutput do
            local zero = torch.isTensor(gradOutput[i]) and
                         gradOutput[i]:storage() ~= nil and
                         gradOutput[i]:storage():size() == 1 and
                         gradOutput[i]:storage()[1] == 0
            if not zero then
               idx = i
               count = count + 1
            end
         end
         if count < 2 then
            -- Return the only non-zero one, or the first one
            -- if they are all zero
            return gradOutput[idx]
         end
      end
      node.data.gradOutputBuffer = node.data.gradOutputBuffer or nesting.cloneNested(gradOutput[1])
      local gobuff = node.data.gradOutputBuffer
      nesting.resizeNestedAs(gobuff, gradOutput[1])
      nesting.copyNested(gobuff, gradOutput[1])
      -- 注释:
      for i=2,#gradOutput do
         nesting.addNestedTo(gobuff, gradOutput[i])
      end
      gradOutput = gobuff
   else
      gradOutput = gradOutput[1]
   end
   return gradOutput
end

注释:可以看到这里首先将第一个节点的第一个梯度node.data.gradOutput[1]作为gobuff, 然后对于其他的梯度进行addNestedTo操作。

-- Adds the input to the output.
-- The input can contain nested tables.
-- The output will contain the same nesting of tables.
function nesting.addNestedTo(output, input)
   if torch.isTensor(output) then
      output:add(input) --不断累加
   else
      for key, child in pairs(input) do
         assert(output[key] ~= nil, "missing key")
         nesting.addNestedTo(output[key], child)
      end
   end
end

可以看到是不断累加的。得到总的gradOuput后,再传入Node里面的module内,调用
module的updateGradInput来计算。

反向传播的主要代码:

function gModule:updateGradInput(input,gradOutput)
   local function neteval(node)
      if node.data.selectindex then
         assert(not node.data.module, "the selectindex-handling nodes should have no module")
         assert(#node.children == 1, "only the splitted node should be the input")
         local child = node.children[1]
         local go = getTotalGradOutput(node)
         child.data.gradOutput = child.data.gradOutput or {}
         assert(#child.data.gradOutput <= 1, "the splitted node should be used only once")
         -- The data.gradOutput holds the to-be-summed gradients.
         child.data.gradOutput[1] = child.data.gradOutput[1] or {}
         assert(not child.data.gradOutput[1][node.data.selectindex], "no gradOutput should be assigned yet")
         child.data.gradOutput[1][node.data.selectindex] = go
      else
      --********得到结点的总的gradOutput**************
         local gradOutput = getTotalGradOutput(node)
         -- updateGradInput through this node
         -- If no module is present, the node behaves like nn.Identity.
         local gradInput
         if not node.data.module then
            gradInput = gradOutput
         else
            local input = node.data.input
            -- a parameter node is captured
            if input == nil and node.data.module ~= nil then
               input = {}
            end
            if #input == 1 then
               input = input[1]
            end
            --********得到结点存储的module,调用module的updateGradInput
            --********进行更新。
            local module = node.data.module
            gradInput = module:updateGradInput(input,gradOutput)
         end
         -- 反传时,将该结点的梯度传给每个子结点的data.gradOutput
         -- propagate the output to children
         for i,child in ipairs(node.children) do
            child.data.gradOutput = child.data.gradOutput or {}
            local mapindex = node.data.mapindex[child.data]
            local gi
            if #node.children == 1 then
               gi = gradInput
            else
               gi = gradInput[mapindex]
            end
            table.insert(child.data.gradOutput,gi)
         end
      end
      if self.verbose then
         print(' V : ' .. node:label())
      end
   end
   local outnode = self.outnode
   if #outnode.children > 1 and #gradOutput ~= #outnode.children then
      error(string.format('Got %s gradOutputs instead of %s', #gradOutput, #outnode.children))
   end
   for _,node in ipairs(self.backwardnodes) do
      local gradOutput = node.data.gradOutput
      while gradOutput and #gradOutput >0 do
         table.remove(gradOutput)
      end
   end
   -- Set the starting gradOutput.
   outnode.data.gradOutput = outnode.data.gradOutput or {}
   outnode.data.gradOutput[1] = gradOutput

   for i,node in ipairs(self.backwardnodes) do
      neteval(node)
   end

   assert(#self.innode.data.gradOutput == 1, "expecting the innode to be used only once")
   self.gradInput = self.innode.data.gradOutput[1]
   return self.gradInput
end

可以看到,反传时,就是先对每个Node综合一下梯度,然后将该梯度传给该结点的每个子节点。然后对每个结点这样做,最终的梯度就是self.innode.data.gradOutput[1]

你可能感兴趣的:(Lua,Torch7入门教程)