C++元编程——单向深度RNN实现

书接上文,话说两端。通过节点建立RNN有点麻烦,现在又搞了一个深度RNN。就是把单个RNN堆叠起来。废话不多,直接搞代码:

template
struct rnn_node:public cal_chain_node
{
	using inp_type = mat;
	using ret_type = mat;
	inp_type St;
	inp_type dSt;
	cal_chain_node_mult W,U;
	cal_chain_node_bias b;
	cal_chain_node_act f;
	
	cal_chain_node_mult V;
	cal_chain_node_act g;

	static constexpr int inp_num = inp_num;

	rnn_node()
	{
	}

	ret_type forward(const inp_type& X)
	{
		St.assign<0,0>(f.forward(b.forward(U.forward(X) + W.forward(St))));
		return g.forward(V.forward(St));
	}

	inp_type backward(const ret_type& delta)
	{
		auto delta_before_b = (f.backward(dSt) + f.backward(V.backward(g.backward(delta)))) / 2.;
		auto WU = b.backward(delta_before_b);
		dSt.assign<0,0>(W.backward(WU));
		return U.backward(WU);
	}

	void update() 
	{
		W.update();
		U.update();
		b.update();
		f.update();
		V.update();
		g.update();
	}
};

template
struct rnn_type_cal 
{
	using added_t = typename cur_t::template add_tool::type;
	using chain_type = typename rnn_type_cal::chain_type;
};

template
struct rnn_type_cal
{
	using added_t = typename cur_t::template add_tool::type;
	using chain_type = added_t;
};

template
struct rnn_type_def 
{
	using first_chain = cal_chain;
	using chain_type = typename rnn_type_cal< first_chain, val_t, nodes_num...>::chain_type;
};

template
struct gen_rnn 
{
	template
	static void gen_rnn_chain(chain_type& chn)
	{
		using cur_node_t = rnn_node;
		cur_node_t* cur_node = new cur_node_t;
		chn.set(cur_node);
		gen_rnn::template gen_rnn_chain(chn);
	}
};

template
struct gen_rnn
{
	template
	static void gen_rnn_chain(chain_type& chn)
	{
		using cur_node_t = rnn_node;
		cur_node_t* cur_node = new cur_node_t;
		chn.set(cur_node);
	}
};

template
typename rnn_type_def::chain_type make_rnn() 
{
	using chain_type = typename rnn_type_def::chain_type;
	chain_type ret;
	gen_rnn::gen_rnn_chain<0>(ret);
	return ret;
}

对于上一篇文章中的计算链cal_chain增加了一个算法,用于推导增加元素后的类型(add_tool)。计算链全部代码如下:

#ifndef _CAL_CHAIN_HPP_
#define _CAL_CHAIN_HPP_
#include "mat.hpp"
#include "base_function.hpp"

template
struct cal_chain_node
{
	using base_type = cal_chain_node;
	using ret_type = mat;
	using inp_type = mat;

	virtual ret_type forward(const inp_type& inp) = 0;
	virtual inp_type backward(const ret_type& delta) = 0;
	virtual void update() = 0;
};

template
struct cal_chain 
{
	using type = val_t;
	using cur_type = cal_chain_node;
	cur_type* sp_cur_node;
	using inp_type = typename cur_type::inp_type;
	using nxt_type = cal_chain;
	std::shared_ptr sp_next;
	using ret_type = typename nxt_type::ret_type;

	cal_chain() 
	{
		sp_next = std::make_shared();
	}

	auto forward(const inp_type& inp)
	{
		return sp_next->forward(sp_cur_node->forward(inp));
	}

	auto backward(const ret_type& ret) 
	{
		return sp_cur_node->backward(sp_next->backward(ret));
	}

	void update() 
	{
		sp_cur_node->update();
		if (sp_next)
			sp_next->update();
	}

	template
	cal_chain& set(set_type* sp)
	{
		if constexpr (remain != 0)
		{
			sp_next->set(sp);
		}
		if constexpr (remain == 0)
		{
			sp_cur_node = dynamic_cast(sp);
		}
		return *this;
	}

	template
	struct add_tool 
	{
		using type = cal_chain;
	};
};

