Torch中的RNN底层代码实现

  • 理论篇
  • 代码篇

Torch中的RNN【1】这个package包括了RNN,RL,通过这个package可以很容易构建RNN,RL的模型。

安装:

luarocks install torch
luarocks install nn
luarocks install torchx
luarocks install dataload

如果有CUDA:
luarocks install cutorch
luarocks install cunn

记得安装:
luarocks install rnn

但是如果要使用nn.Reccurent,需要安装:【4】

理论篇

这一次主要是讲最简单的RNN,也就是Simple RNN。实现的话是根据这两篇论文:【6】,【7】

首先介绍一下Simple RNN的整个网络结构,再说一下 ρ step 的BPTT。

整个网络可以用下图来表示:(这种网络的输入一部分是当前的输入,另外一部分来自于hidden layer的上一个输出,这种叫做Elman Network。另外一种网络是一部分来自于当前输入,另外一部分来自于整个网络的上一个输出)
Torch中的RNN底层代码实现_第1张图片

  • 当前输入 wt 与上一个hidden layer的输出 st1 两个vector相加,得到真正输入到网络里面的东西。
    这里写图片描述
  • 接着是把输入送进一个logistic regression里面,得到hidden layer: st . st 一方面往输出那条路径走,另外一方面往缓存或者叫做Context里面存起来,称为下一个输入需要的一部分,替换 st1
    这里写图片描述
    这里写图片描述
  • st 输出该时刻的output: yt 一个Linear加上softmax,非常简单。
    这里写图片描述
    这里写图片描述

这样呢就把整个网络结构描述完了,接下来就是如何训练得到参数了。(其实RNN,LSTM还有很多小的trick,同样的算法,trick不一样,结果都会千差万别)

在另外的论文里面把这幅图给完整画了出来,显得更加清晰:
Torch中的RNN底层代码实现_第2张图片

了解了整个网络的以后,需要定义loss,在进行BP的时候,首先定义loss function,一般采用的是SSE: dpk 就是第p个sample,输出的feature第k个的label。 y 是prediction。
Torch中的RNN底层代码实现_第3张图片

对于w的更新,都是采用梯度下降:
这里写图片描述

对于输出output部分进行求导:
这里写图片描述

再进一步输出output的linear regression部分的w进行求导:
这里写图片描述

接着是hidden layer进行求导:
这里写图片描述

对hidden layer的输入的input部分参数进行求导:
这里写图片描述

对hidden layer的上一个hidden layer的作为input的部分进行求导:
这里写图片描述

目前的loss为SSE的时候,一般采用logistic function作为输出的函数:
这里写图片描述
这里写图片描述

当然,也可以有别的loss function,对应的output function g也会做乡相应改变。

比如对于Gaussian Distribution:
这里写图片描述
这里写图片描述
这里写图片描述

使用cross-entropy作为loss:(g为logistic function)
这里写图片描述
这里写图片描述

对于分类问题,采用multinomial dostribution:
采用的是softmax作为g,然后loss function为:
这里写图片描述

这里写图片描述

这里写图片描述

在RNN中经常听到BPTT,就是让RNN在进行后向传递的时候不仅仅是当前这个时期,还有的是更多时刻:τ。

比如τ = 3,展开的图如下:展开了三次,那么进行BP的时候,就把各个参数往后相乘来更新w,这里需要注意vanishing gradient effect和explode gradient effect的东西,一个梯度衰减比如为0,一个梯度爆炸。
Torch中的RNN底层代码实现_第4张图片

还有一种图可以表示梯度变化,红色的表示梯度的方向:
Torch中的RNN底层代码实现_第5张图片

如果用公式来表示这个整个过程就是前向:
Torch中的RNN底层代码实现_第6张图片

后向更新梯度:(每个时刻的梯度都会进行叠加到最后的w更新)
Torch中的RNN底层代码实现_第7张图片

代码篇

