基于三层BP神经网络的人脸识别

实验四、基于三层BP神经网络的人脸识别
一、 实验要求
采用三层前馈BP神经网络实现标准人脸YALE数据库的识别,编程语言为C系列语言。
二、BP神经网络的结构和学习算法
实验中建议采用如下最简单的三层BP神经网络,输入层为 ,有n个神经元节点,输出层具有m个神经元,网络输出为 ,隐含层具有k个神经元,采用BP学习算法训练神经网络。

BP神经网络的结构
BP网络在本质上是一种输入到输出的映射,它能够学习大量的输入与输出之间的映射关系,而不需要任何输入和输出之间的精确的数学表达式,只要用已知的模式对BP网络加以训练,网络就具有输入输出对之间的映射能力。
BP网络执行的是有教师训练,其样本集是由形如(输入向量,期望输出向量)的向量对构成的。在开始训练前,所有的权值和阈值都应该用一些不同的小随机数进行初始化。

基于三层BP神经网络的人脸识别_第1张图片

import java.awt.image.BufferedImage;
import java.io.File;
import java.io.IOException;
import java.util.Scanner;

import javax.imageio.ImageIO;

public class ANN {
	
	static int B = 16;
//	输入层神经元个数
	static int N = 2*B*B*2;
//	隐含层神经元个数
	static int L = B*B;
//	输出层神经元个数
	static int M = 16;
//	进度控制产生
	static double yiphxir = 0.001;
//	学习率
	static double arph = 0.2;
	static double a = -1.0;
	static double max = 255;
//	输入层到隐含层的权值
	static double[][] V = new double[N][L];
//	隐含层到输出层的权值
	static double[][] W = new double[L][M];
//	隐含层的阈值
	static double[] fai = new double[L];
//	输出层的阈值
	static double[] sita = new double[M];
	
	static double[][] deltaW = new double[L][M];
	static double[] deltaSita = new double[M];
	
	static double[][] deltaV = new double[N][L];
	static double[] deltaFai = new double[L];
	static double[] FE = new double[M];
	static double[] getImagePixel(String image) {
        File file = new File(image);
        BufferedImage bi = null;
        try {
            bi = ImageIO.read(file);
        } catch (IOException e) {
        
            e.printStackTrace();
        }
        int width = bi.getWidth()/2;
        int height = bi.getHeight()/2;
        double[] pixel = new double[N];
        int cnt = 0;
    	for(int y = height-B; y < height+B; y++) {
            for(int x = width-B; x < width+B; x++) {
            	int t = bi.getRGB(x, y);
            	t = (t & 0xff);
            	pixel[cnt] = t/max;
                cnt++;
            }
        }
        return pixel;
    }
	
	static double FS(double x){
		return 1.0/(1.0+Math.exp(a*x)); 
	}
	
	static void init(){
		for(int i=0; i<N; i++){
			for(int j=0; j<L; j++){
				V[i][j] = Math.random()*2.0-1.0;
			}
		}
		for(int j=0; j<L; j++){
			for(int k=1; k<M; k++){
				W[j][k] = Math.random()*2.0-1.0;
			}
		}
		for(int j=0; j<L; j++){
			fai[j] = Math.random()*2.0-1.0;
		}
		for(int k=1; k<M; k++){
			sita[k] = Math.random()*2.0-1.0;
		}
	}
	
