C++元编程——池化层实现

池化层正向都是求区域最大值,反向有两种方法:1、记住最大值位置,将回传值赋值到该位置;2、平均值平均分配到区域中。下面就是实现啦,实现了这两种方法:

#ifndef _POOL_LAYER_HPP_
#define _POOL_LAYER_HPP_
#include "base_logic.hpp"

/* 池化层,对区域内取最大值 */

template class update_method_tpl, int input_row, int input_col, int tpl_row, int tpl_col, typename val_t = double>
struct pool_layer 
{
	using input_type = mat;
	using ret_type = mat;
	update_method_tpl	um;			// 更新算法

	ret_type forward(const input_type& mt) 
	{
		return um.forward(mt);
	}

	input_type backward(const ret_type& mt)
	{
		return um.backward(mt);
	}
};

template
struct region_max_cal
{
	template
	inline static val_t cal(const mat& mt, omat_p_t& p_omt_r, omat_p_t& p_omt_c)
	{
		int i_max_r = 0, i_max_c = 0;
		val_t ret = mt.region_max(i_max_r, i_max_c);
		(*p_omt_r).get_val() = (i_max_r);
		(*p_omt_c).get_val() = (i_max_c);
		return ret;
	}
};

template
struct pool_layer_max 
{
	using input_type = mat;
	using ret_type = mat;

	template
	struct region_max_local
	{
		template
		inline static val_t cal(const mat& mt, omat_p_t& p_omt_r, omat_p_t& p_omt_c)
		{
			return region_max_cal::cal(mt, p_omt_r, p_omt_c);
		}
	};

	mat mt_max_row, mt_max_col;

	ret_type forward(const input_type& mt)
	{
		ret_type ret;
		col_loop(ret, mt, &mt_max_row, &mt_max_col);
		return ret;
	}

	input_type backward(const ret_type& mt) 
	{
		input_type ret;
		for (int r = 0; r < ret_type::r; ++r)
		{
			for (int c = 0; c < ret_type::c; ++c) 
			{
				ret.get(mt_max_row.get(r, c), mt_max_col.get(r, c)) = mt.get(r, c);
			}
		}
		return ret;
	}
};

template
struct pool_layer_average
{
	using input_type = mat;
	using ret_type = mat;

	template
	struct region_max_local
	{
		template
		inline static val_t cal(const mat& mt, omat_p_t& p_omt_r, omat_p_t& p_omt_c)
		{
			return region_max_cal::cal(mt, p_omt_r, p_omt_c);
		}
	};

	ret_type forward(const input_type& mt)
	{
		ret_type ret;
		mat mt_max_row, mt_max_col;
		col_loop(ret, mt, &mt_max_row, &mt_max_col);
		return ret;
	}

	input_type backward(const ret_type& mt)
	{
		input_type ret;
		for (int r = 0; r < input_type::r; ++r)
		{
			for (int c = 0; c < input_type::c; ++c)
			{
				ret.get(r, c) = mt.get(r/tpl_row, c/tpl_col)/(tpl_row*tpl_col);
			}
		}
		return ret;
	}
};

#endif

你可能感兴趣的:(元编程学习实践,c++,算法,开发语言)