反向传播时候有个计算链,误差传播时也是反向走过各个计算链,所以这个计算链的概念很重要。那么层间单向的RNN计算链可以表现为下图:
大写字母W、U、V表示点积运算,B是偏移运算,f和g是激活运算,+是相加运算。反向传播就可以从链上看出来结果,前一个误差经过当前运算得到当前误差,同时得到当前运算参数的偏导数,并更新当前运算参数,如此往复向前进行更新。如果遇到分支,可以想像成两次输出的误差,可以求均值也可以分两次训练。我采用的是求均值然后再更新。
下面是这个计算链的实现cal_chain.hpp:
#ifndef _CAL_CHAIN_HPP_
#define _CAL_CHAIN_HPP_
#include "mat.hpp"
#include "base_function.hpp"
template
struct cal_chain_node
{
using base_type = cal_chain_node;
using ret_type = mat;
using inp_type = mat;
virtual ret_type forward(const inp_type& inp) = 0;
virtual inp_type backward(const ret_type& delta) = 0;
virtual void update() = 0;
};
template
struct cal_chain
{
using type = val_t;
using cur_type = cal_chain_node;
cur_type* sp_cur_node;
using inp_type = typename cur_type::inp_type;
using nxt_type = cal_chain;
std::shared_ptr sp_next;
using ret_type = typename nxt_type::ret_type;
cal_chain()
{
sp_next = std::make_shared();
}
auto forward(const inp_type& inp)
{
return sp_next->forward(sp_cur_node->forward(inp));
}
auto backward(const ret_type& ret)
{
return sp_cur_node->backward(sp_next->backward(ret));
}
void update()
{
sp_cur_node->update();
if (sp_next)
sp_next->update();
}
template
cal_chain& set(set_type* sp)
{
if constexpr (remain != 0)
{
sp_next->set(sp);
}
if constexpr (remain == 0)
{
sp_cur_node = dynamic_cast(sp);
}
return *this;
}
};
template
struct cal_chain
{
using type = val_t;
using cur_type = cal_chain_node;
cur_type* sp_cur_node;
using inp_type = typename cur_type::inp_type;
using ret_type = typename cur_type::ret_type;
cal_chain()
{
}
auto forward(const inp_type& inp)
{
return sp_cur_node->forward(inp);
}
auto backward(const ret_type& ret)
{
return sp_cur_node->backward(ret);
}
void update()
{
sp_cur_node->update();
}
template
cal_chain& set(set_type* sp)
{
static_assert(remain == 0, "over over");
if constexpr (remain == 0)
sp_cur_node = dynamic_cast(sp);
return *this;
}
};
template
struct cal_chain_container :public cal_chain_node
{
using inp_type = typename chain::inp_type;
using ret_type = typename chain::ret_type;
chain* p;
cal_chain_container(chain* pv):p(pv)
{}
ret_type forward(const inp_type& inp)
{
return p->forward(inp);
}
inp_type backward(const ret_type& ret)
{
return p->backward(ret);
}
void update()
{
p->update();
}
};
template
auto make_chain_node(cal_chain* p)
{
using chain_type = cal_chain;
cal_chain_container cc(p);
return cc;
}
#include "weight_initilizer.hpp"
#include "update_methods.hpp"
template class update_method_templ, int inp_row, int inp_col, int ret_row, typename val_t>
struct cal_chain_node_mult:public cal_chain_node
{
using ret_type = mat;
using inp_type = mat;
using weight_type = mat;
weight_type W;
weight_type deltaW;
inp_type pre_inp;
update_method_templ um;
double d_num;
cal_chain_node_mult():d_num(0.)
{
weight_initilizer::cal(W);
}
virtual ret_type forward(const inp_type& inp)
{
pre_inp = inp;
return W.dot(inp);
}
virtual inp_type backward(const ret_type& delta)
{
auto ret = W.t().dot(delta);
deltaW = deltaW * d_num + delta.dot(pre_inp.t());
d_num = d_num + 1.;
if (d_num > 1e-7)
{
deltaW = deltaW / d_num;
}
return ret;
}
virtual void update()
{
W.assign<0, 0>(um.update(W, deltaW));
deltaW = 0.;
d_num = 0.;
}
};
template class update_method_templ, int inp_row, int inp_col, typename val_t>
struct cal_chain_node_bias :public cal_chain_node
{
using ret_type = mat;
using inp_type = mat;
update_method_templ um;
inp_type b;
inp_type deltab;
double d_num;
cal_chain_node_bias():d_num(0.)
{}
virtual ret_type forward(const inp_type& inp)
{
return b + (inp);
}
virtual inp_type backward(const ret_type& delta)
{
deltab = deltab * d_num + delta;
d_num = d_num + 1.;
if (d_num > 1e-7)
{
deltab = deltab / d_num;
}
return delta;
}
virtual void update()
{
b.assign<0, 0>(um.update(b, deltab));
deltab = 0.;
d_num = 0.;
}
};
#include "activate_function.hpp"
template class activate_func, int inp_row, int inp_col, typename val_t>
struct cal_chain_node_act :public cal_chain_node
{
using ret_type = mat;
using inp_type = mat;
activate_func act_fun;
virtual ret_type forward(const inp_type& inp)
{
return act_fun.forward(inp);
}
virtual inp_type backward(const ret_type& delta)
{
return act_fun.backward() * delta;
}
virtual void update()
{
}
};
#endif
下面看一看用这个计算链实现RNN:
#include
#include "cal_chain.hpp"
template
struct rnn_node
{
using inp_type = mat;
using ret_type = mat;
inp_type St;
inp_type dSt;
cal_chain_node_mult W,U;
cal_chain_node_bias b;
cal_chain_node_act f;
cal_chain_node_mult V;
cal_chain_node_act g;
rnn_node()
{
}
ret_type forward(const inp_type& X)
{
St.assign<0,0>(f.forward(b.forward(U.forward(X) + W.forward(St))));
return g.forward(V.forward(St));
}
inp_type backward(const ret_type& delta)
{
auto delta_before_b = f.backward(dSt) + f.backward(V.backward(g.backward(delta)));
auto WU = b.backward(delta_before_b);
dSt.assign<0,0>(W.backward(WU));
return U.backward(WU);
}
void update()
{
W.update();
U.update();
b.update();
f.update();
V.update();
g.update();
}
};
int main(int argc, char** argv)
{
mat<3, 1, double> mm1{.1,.2,.3};
rnn_node<3, 2> r;
for (int i = 0; ; ++i)
{
auto ret = r.forward(mm1);
if (i % 10000 == 0)
{
ret.print();
_getch();
}
mat<2, 1, double> mm2{.4,.6};
r.backward(ret - mm2);
r.update();
}
return 0;
}
这个用法当然不是RNN的常规用法。真正的RNN训练应该用一段固定长度的序列进行计算。最好有层内正反向两个传播节点。