C++元编程——计算链和RNN

反向传播时候有个计算链,误差传播时也是反向走过各个计算链,所以这个计算链的概念很重要。那么层间单向的RNN计算链可以表现为下图:

C++元编程——计算链和RNN_第1张图片

 大写字母W、U、V表示点积运算,B是偏移运算,f和g是激活运算,+是相加运算。反向传播就可以从链上看出来结果,前一个误差经过当前运算得到当前误差,同时得到当前运算参数的偏导数,并更新当前运算参数,如此往复向前进行更新。如果遇到分支,可以想像成两次输出的误差,可以求均值也可以分两次训练。我采用的是求均值然后再更新。

下面是这个计算链的实现cal_chain.hpp:

#ifndef _CAL_CHAIN_HPP_
#define _CAL_CHAIN_HPP_
#include "mat.hpp"
#include "base_function.hpp"

template
struct cal_chain_node
{
	using base_type = cal_chain_node;
	using ret_type = mat;
	using inp_type = mat;

	virtual ret_type forward(const inp_type& inp) = 0;
	virtual inp_type backward(const ret_type& delta) = 0;
	virtual void update() = 0;
};

template
struct cal_chain 
{
	using type = val_t;
	using cur_type = cal_chain_node;
	cur_type* sp_cur_node;
	using inp_type = typename cur_type::inp_type;
	using nxt_type = cal_chain;
	std::shared_ptr sp_next;
	using ret_type = typename nxt_type::ret_type;

	cal_chain() 
	{
		sp_next = std::make_shared();
	}

	auto forward(const inp_type& inp)
	{
		return sp_next->forward(sp_cur_node->forward(inp));
	}

	auto backward(const ret_type& ret) 
	{
		return sp_cur_node->backward(sp_next->backward(ret));
	}

	void update() 
	{
		sp_cur_node->update();
		if (sp_next)
			sp_next->update();
	}

	template
	cal_chain& set(set_type* sp)
	{
		if constexpr (remain != 0)
		{
			sp_next->set(sp);
		}
		if constexpr (remain == 0)
		{
			sp_cur_node = dynamic_cast(sp);
		}
		return *this;
	}
};

template
struct cal_chain
{
	using type = val_t;
	using cur_type = cal_chain_node;
	cur_type* sp_cur_node;
	using inp_type = typename cur_type::inp_type;
	using ret_type = typename cur_type::ret_type;

	cal_chain() 
	{
	}

	auto forward(const inp_type& inp)
	{
		return sp_cur_node->forward(inp);
	}

	auto backward(const ret_type& ret)
	{
		return sp_cur_node->backward(ret);
	}

	void update()
	{
		sp_cur_node->update();
	}


	template
	cal_chain& set(set_type* sp)
	{
		static_assert(remain == 0, "over over");
		if constexpr (remain == 0)
			sp_cur_node = dynamic_cast(sp);
		return *this;
	}
};

template
struct cal_chain_container :public cal_chain_node
{
	using inp_type = typename chain::inp_type;
	using ret_type = typename chain::ret_type;
	chain*		p;
	cal_chain_container(chain* pv):p(pv)
	{}

	ret_type forward(const inp_type& inp)
	{
		return p->forward(inp);
	}

	inp_type backward(const ret_type& ret)
	{
		return p->backward(ret);
	}

	void update()
	{
		p->update();
	}
};

template
auto make_chain_node(cal_chain* p)
{
	using chain_type = cal_chain;
	cal_chain_container cc(p);
	return cc;
}

#include "weight_initilizer.hpp"
#include "update_methods.hpp"

template class update_method_templ, int inp_row, int inp_col, int ret_row, typename val_t>
struct cal_chain_node_mult:public cal_chain_node
{
	using ret_type = mat;
	using inp_type = mat;
	using weight_type = mat;

	weight_type W;
	weight_type deltaW;
	inp_type pre_inp;

	update_method_templ um;
	double d_num;

	cal_chain_node_mult():d_num(0.)
	{
		weight_initilizer::cal(W);
	}

	virtual ret_type forward(const inp_type& inp) 
	{
		pre_inp = inp;
		return W.dot(inp);
	}

	virtual inp_type backward(const ret_type& delta) 
	{
		auto ret = W.t().dot(delta);
		deltaW = deltaW * d_num + delta.dot(pre_inp.t());
		d_num = d_num + 1.;
		if (d_num > 1e-7) 
		{
			deltaW = deltaW / d_num;
		}
		return ret;
	}

	virtual void update() 
	{
		W.assign<0, 0>(um.update(W, deltaW));
		deltaW = 0.;
		d_num = 0.;
	}
};

template class update_method_templ, int inp_row, int inp_col, typename val_t>
struct cal_chain_node_bias :public cal_chain_node
{
	using ret_type = mat;
	using inp_type = mat;
	update_method_templ um;

	inp_type b;
	inp_type deltab;
	double d_num;
	cal_chain_node_bias():d_num(0.)
	{}

	virtual ret_type forward(const inp_type& inp)
	{
		return b + (inp);
	}

	virtual inp_type backward(const ret_type& delta)
	{
		deltab = deltab * d_num + delta;
		d_num = d_num + 1.;
		if (d_num > 1e-7)
		{
			deltab = deltab / d_num;
		}
		return delta;
	}

	virtual void update()
	{
		b.assign<0, 0>(um.update(b, deltab));
		deltab = 0.;
		d_num = 0.;
	}
};

#include "activate_function.hpp"
template class activate_func, int inp_row, int inp_col, typename val_t>
struct cal_chain_node_act :public cal_chain_node
{
	using ret_type = mat;
	using inp_type = mat;
	activate_func act_fun;

	virtual ret_type forward(const inp_type& inp)
	{
		return act_fun.forward(inp);
	}

	virtual inp_type backward(const ret_type& delta)
	{
		return act_fun.backward() * delta;
	}

	virtual void update()
	{
	}
};

#endif

下面看一看用这个计算链实现RNN:

#include 
#include "cal_chain.hpp"

template
struct rnn_node
{
	using inp_type = mat;
	using ret_type = mat;
	inp_type St;
	inp_type dSt;
	cal_chain_node_mult W,U;
	cal_chain_node_bias b;
	cal_chain_node_act f;
	
	cal_chain_node_mult V;
	cal_chain_node_act g;

	rnn_node()
	{
	}

	ret_type forward(const inp_type& X)
	{
		St.assign<0,0>(f.forward(b.forward(U.forward(X) + W.forward(St))));
		return g.forward(V.forward(St));
	}

	inp_type backward(const ret_type& delta)
	{
		auto delta_before_b = f.backward(dSt) + f.backward(V.backward(g.backward(delta)));
		auto WU = b.backward(delta_before_b);
		dSt.assign<0,0>(W.backward(WU));
		return U.backward(WU);
	}

	void update() 
	{
		W.update();
		U.update();
		b.update();
		f.update();
		V.update();
		g.update();
	}
};

int main(int argc, char** argv) 
{
	mat<3, 1, double> mm1{.1,.2,.3};
	rnn_node<3, 2> r;
	for (int i = 0; ; ++i)
	{
		auto ret = r.forward(mm1);
		if (i % 10000 == 0)
		{
			ret.print();
			_getch();
		}

		mat<2, 1, double> mm2{.4,.6};
		r.backward(ret - mm2);
		r.update();
	}
	return 0;
}

这个用法当然不是RNN的常规用法。真正的RNN训练应该用一段固定长度的序列进行计算。最好有层内正反向两个传播节点。

你可能感兴趣的:(元编程学习实践,c++,rnn,开发语言)