这次描述的是Simple RNN,函数为nn.Recurrent 。在nn中有两个抽象类,一个是nn,用来构建网络,一个是Criterion【3】,用来提供比如cross entropy,reward。具体介绍可以看【2】,还有对应的论文。

在【3】中提出了一个简单的例子:目前下面的nn.Recurrent已经不在Torch的库中,所以要使用的话,就去安装这个人写的【4】

这里面的实现的RNN是最简单的,连hidden layer都没有,直接transfer就是输出了。

nn.Recurrent(start, input, feedback, [transfer, rho, merge])
-- start:对初始t=0的input进行处理
-- input:对t~=0的时候input进行处理
-- feedback:对s(t)进行处理缓存到了s(t-1)
-- transfer:对输出进行处理的函数
-- rho:进行BPTT的steps的数目
-- merge:对input x(t)和上一个时刻的输出s(t-1)进行融合
-- generate some dummy inputs and gradOutputs sequences
-- 生成dummy input
inputs, gradOutputs = {}, {} 
for step=1,rho do
    inputs[step] = torch.randn(batchSize,inputSize)
    gradOutputs[step] = torch.randn(batchSize,inputSize) 
end

-- 调用RNN
-- an AbstractRecurrent instance
rnn = nn.Recurrent(
    hiddenSize, -- size of the input layer(隐层的size)
    nn.Linear(inputSize,outputSize), -- input layer(输入层进行linear regression) 
    nn.Linear(outputSize, outputSize), -- recurrent layer 输出层的linear regression
    nn.Sigmoid(), -- transfer function,把输入通过linear regression之后的结果送到这个函数得到s(t),这个函数也可以改成ReLU别的激活函数
    rho -- maximum number of time-steps for BPTT,进行BPTT时候的steps
)

-- feed-forward and backpropagate through time like this :
for step=1,rho 
do
    rnn:forward(inputs[step])
    rnn:backward(inputs[step], gradOutputs[step])
end

rnn:backwardThroughTime() -- call backward on the internal modules 
gradInputs = rnn.gradInputs
rnn:updateParameters(0.1)
rnn:forget() -- resets the time-step counter

对完整的nn.Reccurent的理解:【10】

assert(not nn.Recurrent, "update nnx package : luarocks install nnx")
local Recurrent, parent = torch.class('nn.Recurrent', 'nn.AbstractRecurrent')

-- 把各个module放到RNN的对应位置
-- start是对最开始t=0输入inut做的处理
-- input是对t~=0的时刻进行input的处理
-- feedback是对s(t)进行处理缓存到s(t-1)的函数
-- transfer是对最后的输出的activation function
-- rho:是进行BPTT的时间
-- merge:对于输入x(t)和上一个时刻的hidden layer的输出s(t-1)的融合方法
function Recurrent:__init(start, input, feedback, transfer, rho, merge)
   parent.__init(self, rho)

   local ts = torch.type(start)
   if ts == 'torch.LongStorage' or ts == 'number' then
      start = nn.Add(start)
   elseif ts == 'table' then
      start = nn.Add(torch.LongStorage(start))
   elseif not torch.isTypeOf(start, 'nn.Module') then
      error"Recurrent : expecting arg 1 of type nn.Module, torch.LongStorage, number or table"
   end

   self.startModule = start
   self.inputModule = input
   self.feedbackModule = feedback
   self.transferModule = transfer or nn.Sigmoid()
   self.mergeModule = merge or nn.CAddTable()

   self.modules = {self.startModule, self.inputModule, self.feedbackModule, self.transferModule, self.mergeModule}

   self:buildInitialModule()
   self:buildRecurrentModule()
   self.sharedClones[2] = self.recurrentModule
end

-- 对最开始t=0的时候构建模型
-- build module used for the first step (steps == 1)
function Recurrent:buildInitialModule()
   self.initialModule = nn.Sequential()
   self.initialModule:add(self.inputModule:sharedClone())
   self.initialModule:add(self.startModule)
   self.initialModule:add(self.transferModule:sharedClone())
