本程序用于训练集,将训练集割成两部分,一部分训练,一部分测 试。
由于训练集格式是先放答案再放数据(就是28*28图片展开的784维向量),因此先读入答案,再读入数据。
注意:运行程序前,请删除数据的第一行(label,pixel0,pixel1……)不然会RE。
#include
using namespace std;
const int A=784,B=28,C=10;
const double L=0.2;
class BP{
private:
const int IN,HN,ON;
double lambda;
bool isFirstTime;
struct neuron{
double I,O,theta;
};
vector InputNeurons;
vector HiddenNeurons;
vector OutputNeurons;
double WeightIH[A+1][B+1],WeightHO[B+1][C+1];//两个邻接矩阵用于表示边权
double e[C+1];//e[i]=T[i]-OutputNeurons[i].O
double rand(const double &x,const double &y);//产生[x,y]的随机实数
double f(const double &x);//Sigmoid函数
void setWeightRandomly();//随机赋边权
void FeedForward();//前向传输
//逆向反馈,T是实际的答案,用于监督
void BackPropagation(const vector<double> &T);
public:
BP(int a=0,int b=0,int c=0,double d=0.5):
IN(a),HN(b),ON(c),lambda(d),isFirstTime(true){
InputNeurons.resize(IN+1);//方便下标从1开始
HiddenNeurons.resize(HN+1);
OutputNeurons.resize(ON+1);
for(int i=1;i<=IN;++i)
InputNeurons[i]=(neuron){0,0,0};
for(int i=1;i<=HN;++i)
HiddenNeurons[i]=(neuron){0,0,0};
for(int i=1;i<=ON;++i)
OutputNeurons[i]=(neuron){0,0,0};
srand(time(0));//你懂得
}
//用一组数据训练这个模型
void train(const vector<double> &data,const vector<double> &ans);
//用一组数据检测这个模型
vector<double> test(const vector<double> &data);
//设定lambda的值
void setLambda(const double &x);
};
inline double BP::rand(const double &x,const double &y){
return (double)std::rand()*1.0/RAND_MAX*(y-x)+x;
}
inline double BP::f(const double &x){
return 1.0/(1+exp(-x));
}
inline void BP::setWeightRandomly(){//给每条边赋一个(-1,1)的随机的值
int i,j;
for(i=1;i<=IN;++i)
for(j=1;j<=HN;++j)
WeightIH[i][j]=rand(-1,1);
for(i=1;i<=HN;++i)
for(j=1;j<=ON;++j)
WeightHO[i][j]=rand(-1,1);
for(i=1;i<=HN;++i)HiddenNeurons[i].theta=rand(0,1);
for(i=1;i<=ON;++i)OutputNeurons[i].theta=rand(0,1);
}
inline void BP::FeedForward(){//前向传输
int i,j;
for(j=1;j<=HN;++j){
neuron &p=HiddenNeurons[j];
for(i=1,p.I=0;i<=IN;++i)
p.I+=WeightIH[i][j]*InputNeurons[i].O;
p.O=f(p.I+=p.theta);
}
for(j=1;j<=ON;++j){
neuron &p=OutputNeurons[j];
for(i=1,p.I=0;i<=HN;++i)
p.I+=WeightHO[i][j]*HiddenNeurons[i].O;
p.O=f(p.I+=p.theta);
}
}
inline void BP::BackPropagation(const vector<double> &T){
int i,j,k;
for(i=1;i<=ON;++i)e[i]=T[i]-OutputNeurons[i].O;
for(k=1;k<=ON;++k)for(j=1;j<=HN;++j){
WeightHO[j][k]+=lambda*e[k]*HiddenNeurons[j].O;
}
for(k=1;k<=ON;++k) OutputNeurons[k].theta+=lambda*e[k];
for(j=1;j<=HN;++j){
double sum;
for(k=1,sum=0;k<=ON;++k)sum+=e[k]*WeightHO[j][k];
for(i=1;i<=IN;++i){
WeightIH[i][j]+=
lambda*HiddenNeurons[j].O*(1-HiddenNeurons[j].O)*InputNeurons[i].O*sum;
}
HiddenNeurons[j].theta+=lambda*HiddenNeurons[j].O*(1-HiddenNeurons[j].O)*sum;
}
}
inline void BP::train(const vector<double> &data,const vector<double> &ans){
int i;
for(i=1;i<=IN;++i)InputNeurons[i].O=data[i];
if(isFirstTime) setWeightRandomly(),isFirstTime=false;
FeedForward();
BackPropagation(ans);
}
inline vector<double> BP::test(const vector<double> &data){
int i;
for(i=1;i<=IN;++i)InputNeurons[i].O=data[i];//
FeedForward();
vector<double> ans;
ans.push_back(0);
for(i=1;i<=ON;++i) ans.push_back(OutputNeurons[i].O);
return ans;
}
inline void BP::setLambda(const double &x){lambda=x;}
inline int read(){
int x=0,ch=getchar();
while(!isdigit(ch)) ch=getchar();
while(isdigit(ch))x=x*10+ch-48,ch=getchar();
return x;
}
vector<double> create1(int pos){
vector<double> v(11,0);v[pos]=1;
return v;
}
int maxpos(const vector<double> &a){
double tmp=-1e9;
int i,pos;
for(i=1;iif(a[i]>tmp) tmp=a[i],pos=i;
return pos;
}
BP solver(A,B,C,L);
int main(){//归一化非常重要!!不归BUG砸死你
int n,i,j,right;
vector<double> input,ans;
scanf("%d",&n);
for(i=1;i<=n;++i){
input.clear();input.push_back(0);
ans .clear();ans .push_back(0);
{int x;scanf("%d",&x);ans=create1(x);}
for(j=1;j<=A;++j){
double x;
scanf(",%lf",&x);
input.push_back(x/255.0);
}
solver.train(input,ans);
}
puts("学习完成!");
scanf("%d",&n);
for(i=1,right=0;i<=n;++i){
input.clear();input.push_back(0);
ans .clear();ans .push_back(0);
int standardAns,calcAns;
scanf("%d",&standardAns);
for(j=1;j<=A;++j){double x;scanf(",%lf",&x);input.push_back(x);}
ans=solver.test(input);
calcAns=maxpos(ans);
if(calcAns==standardAns) right++;
printf("%d\n",calcAns);
}
fclose(stdin);fclose(stdout);
return 0;
}