	static boolean trainBP(double[] X, int[] D){
		double[] H = new double[L];
		for(int j=0; j<L; j++){
			double derta = 0.0;
			for(int i=0; i<N; i++){
				derta += X[i]*V[i][j];
			}
//			System.out.print("derta"+Double.toString(derta));
//			System.out.println("    derta-fai="+Double.toString(derta-fai[j]));
			H[j] = FS(derta+fai[j]);
//			System.out.println(Double.toString(H[j])+" ");
		}
		double[] Y = new double[M];
		for(int k=1; k<M; k++){
			double derta = 0.0;
			for(int j=0; j<L; j++){
				derta += H[j]*W[j][k];
			}
//			System.out.print("derta"+Double.toString(derta));
//			System.out.println("    derta-sita="+Double.toString(derta-sita[k]));
			Y[k] = FS(derta+sita[k]);
//			System.out.print(Double.toString(Y[k])+" ");
		}
//		System.out.println();
		
		double[] deltaK = new double[M];
		double E = 0.0;
		for(int k=1; k<M; k++){
			E += (D[k] - Y[k])*(D[k] - Y[k]);
			deltaK[k] = (D[k]-Y[k])*Y[k]*(1-Y[k]);
		}
//		System.out.print("E=");
//		System.out.println(E/2);
		if(E/2 < yiphxir) return true;
		
		double[] deltaJ = new double[L];
		for(int j=0; j<L; j++){
			double beta = 0.0;
			for(int k=1; k<M; k++){
				beta += deltaK[k]*W[j][k];
			}
			deltaJ[j] = H[j]*(1-H[j])*beta;
		}
		
		for(int j=0; j<L; j++){
			for(int k=1; k<M; k++){
				deltaW[j][k] = arph*deltaK[k]*H[j];
//				deltaW[j][k] = (arph/(1+L))*(deltaW[j][k]+1)*deltaK[k]*H[j];
				W[j][k] += deltaW[j][k];
			}
		}
		for(int k=1; k<M; k++){
			deltaSita[k] = arph*deltaK[k];
//			deltaSita[k] = (arph/(1+L))*(detaSitla[k]+1)*deltaK[k];
			sita[k] += deltaSita[k];
		}
		
		for(int i=0; i<N; i++){
			for(int j=0; j<L; j++){
				deltaV[i][j] = arph*deltaJ[j]*X[i];
//				deltaV[i][j] = (arph/(1+N))*(deltaV[i][j]+1)*deltaJ[j]*X[i];
				V[i][j] += deltaV[i][j];
			}
		}
		for(int j=0; j<L; j++){
			deltaFai[j] = arph*deltaJ[j];
//			deltaFai[j] = (arph/(1+N))*(deltaFai[j]+1)*deltaJ[j];
			fai[j] += deltaFai[j];
		}
		
		return false;
	}
	
	static int BP(double[] X){
		int ans = 0;
		double[] H = new double[L];
		for(int j=0; j<L; j++){
			double derta = 0.0;
			for(int i=0; i<N; i++){
				derta += X[i]*V[i][j];
			}
			H[j] = FS(derta-fai[j]);
		}
		double[] Y = new double[M];
		for(int k=1; k<M; k++){
			double derta = 0.0;
			for(int j=0; j<L; j++){
				derta += H[j]*W[j][k];
			}
			Y[k] = FS(derta-sita[k]);
		}
		
		double min = Double.MAX_VALUE;
		
		for(int n=1; n<M; n++){
			double E = 0.0;
			int[] D = new int[M];
			D[n] = 1;
			for(int k=1; k<M; k++){
				E += Math.abs((D[k] - Y[k])*(D[k] - Y[k]));
			}
			FE[n] = 1.0/Math.exp(E/2.0);
			if(E/2 < min){
				min = E/2;
				ans = n;
			}
		}
		return ans;
	}
	
	@SuppressWarnings("resource")
	public static void main(String[] args) {
		init();
		int temp = 0;
		while(temp < 100){
			int t = 1;
			while(t <= 11){
				int ca = 1;
				while(ca < M){
//					System.out.println(t);
//					System.out.println(ca);
					String imagePath;
					if(ca >= 10){
//						imagePath = "C:\\Users\\DELL\\Downloads\\AAII\\src\\image\\subject01_10.bmp";
						imagePath = "C:\\Users\\DELL\\Downloads\\AAII\\src\\image\\"+"subject"+Integer.toString(ca)+"_"+Integer.toString(t)+".bmp";
					}
					else imagePath = "C:\\Users\\DELL\\Downloads\\AAII\\src\\image\\"+"subject0"+Integer.toString(ca)+"_"+Integer.toString(t)+".bmp";
					double[] X = getImagePixel(imagePath);
					int[] D = new int[M];
					D[ca] = 1;
					boolean flag = trainBP(X, D);
					if(flag) break;
					ca++;
				}
				t++;
			}
			temp++;
		}
		Scanner in = new Scanner(System.in);
		while(true){
			System.out.println("请输入subject(t1)和_(t2):");
			String t1 = in.next();
			String t2 = in.next();
			String imagePath = "C:\\Users\\DELL\\Downloads\\AAII\\src\\image\\"+"subject"+t1+"_"+t2+".bmp";


			int ans = BP(X);
			System.out.println("(subject01~subject15)的匹配率如下:");
			for(int i=1; i<M; i++){
				System.out.println("subject"+Integer.toString(i)+":"+String.format("%g", FE[i]*100)+"%");
			}
			System.out.println();
			System.out.print("综合结果为subject");
			System.out.println(ans);
			System.out.println();
		}
	}
}

图片地址:https://download.csdn.net/download/qq_43179428/19880190

你可能感兴趣的:(随笔)