end

-- build module used for the other steps (steps > 1)
-- 构建整个模型
function Recurrent:buildRecurrentModule()
   local parallelModule = nn.ParallelTable()
   parallelModule:add(self.inputModule)
   parallelModule:add(self.feedbackModule)
   self.recurrentModule = nn.Sequential()
   self.recurrentModule:add(parallelModule)
   self.recurrentModule:add(self.mergeModule)
   self.recurrentModule:add(self.transferModule)
end

-- 更新输出
function Recurrent:updateOutput(input)
   -- output(t) = transfer(feedback(output_(t-1)) + input(input_(t)))
   local output
   if self.step == 1 then
      output = self.initialModule:updateOutput(input)
   else
      if self.train ~= false then
         -- set/save the output states
         self:recycle()
         local recurrentModule = self:getStepModule(self.step)
          -- self.output is the previous output of this module
         output = recurrentModule:updateOutput{input, self.outputs[self.step-1]}
      else
         -- self.output is the previous output of this module
         output = self.recurrentModule:updateOutput{input, self.outputs[self.step-1]}
      end
   end

   self.outputs[self.step] = output
   self.output = output
   self.step = self.step + 1
   self.gradPrevOutput = nil
   self.updateGradInputStep = nil
   self.accGradParametersStep = nil
   return self.output
end

-- 求解梯度,没有累加
function Recurrent:_updateGradInput(input, gradOutput)
   assert(self.step > 1, "expecting at least one updateOutput")
   local step = self.updateGradInputStep - 1

   local gradInput

   if self.gradPrevOutput then
      self._gradOutputs[step] = nn.rnn.recursiveCopy(self._gradOutputs[step], self.gradPrevOutput)
      nn.rnn.recursiveAdd(self._gradOutputs[step], gradOutput)
      gradOutput = self._gradOutputs[step]
   end

   local output = self.outputs[step-1]
   if step > 1 then
      local recurrentModule = self:getStepModule(step)
      gradInput, self.gradPrevOutput = unpack(recurrentModule:updateGradInput({input, output}, gradOutput))
   elseif step == 1 then
      gradInput = self.initialModule:updateGradInput(input, gradOutput)
   else
      error"non-positive time-step"
   end

   return gradInput
end

-- 求解梯度,但是会把t steps的梯度相加
function Recurrent:_accGradParameters(input, gradOutput, scale)
   local step = self.accGradParametersStep - 1

   local gradOutput = (step == self.step-1) and gradOutput or self._gradOutputs[step]
   local output = self.outputs[step-1]

   if step > 1 then
      local recurrentModule = self:getStepModule(step)
      recurrentModule:accGradParameters({input, output}, gradOutput, scale)
   elseif step == 1 then
      self.initialModule:accGradParameters(input, gradOutput, scale)
   else
      error"non-positive time-step"
   end
end

function Recurrent:recycle()
   return parent.recycle(self, 1)
end

function Recurrent:forget()
   return parent.forget(self, 1)
end

function Recurrent:includingSharedClones(f)
   local modules = self.modules
   self.modules = {}
   local sharedClones = self.sharedClones
   self.sharedClones = nil
   local initModule = self.initialModule
   self.initialModule = nil
   for i,modules in ipairs{modules, sharedClones, {initModule}} do
      for j, module in pairs(modules) do
         table.insert(self.modules, module)
      end
   end
   local r = f()
   self.modules = modules
   self.sharedClones = sharedClones
   self.initialModule = initModule
   return r
end

function Recurrent:reinforce(reward)
   if torch.type(reward) == 'table' then
      -- multiple rewards, one per time-step
      local rewards = reward
      for step, reward in ipairs(rewards) do
         if step == 1 then
            self.initialModule:reinforce(reward)
         else
            local sm = self:getStepModule(step)
            sm:reinforce(reward)
         end
      end
   else
      -- one reward broadcast to all time-steps
      return self:includingSharedClones(function()
         return parent.reinforce(self, reward)
      end)
   end
