【deep learning学习笔记】注释yusugomori的DA代码 --- dA.h

DA就是“Denoising Autoencoders”的缩写。继续给yusugomori做注释,边注释边学习。看了一些DA的材料,基本上都在前面“转载”了。学习中间总有个疑问:DA和RBM到底啥区别?(别笑,我不是“学院派”的看Deep Learning理论,如果“顺次”看下来,可能不会有这个问题),现在了解的差不多了,详情见:【deep learning学习笔记】Autoencoder。之后,又有个疑问,DA具体的权重更新公式是怎么推导出来的?我知道是BP算法,不过具体公示的推导、偏导数的求解,没有看到哪个材料有具体的公式,所以姑且认为yusugomori的代码写的是正确的。


注释后的头文件:

 

// The Class of denoising auto-encoder
class dA 
{
public:
	int N;			// the number of training samples
	int n_visible;	// the number of visible nodes
	int n_hidden;	// the number of hidden nodes
	double **W;		// the weight connecting visible node and hidden node
	double *hbias;	// the bias of hidden nodes
	double *vbias;	// the bias of visible nodes

public:
	// initialize the parameters
	dA ( int,		// N
		 int,		// n_visible
		 int ,		// n_hidden
		 double**,	// W
		 double*,	// hbias
		 double*	// vbias
		 );
	~dA();

	// make the input noised
	void get_corrupted_input (
				int*,		// the original input 0-1 vector			-- input
				int*,		// the resulted 0-1 vector gotten noised	-- output
				double		// the p probability of noise, binomial test -- input
				);
	// encode process: calculate the probability output from hidden node
	// p(hi|v) = sigmod ( sum_j(vj * wij) + bi), it's same with RBM
	// but different from RBM, it dose not generate 0-1 state from Bernoulli distribution
	void get_hidden_values (
				int*,		// the input from visible nodes
				double*		// the output of hidden nodes
				);
	// decode process: calculate the probability output from visiable node
	// p(vi|h) = sigmod ( sum_j(hj * wij) + ci), it's same with RBM
	// but different from RBM, it dose not generate 0-1 state from Bernoulli distribution 
	void get_reconstructed_input (
				double*,	// the input from hidden nodes
				double*		// the output reconstructed of visible nodes
				);
	// train the model by a single sample
	void train (
				int*,		// the input sample from visiable node
				double,		// the learning rate
				double		// corruption_level is the probability of noise
				);
	// reconstruct the input sample
	void reconstruct (
				int*,		// the input sample		-- input
				double*		// the reconstructed value -- output
				);
};


 

 

你可能感兴趣的:(学习笔记)