template
struct cal_chain
{
	using type = val_t;
	using cur_type = cal_chain_node;
	cur_type* sp_cur_node;
	using inp_type = typename cur_type::inp_type;
	using ret_type = typename cur_type::ret_type;

	cal_chain() 
	{
	}

	auto forward(const inp_type& inp)
	{
		return sp_cur_node->forward(inp);
	}

	auto backward(const ret_type& ret)
	{
		return sp_cur_node->backward(ret);
	}

	void update()
	{
		sp_cur_node->update();
	}


	template
	cal_chain& set(set_type* sp)
	{
		static_assert(remain == 0, "over over");
		if constexpr (remain == 0)
			sp_cur_node = dynamic_cast(sp);
		return *this;
	}

	template
	struct add_tool
	{
		using type = cal_chain;
	};
};

template
struct cal_chain_container :public cal_chain_node
{
	using inp_type = typename chain::inp_type;
	using ret_type = typename chain::ret_type;
	chain*		p;
	cal_chain_container(chain* pv):p(pv)
	{}

	ret_type forward(const inp_type& inp)
	{
		return p->forward(inp);
	}

	inp_type backward(const ret_type& ret)
	{
		return p->backward(ret);
	}

	void update()
	{
		p->update();
	}
};

template
auto make_chain_node(cal_chain* p)
{
	using chain_type = cal_chain;
	cal_chain_container cc(p);
	return cc;
}

#include "weight_initilizer.hpp"
#include "update_methods.hpp"

template class update_method_templ, int inp_row, int inp_col, int ret_row, typename val_t>
struct cal_chain_node_mult:public cal_chain_node
{
	using ret_type = mat;
	using inp_type = mat;
	using weight_type = mat;

	weight_type W;
	weight_type deltaW;
	inp_type pre_inp;

	update_method_templ um;
	double d_num;

	cal_chain_node_mult():d_num(0.)
	{
		weight_initilizer::cal(W);
	}

	virtual ret_type forward(const inp_type& inp) 
	{
		pre_inp = inp;
		return W.dot(inp);
	}

	virtual inp_type backward(const ret_type& delta) 
	{
		auto ret = W.t().dot(delta);
		deltaW = deltaW * d_num + delta.dot(pre_inp.t());
		d_num = d_num + 1.;
		if (d_num > 1e-7) 
		{
			deltaW = deltaW / d_num;
		}
		return ret;
	}

	virtual void update() 
	{
		W.assign<0, 0>(um.update(W, deltaW));
		deltaW = 0.;
		d_num = 0.;
	}
};

template class update_method_templ, int inp_row, int inp_col, typename val_t>
struct cal_chain_node_bias :public cal_chain_node
{
	using ret_type = mat;
	using inp_type = mat;
	update_method_templ um;

	inp_type b;
	inp_type deltab;
	double d_num;
	cal_chain_node_bias():d_num(0.)
	{}

	virtual ret_type forward(const inp_type& inp)
	{
		return b + (inp);
	}

	virtual inp_type backward(const ret_type& delta)
	{
		deltab = deltab * d_num + delta;
		d_num = d_num + 1.;
		if (d_num > 1e-7)
		{
			deltab = deltab / d_num;
		}
		return delta;
	}

	virtual void update()
	{
		b.assign<0, 0>(um.update(b, deltab));
		deltab = 0.;
		d_num = 0.;
	}
};

#include "activate_function.hpp"
template class activate_func, int inp_row, int inp_col, typename val_t>
struct cal_chain_node_act :public cal_chain_node
{
	using ret_type = mat;
	using inp_type = mat;
	activate_func act_fun;

	virtual ret_type forward(const inp_type& inp)
	{
		return act_fun.forward(inp);
	}

	virtual inp_type backward(const ret_type& delta)
	{
		return act_fun.backward() * delta;
	}

	virtual void update()
	{
	}
};

#endif

最后运算结果显示,这个多层的RNN也是有学习能力的。由于没有足够多的有关联数据(也比较懒,不想搞),所以就没正经试验过。有兴趣的小朋友们可以用自己的数据试一试。

你可能感兴趣的:(元编程学习实践,c++,rnn,开发语言)