end

function Recurrent:maskZero()
   error("Recurrent doesn't support maskZero as it uses a different "..
      "module for the first time-step. Use nn.Recurrence instead.")
end

function Recurrent:trimZero()
   error("Recurrent doesn't support trimZero as it uses a different "..
      "module for the first time-step. Use nn.Recurrence instead.")
end

-- 把模型打印出来
-- 比如我调用的是:
-- nn.Recurrent(256, nn.Identity(), nn.Linear(256, 256), nn['ReLU'](), 99999)
-- [[[
{input(t), output(t-1)} -> (1) -> (2) -> (3) -> output(t)]
    (1):  {
           input(t)
                |`-> (t==0): nn.Add
                |`-> (t~=0): nn.Identity
           output(t-1)
                |`-> nn.Linear(256 -> 256)
          }
    (2): nn.CAddTable
    (3): nn.ReLU
    }
---]]
function Recurrent:__tostring__()
   local tab = '  '
   local line = '\n'
   local next = ' -> '
   local str = torch.type(self)
   str = str .. ' {' .. line .. tab .. '[{input(t), output(t-1)}'
   for i=1,3 do
      str = str .. next .. '(' .. i .. ')'
   end
   str = str .. next .. 'output(t)]'

   local tab = '  '
   local line = '\n  '
   local next = '  |`-> '
   local ext = '  |    '
   local last = '   ... -> '
   str = str .. line ..  '(1): ' .. ' {' .. line .. tab .. 'input(t)'
   str = str .. line .. tab .. next .. '(t==0): ' .. tostring(self.startModule):gsub('\n', '\n' .. tab .. ext)
   str = str .. line .. tab .. next .. '(t~=0): ' .. tostring(self.inputModule):gsub('\n', '\n' .. tab .. ext)
   str = str .. line .. tab .. 'output(t-1)'
   str = str .. line .. tab .. next .. tostring(self.feedbackModule):gsub('\n', line .. tab .. ext)
   str = str .. line .. "}"
   local tab = '  '
   local line = '\n'
   local next = ' -> '
   str = str .. line .. tab .. '(' .. 2 .. '): ' .. tostring(self.mergeModule):gsub(line, line .. tab)
   str = str .. line .. tab .. '(' .. 3 .. '): ' .. tostring(self.transferModule):gsub(line, line .. tab)
   str = str .. line .. '}'
   return str
end

转载请注明出处: http://blog.csdn.net/c602273091/article/details/78975636

参考链接:
【1】RNN地址: https://github.com/torch/rnn
【2】nn Package: https://arxiv.org/pdf/1511.07889.pdf
【3】RNN Code: https://github.com/torch/rnn/blob/master/doc/recurrent.md#rnn.Recurrence
【4】nn.Reccurent: https://github.com/Element-Research/rnn/blob/master/Recurrent.lua
【5】nn RNN: https://github.com/Element-Research/rnn
【6】Recurrent neural network based language model: http://www.fit.vutbr.cz/research/groups/speech/publi/2010/mikolov_interspeech2010_IS100722.pdf
【7】 A guide to recurrent neural networks and backpropagation: http://citeseerx.ist.psu.edu/viewdoc/download;jsessionid=CDD081815C5FAC4835EF27B81EEA5F8C?doi=10.1.1.3.9311&rep=rep1&type=pdf
【8】STATISTICAL LANGUAGE MODELS BASED ON NEURAL NETWORKS: (3.2~3.3)http://www.fit.vutbr.cz/%7Eimikolov/rnnlm/thesis.pdf
【9】TRAINING RECURRENT NEURAL NETWORKS:(2.5~2.8) http://www.cs.utoronto.ca/%7Eilya/pubs/ilya_sutskever_phd_thesis.pdf
【10】nn.Reccurent: https://github.com/Element-Research/rnn/blob/master/Recurrent.lua

你可能感兴趣的:(Torch中的RNN底层代码实现)