多层感知机的手写数字识别,迭代10次对训练集的正确率97
Main函数,在绘制完数字后,要点下确定按钮再去识别,重绘按钮自然是再次绘图
训练自己的网络结构会替换之前训练的网络结构,没有写保存或者另存新网络模型。结果对训练集变现很好,对绘图的识别结果仍不是很理想。
package main;
import java.awt.Color;
import java.awt.Container;
import java.awt.Graphics;
import java.awt.Image;
import java.awt.event.ActionEvent;
import java.awt.event.ActionListener;
import java.awt.image.BufferedImage;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.util.Scanner;
import javax.imageio.ImageIO;
import javax.swing.ImageIcon;
import javax.swing.JButton;
import javax.swing.JFileChooser;
import javax.swing.JFrame;
import javax.swing.JLabel;
import javax.swing.JPanel;
import javax.swing.JTextField;
import javax.swing.JTextPane;
import javax.swing.text.Style;
import javax.swing.text.StyleConstants;
import javax.swing.text.StyleContext;
import imageprocess.getimage;
import network.NetWork;
import network.traindata;
public class GUI extends JFrame {
private JFrame jFrame;
private BufferedImage img;//用于显示输入图片
private JButton sure;//手写输入确定
private JButton cancel;//手写输入确定
private JButton recognition;//识别
private JButton train;//训练自己的网络
private JButton open;
private JTextField result;
private int[][] getmatrix=new int[28][28];
private JTextPane imgtextarea;
private JLabel imglabel;
private static NetWork neunet;
private JFileChooser choose;//选择文件
//private int k=0;
public GUI() {
neunet=new NetWork(10,0.01,50,0.2,0.5);
neunet.initNodes();
jFrame=new JFrame("数字识别");
jFrame.setBounds(0, 0, 765, 800);
jFrame.setLayout(null);
recognition=new JButton("识别结果");
train=new JButton("训练");
open=new JButton("打开图片");
sure=new JButton("确定");
cancel=new JButton("重绘");
JPanel resultpanel = new JPanel();
final mypanel panel = new mypanel();//新建画板
Container contentPane = getContentPane();
contentPane.setBounds(0, 0,350,350);
contentPane.add(panel);
jFrame.add(contentPane);
JPanel draw=new JPanel();//画板桌布
draw.setBounds(0, 0, 380,420);
draw.setLayout(null);
draw.setBackground(Color.lightGray);
jFrame.add(draw);
draw.add(sure);
sure.setBounds(10, 370, 60, 30);
draw.add(cancel);
cancel.setBounds(130, 370, 60, 30);
draw.add(train);
train.setBounds(250, 370, 60, 30);
open.setBounds(420, 320, 90, 30);
recognition.setBounds(560, 320, 90, 30);
imgtextarea=new JTextPane();
Style style=new StyleContext().new NamedStyle();
StyleConstants.setLineSpacing(style,-0.1f);
StyleConstants.setFontSize(style, 7);
StyleConstants.setBold(style, true);
imgtextarea.setLogicalStyle(style);
imglabel=new JLabel();
imgtextarea.setSize(50, 70);
imgtextarea.setBounds(300, 0, 200, 60);
imglabel.setBounds(200,10,100,100);
imgtextarea.setEditable(false);
choose = new JFileChooser();
choose.setCurrentDirectory(new File("."));
resultpanel.add(imglabel);
resultpanel.add(imgtextarea);
result=new JTextField();
result.setBounds(560, 360, 140, 50);
result.setVisible(true);
resultpanel.setBounds(381, 0, 350, 310);
resultpanel.setBackground(Color.gray);
jFrame.add(resultpanel);
jFrame.add(result);
jFrame.add(recognition);
jFrame.add(open);
jFrame.setSize(761, 450);
jFrame.setVisible(true);
sure.addActionListener(new ActionListener() {
@Override
public void actionPerformed(ActionEvent actionevent) {
// TODO Auto-generated method stub
BufferedImage image=new BufferedImage(panel.getWidth(), panel.getHeight(), BufferedImage.TYPE_INT_RGB);
Graphics gs=image.getGraphics();
panel.paintAll(gs);
gs.drawImage(image, 0, 0, panel.getWidth(), panel.getHeight(), null);
try {
ImageIO.write(image, "png", new File("./save.jpg"));
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
try {
getmatrix=getimage.getMatirx(image);
if(image.getHeight()>28||image.getWidth()>28){
image=getimage.scale(image, 28, 28);}
imglabel.setIcon(new ImageIcon((Image)image));//把图片作为icon显示
for (int i = 0; i < getmatrix.length-1; i++) {
for (int j = i+1; j < getmatrix.length; j++) {
int temp=getmatrix[i][j];
getmatrix[i][j]=getmatrix[j][i];
getmatrix[j][i]=temp;
}
}
imgtextarea.setText("");
String s="";
for(int i=0;i28||img.getHeight()>28)
{img=getimage.scale(img, 28, 28);//缩放图片
}
imglabel.setIcon(new ImageIcon((Image)img));//把图片作为icon显示
int a[]=traindata.SampleMatirx(img);
for (int i = 0; i < a.length; i++) {
getmatrix[i/28][i%28]=a[i];
}
String s=" ";
for(int i=0;i
获取数据集二值化矩阵保存用于训练,使用了40000张图像数据作为训练集
package network;
import java.awt.image.BufferedImage;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.PrintStream;
import javax.imageio.ImageIO;
import imageprocess.binaryimage;
import imageprocess.trainbrainimage;
public class traindata {
private static int sampleNumber=40000;
public static void imagetomatrix() throws IOException{
int input[][]=new int[sampleNumber][784];
try {
for(int i=0;i<=9;i++) {
for(int j=0;j<=3999;j++) {
BufferedImage image;
image=ImageIO.read(new File("./mnist/mnist_data/"+i+"."+j+".jpg"));
input[j+i*4000]=SampleMatirx(image);
}
}
File f=new File("./allInput.txt");
if(!f.exists()){
f.createNewFile();}
else{
FileOutputStream opf=new FileOutputStream("./allInput.txt");
PrintStream s=new PrintStream(opf);
for(int i=0;i
图像灰度二值化处理:
package imageprocess;
import java.awt.Color;
import java.awt.image.BufferedImage;
import java.io.IOException;
public class trainbrainimage {
private int gray[][]=null;//存储图像灰度值
public int brimage[][]=null;//存储图像二值化后灰度值
public BufferedImage image;
public void brmatrix(BufferedImage bi) throws IOException {
int h=bi.getHeight();//获取图像的高
int w=bi.getWidth();//获取图像的宽
gray=new int[w][h];
brimage=new int[w][h];
for (int x = 0; x < w; x++) {
for (int y = 0; y < h; y++) {
gray[x][y]=getGray(bi.getRGB(x, y));
}
}
BufferedImage nbi=new BufferedImage(w,h,BufferedImage.TYPE_BYTE_BINARY);
int SW=125;
for (int x = 0; x < w; x++) {
for (int y = 0; ySW){
int max=new Color(255,255,255).getRGB();
nbi.setRGB(x, y, max);
brimage[x][y]=1;
}else{
int min=new Color(0,0,0).getRGB();
nbi.setRGB(x, y, min);
brimage[x][y]=0;
}
}
}
this.image=nbi;
System.gc();
}
private int getGray(int rgb){
String str=Integer.toHexString(rgb);
int r=Integer.parseInt(str.substring(2,4),16);
int g=Integer.parseInt(str.substring(4,6),16);
int b=Integer.parseInt(str.substring(6,8),16);
Color c=new Color(rgb);
r=c.getRed();
g=c.getGreen();
b=c.getBlue();
int top=(r+g+b)/3;
return (int)(top);
}
private int getAverageColor(int[][] gray, int x, int y, int w, int h)
{
int rs = gray[x][y]
+ (x == 0 ? 0 : gray[x - 1][y])
+ (x == 0 || y == 0 ? 0 : gray[x - 1][y - 1])
+ (x == 0 || y == h - 1 ? 0 : gray[x - 1][y + 1])
+ (y == 0 ? 0 : gray[x][y - 1])
+ (y == h - 1 ? 0 : gray[x][y + 1])
+ (x == w - 1 ? 0 : gray[x + 1][ y])
+ (x == w - 1 || y == 0 ? 20 : gray[x + 1][y - 1])
+ (x == w - 1 || y == h - 1 ? 0 : gray[x + 1][y + 1]);
return rs / 9;
}
}
图像数据和数字是每4000张一个类(1-4000数字0,4001-8000数字1),数据集中所以要打乱样本训练,网络结构代码
package network;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Scanner;
public class NetWork {
private int iteration; //迭代次数
public static int[][] allInput;//记录所有训练样本的所有矩阵//测试集
private double stepsize; //移动步长学习率
private double weighRange; //用于规范初始化权值
private double momentum; //动量调节因子
private int inputsize=784; //输入点值
private int hinddensize=50;//隐层节点值
private int outputsize=10; //输出节点个数
private int[] inputnode;
private node[] hiddennode;
private node[] outputnode;
//权值大小及更新时所用
private double [] hinddenDelta;
private double [] outputDelta;
private double [][]inputweight;
private double [][]oldInputeight;
private double [][]outputweight;
private double [][]oldoutputweight;
private int success=0;
private double error=0;
private double errorrate;
private double successrate;
private char []type; //保存测试数据的输出类型
//初始化构造
public NetWork(int iteration,double stepsize,int hinddensize,double weighRange,double momentum) {
this.iteration=iteration;
this.stepsize=stepsize;
this.hinddensize=hinddensize;
this.weighRange=weighRange;
this.momentum=momentum;
//this.type=new char[test];
inputnode=new int[inputsize];
hiddennode=new node[hinddensize];
outputnode=new node[outputsize];
hinddenDelta=new double [hinddensize];
outputDelta=new double[outputsize];
inputweight=new double[inputsize][hinddensize];
oldInputeight=new double[inputsize][hinddensize];
outputweight=new double[hinddensize][outputsize];
oldoutputweight=new double[hinddensize][outputsize];
}
public double getsuccess() {
return successrate;
}
public double geterror() {
return errorrate;
}
public char[] gettype() {
return type;
}
public void initNetwork() {
initNodes(); //初始化节点参数
initWeights(weighRange);//初始化权值
}
private void initWeights(double weighRange2) {
// TODO Auto-generated method stub
for(int i=0;iarrayList=new ArrayList<>();
for(int i=0;i<40000;i++)
arrayList.add(i);
Collections.shuffle(arrayList);
NetWork.getAllInput();
while(iteration>0) {
for(int i=0;i b[maxIndex]){
maxIndex = i;
}
}
return maxIndex;
}
public String recognize(int[] a) throws FileNotFoundException {
// TODO Auto-generated method stub
for (int i = 0; i < a.length; i++) {
inputnode[i]=a[i];
}
File file = new File("./height.txt");
Scanner in = new Scanner(file);
for(int i=0;i
节点:
package network;
public class node {
private double activation;
private double bias;
private double oldbais;
public node(double a,double b) {
this.activation=a;
this.bias=b;
}
public double getActivation() {
return this.activation;
}
public void setActivation(double activation) {
this.activation = activation;
}
public double getBias() {
return this.bias;
}
public void setBias(double bias) {
this.bias = bias;
}
public double getOldbais() {
return this.oldbais;
}
public void setOldbais(double oldbais) {
this.oldbais = oldbais;
}
public String toString() {
return this.activation + " " + this.bias;
}
}
绘制待识别数字界面:
package main;
import java.awt.*;
import java.awt.event.*;
import java.util.Vector;
import javax.swing.JPanel;
public class mypanel extends JPanel {
private static final long serialVersionUID = 1L;
private Vector> FreedomDatas = new Vector>();
private Color lineColor = Color.white;
private int lineWidth = 16;
public mypanel()
{ //setBorder(BorderFactory.createLineBorder(Color.BLACK));
addMouseListener(new MouseAdapter()
{
public void mousePressed(MouseEvent e)
{
Point p = new Point(e.getX(),e.getY());
Vector newLine = new Vector();
newLine.add(p);
FreedomDatas.add(newLine);
}
public void mouseReleased(MouseEvent e)
{
repaint();
}
});
addMouseMotionListener(new MouseMotionAdapter()
{
public void mouseDragged(MouseEvent e)
{
Point p = new Point(e.getX(),e.getY());
int n = FreedomDatas.size()-1; //拿到最后一条线的位置
Vector lastLine = FreedomDatas.get(n);
lastLine.add(p);
}
});
}
public void cleanAll()
{
FreedomDatas.clear();
repaint();
}
public void paint(Graphics g)
{
g.fillRect(0, 0, getWidth(), getHeight());
g.setColor(lineColor);
Graphics2D g_2D = (Graphics2D)g;
BasicStroke stroke = new BasicStroke(lineWidth,BasicStroke.CAP_ROUND,BasicStroke.JOIN_ROUND);
g_2D.setStroke(stroke);
Vector v;
Point s,e;
int i,j,m;
int n = FreedomDatas.size();
for(i=0;i
抓取绘图界面的数字图像并处理二值化
package imageprocess;
import java.awt.Color;
import java.awt.Graphics;
import java.awt.GraphicsConfiguration;
import java.awt.GraphicsDevice;
import java.awt.GraphicsEnvironment;
import java.awt.HeadlessException;
import java.awt.Image;
import java.awt.Toolkit;
import java.awt.Transparency;
import java.awt.geom.AffineTransform;
import java.awt.image.AffineTransformOp;
import java.awt.image.BufferedImage;
import java.awt.image.CropImageFilter;
import java.awt.image.FilteredImageSource;
import java.awt.image.ImageFilter;
import java.io.File;
import java.io.IOException;
import javax.imageio.ImageIO;
import javax.swing.ImageIcon;
public class getimage {
public final static int[][] getMatirx(BufferedImage bi) throws IOException{
binaryimage biimg=new binaryimage();//创建二值图类,分别保存二值化后的图片及矩阵
int h=bi.getHeight();
int w=bi.getWidth();
if(h>800||w>800){
bi=scale(bi, 800, 800);
h=bi.getHeight();
w=bi.getWidth();
}
biimg.brmatrix(bi);//二值图保;
int bi_matrix[][]=biimg.brimage;//二值矩阵
int left=0,right=0,top=0,below=0;
int row[]=new int[w];
for(int i=0;i=2){
if(left==0){
left=i;
}
if(right=2){
if(top==0){
top=i;
}
if(below=2){
new_w=(new_h/2%28)+new_h/2;
top=top-(28-(below-top)%28)/2;
left=left-(new_w-right+left)/2;
}
else{
top=top-(28-(below-top)%28)/2;
left=left-(28-(right-left)%28)/2;
}
biimg.image=cut(biimg.image, left, top, w, h, new_w, new_h);
int InputMatrix[][]=new int[28][28];
InputMatrix=cut2(biimg.image, 28, 28);
return InputMatrix;
}
public final static BufferedImage scale(BufferedImage bi, int height, int width) {
double ratio = 0.0; // 缩放比例
Image temp = bi.getScaledInstance(width, height, Image.SCALE_SMOOTH);
if ((bi.getHeight() > height) || (bi.getWidth() > width)) {
if (bi.getHeight() > bi.getWidth()) {
ratio = (new Integer(height)).doubleValue()/ bi.getHeight();
} else {
ratio = (new Integer(width)).doubleValue()/ bi.getWidth();
}
AffineTransformOp op = new AffineTransformOp(AffineTransform
.getScaleInstance(ratio, ratio), null);
temp = op.filter(bi, null);
}
return toBufferedImage(temp);
}
public final static BufferedImage toBufferedImage(Image image) {
if (image instanceof BufferedImage) {
return (BufferedImage) image;
}
image = new ImageIcon(image).getImage();
boolean hasAlpha = false;
BufferedImage bimage = null;
GraphicsEnvironment ge = GraphicsEnvironment
.getLocalGraphicsEnvironment();
try {
int transparency = Transparency.OPAQUE;
if (hasAlpha) {
transparency = Transparency.BITMASK;
}
GraphicsDevice gs = ge.getDefaultScreenDevice();
GraphicsConfiguration gc = gs.getDefaultConfiguration();
bimage = gc.createCompatibleImage(image.getWidth(null),
image.getHeight(null), transparency);
} catch (HeadlessException e) {
}
if (bimage == null) {
int type = BufferedImage.TYPE_INT_RGB;
if (hasAlpha) {
type = BufferedImage.TYPE_INT_ARGB;
}
bimage = new BufferedImage(image.getWidth(null),
image.getHeight(null), type);
}
Graphics g = bimage.createGraphics();
g.drawImage(image, 0, 0, null);
g.dispose();
return bimage;
}
public final static BufferedImage cut(BufferedImage bi,int x, int y,int w,int h,int new_w,int new_h) throws IOException {
Image image = bi.getScaledInstance(w, h,
Image.SCALE_DEFAULT);
ImageFilter cropFilter = new CropImageFilter(x, y, new_w, new_h);
Image img = Toolkit.getDefaultToolkit().createImage(
new FilteredImageSource(image.getSource(),
cropFilter));
BufferedImage tag = new BufferedImage(new_w, new_h, BufferedImage.TYPE_INT_RGB);
Graphics g = tag.getGraphics();
g.drawImage(img, 0, 0, new_w, new_h, null); // 绘制切割后的图
g.dispose();
ImageIO.write(tag, "jpg", new File("./after.jpg"));
return tag;
}
public final static int[][] cut2( BufferedImage bi,int rows, int cols) {
int InputMatrix[][]=new int[rows][cols];
try {
int srcWidth = bi.getHeight();
int srcHeight = bi.getWidth();
if (srcWidth > 0 && srcHeight > 0) {
Image img;
ImageFilter cropFilter;
Image image = bi.getScaledInstance(srcWidth, srcHeight, Image.SCALE_DEFAULT);
int destWidth = srcWidth;
int destHeight = srcHeight;
if (srcWidth % cols == 0) {
destWidth = srcWidth / cols;
} else {
destWidth = (int) Math.floor(srcWidth / cols) + 1;
}
if (srcHeight % rows == 0) {
destHeight = srcHeight / rows;
} else {
destHeight = (int) Math.floor(srcWidth / rows) + 1;
}
for (int i = 0; i < rows; i++) {
for (int j = 0; j < cols; j++) {
cropFilter = new CropImageFilter(j * destWidth, i * destHeight,
destWidth, destHeight);
img = Toolkit.getDefaultToolkit().createImage(
new FilteredImageSource(image.getSource(),
cropFilter));
BufferedImage tag = new BufferedImage(destWidth,
destHeight, BufferedImage.TYPE_INT_RGB);
Graphics g = tag.getGraphics();
g.drawImage(img, 0, 0, null); // 绘制缩小后的图
g.dispose();
if(IsBlank(tag,destWidth,destHeight)==true){
InputMatrix[i][j]=1;
}
else{
InputMatrix[i][j]=0;
}
}
}
}
} catch (Exception e) {
e.printStackTrace();
}
return InputMatrix;
}
public final static boolean IsBlank(BufferedImage tag,int destWidth,int destHeight){
boolean blank=true;
int gray[][]=new int[destWidth][destHeight];
for (int x = 0; x < destWidth; x++) {
for (int y = 0; y < destHeight; y++) {
gray[x][y]=getGray(tag.getRGB(x, y));
}
}
for(int i=0;i
package imageprocess;
import java.awt.Color;
import java.awt.image.BufferedImage;
import java.io.File;
import java.io.IOException;
import javax.imageio.ImageIO;
public class binaryimage {
private int gray[][]=null;//存储图像灰度值
public int brimage[][]=null;//存储图像二值化后灰度值
private int gra[][]=null;//给图像添白框,方便去噪
public BufferedImage image;
public void brmatrix(BufferedImage bi) throws IOException {
int h=bi.getHeight();//获取图像的高
int w=bi.getWidth();//获取图像的宽
gray=new int[w][h];
brimage=new int[w][h];
for (int x = 0; x < w; x++) {
for (int y = 0; y < h; y++) {
gray[x][y]=getGray(bi.getRGB(x, y));
}
}
Brighter(gray,w,h);
gra=new int[w+4][h+4];
for(int i=0;i1&&i1&&jSW){
int max=new Color(255,255,255).getRGB();
nbi.setRGB(x, y, max);
brimage[x][y]=1;
}else{
int min=new Color(0,0,0).getRGB();
nbi.setRGB(x, y, min);
brimage[x][y]=0;
}
}
}
this.image=nbi;
System.gc();
}
private int getGray(int rgb){
String str=Integer.toHexString(rgb);
int r=Integer.parseInt(str.substring(2,4),16);
int g=Integer.parseInt(str.substring(4,6),16);
int b=Integer.parseInt(str.substring(6,8),16);
Color c=new Color(rgb);
r=c.getRed();
g=c.getGreen();
b=c.getBlue();
int top=(r+g+b)/3;
return (int)(top);
}
private int getAverageColor(int[][] gray, int x, int y, int w, int h)
{
int rs=0;
for(int i=0;i<5;i++){
for(int j=0;j<5;j++){
rs=gray[x+i][y+j]+rs;}
}
return rs / 25;
}
public static void Brighter(int[][]gray,int w,int h){
for(int x=0;x255){
gray[x][y]=255;
}
}
}
}
}
训练界面:设定迭代次数,移动步长,隐层数量,动量调节数值
package main;
import java.awt.Color;
import java.awt.Cursor;
import java.awt.Dimension;
import java.awt.event.ActionEvent;
import java.awt.event.ActionListener;
import java.awt.event.WindowEvent;
import java.awt.event.WindowListener;
import java.io.File;
import java.io.FileWriter;
import javax.swing.JButton;
import javax.swing.JFrame;
import javax.swing.JLabel;
import javax.swing.JTextField;
import network.NetWork;
public class trainwin extends JFrame implements WindowListener {
private static final long serialVersionUID = 1L;
private JTextField Limit;
private JTextField LearningRate;
private JTextField hLNeurons;
private JButton txttrain;
private JLabel lerror;
private JLabel lters;
private JLabel lsuccess;
private JLabel MSE;
private JLabel emnue;
private JTextField emnuefile;
private NetWork nn;
public trainwin() {
setTitle("自我训练");
setSize(new Dimension(400, 440));
setLocationRelativeTo(null);
setDefaultCloseOperation(EXIT_ON_CLOSE);
setResizable(false);
setLayout(null);
createRightSide();
addWindowListener(this);
setVisible(true);
}
private void createRightSide() {
JLabel lblIterLimit = new JLabel("迭代次数");
lblIterLimit.setBounds(70, 10, 100, 50);
JLabel lblLearningRate = new JLabel("学习率");
lblLearningRate.setBounds(70, 40, 100, 50);
JLabel lblHLNeurons = new JLabel("隐层数量");
lblHLNeurons.setBounds(70, 70, 100, 50);
emnue = new JLabel("动能调节");
emnue.setBounds(70, 100, 100, 50);
emnue.setCursor(new Cursor(Cursor.TEXT_CURSOR));
JLabel lblItersTxt = new JLabel("迭代结束:");
lblItersTxt.setBounds(70, 230, 150, 50);
JLabel lblSuccessTxt = new JLabel("成功率:");
lblSuccessTxt.setBounds(70, 270, 100, 50);
JLabel lblMSETxt = new JLabel("均方误差:");
lblMSETxt.setBounds(70, 310, 150, 50);
Limit = new JTextField();
Limit.setBounds(165, 25, 100, 20);
Limit.setCursor(new Cursor(Cursor.TEXT_CURSOR));
LearningRate = new JTextField();
LearningRate.setBounds(165, 55, 100, 20);
LearningRate.setCursor(new Cursor(Cursor.TEXT_CURSOR));
hLNeurons = new JTextField();
hLNeurons.setBounds(165, 85, 100, 20);
hLNeurons.setCursor(new Cursor(Cursor.TEXT_CURSOR));
emnuefile=new JTextField();
emnuefile.setBounds(165, 115, 100, 20);
emnuefile.setCursor(new Cursor(Cursor.TEXT_CURSOR));
txttrain = new JButton("训练开始");
txttrain.setBounds(100, 150, 100, 30);
txttrain.setFocusPainted(false);
txttrain.setCursor(new Cursor(Cursor.HAND_CURSOR));
txttrain.addActionListener(new TrainListener());
lerror = new JLabel("");
lerror.setBounds(70, 190, 300, 30);
lerror.setForeground(Color.RED);
lters = new JLabel("");
lters.setBounds(160, 230, 100, 50);
lters.setForeground(Color.BLUE);
lsuccess = new JLabel("");
lsuccess.setBounds(160, 270, 100, 50);
lsuccess.setForeground(Color.BLUE);
MSE = new JLabel("");
MSE.setBounds(160, 310, 100, 50);
MSE.setForeground(Color.BLUE);
getContentPane().add(lblIterLimit);
getContentPane().add(lblLearningRate);
getContentPane().add(lblHLNeurons);
getContentPane().add(Limit);
getContentPane().add(LearningRate);
getContentPane().add(hLNeurons);
getContentPane().add(txttrain);
getContentPane().add(lerror);
getContentPane().add(lblItersTxt);
getContentPane().add(lblSuccessTxt);
getContentPane().add(lblMSETxt);
getContentPane().add(lters);
getContentPane().add(lsuccess);
getContentPane().add(MSE);
getContentPane().add(emnue);
getContentPane().add(emnuefile);
}
private class TrainListener implements ActionListener {
public void actionPerformed(ActionEvent e) {
new Thread() {
public void run() {
txttrain.setEnabled(false);
txttrain.setText("训练中...");
lters.setText("");
lsuccess.setText("");
MSE.setText("");
if (!Limit.getText().matches("[1-9][0-9]*"))
lerror.setText("请输入迭代次数.");
else if (!LearningRate.getText().matches("[0-9]*\\.[0-9]+"))
lerror.setText("请输入可用的学习率.");
else if (!hLNeurons.getText().matches("[1-9][0-9]*"))
lerror.setText("请输入有隐层神经元数量 .");
else if (!emnuefile.getText().matches("[0-9]*\\.[0-9]+"))
lerror.setText("请输入有效动量调节值 .");
else {
int iter=Integer.parseInt(Limit.getText());
double rate = Double.parseDouble(LearningRate.getText());
int hide= Integer.parseInt(hLNeurons.getText());
double emnue= Double.parseDouble(emnuefile.getText());
nn = new NetWork(iter, rate, hide,0.05,emnue);
try {
nn.initNetwork();
nn.train();
lters.setText(String.valueOf(iter));
lsuccess.setText(String.format("%f", nn.getsuccess() * 100));
MSE.setText( String.valueOf(nn.geterror()));
FileWriter pf=new FileWriter(new File("./parameter.txt"));
pf.write(String.format("%d\n",iter));
pf.write(String.format("%f\n",rate));
pf.write(String.format("%d\n",hide ));
pf.write(String.format("%f\n",emnue));
pf.close();
} catch (Exception e1) { e1.printStackTrace(); }
}
txttrain.setEnabled(true);
txttrain.setText("训练");
}
}.start();
}
}
public void windowClosed(WindowEvent e) {}
public void windowOpened(WindowEvent e) {}
public void windowIconified(WindowEvent e) {}
public void windowDeiconified(WindowEvent e) {}
public void windowActivated(WindowEvent e) {}
public void windowDeactivated(WindowEvent e) {}
public static void main(String[] args) {
new trainwin();
}
@Override
public void windowClosing(WindowEvent windowevent) {
// TODO Auto-generated method stub
}
}
训练数据及训练10次,学习率0.3,隐层50和动量0.5的个参数数据见:https://download.csdn.net/download/dingyahui123/10636188