RBM 推荐系统 Java代码(质量堪忧,仅供参考,欢迎讨论)

这是一段代码段,用Java写的,关于RBM ,也关于推荐系统。效率比较低,代码也有点问题。贴出来,仅仅是为了给大家提供一点思路,也是希望大家多多指教。仅仅贴出最重要的代码块,完整的代码,大家可以给我留言,我抽出时间来再来给大家发一份。抛砖引玉,希望能有同学或工程师来一起交流,大家共同进步,共同提高。


package rbm_1th;
import java.awt.List;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
import java.util.Random;
import static rbm_1th.utils.*;
import static rbm_1th.Get_data.*;
public class RBM
{
    public int N;
    public int n_visible;
    public int n_hidden;
    public double[][] W;
    public double[] hbias;
    public double[] vbias;
    public Random rng;
    //RBM的构造函数
    public RBM(int N,int n_visiable,int n_hidden,double[][] W,double[] hbias,double[] vbias,Random rng)
    {
        this.N = N;
        this.n_visible = n_visiable;
        this.n_hidden = n_hidden;
        
        if (rng == null)
        {
            this.rng = new Random(1234);
        }
        else{
            this.rng = rng;
        }

        if (W == null){
            this.W =new double[this.n_hidden][this.n_visible];
            double a = 1.0/this.n_visible;

            for (int i = 0; i < this.n_hidden; i++)
            {
                for (int j = 0; j< this.n_visible ; j++ )
                {
                    this.W[i][j] = uniform(-a,a,rng);
                }
            }
        }else
        {
            this.W = W;
        }
        
        if(hbias == null){
            this.hbias = new double[this.n_hidden];
            for (int i = 0;i input,double lr,int k)
    {
        
        double[] ph_mean = new double[n_hidden];
        int[] ph_sample = new int[n_hidden];
        double[] nv_means = new double[n_visible];
        Map nv_samples = new HashMap<>();
        double[] nh_means = new double[n_hidden];
        int[] nh_samples = new int[n_hidden];
        
        sample_h_given_v(input,ph_mean,ph_sample);
        for (int step = 0; step < k;step++)
        {
            if (step == 0)
            {
                gibbs_hvh(ph_sample,nv_means,nv_samples,nh_means,nh_samples);

            }else {
                gibbs_hvh(nh_samples,nv_means,nv_samples,nh_means,nh_samples);
            }
        }
        /**这块儿代码可能有点问题,即在推荐系统里面是应该对所有的权重都改变还是只改变有过用户行为的显层和隐层的连接权重,这会儿脑子太乱,不想想了,欢迎大家拍砖指正**/
        for (int i=0; i v,double[] w,double b)
    {
        double pre_sigmoid_activation = 0.0;
        for (int j :v.keySet() )
        {
            pre_sigmoid_activation += w[j] * v.get(j);
        }
        pre_sigmoid_activation += b;
        return sigmoid(pre_sigmoid_activation);
    }
    
    //隐藏层到可见层
    public double propdown(int[] h,int j,double b)
    {
        double pre_sigmoid_activation = 0.0;
        for (int i =0; i < h.length ; i++ )
        {
            pre_sigmoid_activation += W[i][j] * h[i];
        }
        pre_sigmoid_activation += b;
        
        return sigmoid(pre_sigmoid_activation);
    }
    
    //吉布斯
    public void gibbs_hvh(int[] h0_sample,double[] nv_means,Map nv_samples,double[] nh_means,int[] nh_samples)
    {
        sample_v_given_h(h0_sample,nv_means,nv_samples);
        sample_h_given_v(nv_samples,nh_means,nh_samples);
    }
    
    //sample given hidden get visible
    public void sample_v_given_h(int[] h0_sample,double[] mean,Map sample)
    {
        for (int j:sample.keySet() )
        {
            mean[j] = propdown(h0_sample, j,vbias[j]);
            int oz = binomial(1,mean[j],rng);
            sample.put(j, oz);
        }
    }

    //given visible get hidden
    public void sample_h_given_v(Map v0_sample,double[] mean , int[] sample)
    {
        for (int i = 0;i v,double[] reconstructed_v)
    {
        double[] h = new double[n_hidden];
        double pre_sigmoid_activation;

        for (int i = 0;i> u_p_s = loadData(inpath);
        Map pid2num = get_pid2num(u_p_s);
        Map u_means = getU_mean(u_p_s);


        int train_N = u_p_s.size();
        
        int n_visible = pid2num.size();
        int n_hidden = 2;
        RBM rbm = new RBM(train_N,n_visible,n_hidden,null,null,null,rng);
        ArrayList uid_list = new ArrayList(u_p_s.keySet());

    
        for(int epoch = 0;epoch < training_epochs;epoch++)
        {
            for (int i =0 ; i  p_s = u_p_s.get(uid);
                
                double u_m = u_means.get(uid);

                Map pid_oz = new HashMap();

                for (String pid:p_s.keySet())
                {
                    Double score = u_p_s.get(uid).get(pid);
                    int x = pid2num.get(pid);
                    if (score > u_m)
                    {
                        pid_oz.put(x,1);
                    }
                    else
                    {
                        pid_oz.put(x,0);
                    }

                }
                rbm.contrastive_divergence(pid_oz,learning_rate,k);//注意此处
            }
        }
        
        return rbm.W;
        
        
    }
    
    

    public static void main(String[] args)
    {
        double[][] w =train_rbm();
        
    }
    
}



你可能感兴趣的:(机器学习)