本文根据西瓜书第五章中给出的公式编写,书中给出了全连接神经网络的实现逻辑,本文在此基础上编写了Mnist10手写10个数字的案例,网上也有一些其他手写的例子参考。demo并没有用UE而是使用unity进行编写,方便且易于查错。
该案例仅作为学习,博主也只是业余时间自学一些机器学习知识,欢迎各位路过的大佬留下建议~
源码下载地址:
https://download.csdn.net/download/grayrail/87802798
首先理顺西瓜书第五章中的各符号的意义:
西瓜书第五章直接讲了反向传播,所以在这之前简单讲一下正向传播。
以上图为例,输入层的维度是[3],隐层的维度是[n,3],输出层的维度是[4,n],因此最终输出维度是[4]。输入层通常就是原始输入的信息,隐层用于超参数与中间环节计算,隐层的维度是n * m,m是输入层的数据自身维度,n可以理解为n种可能性(博主自己的理解),例如隐层的第二个维度是50,那么就是假设了50种可能性进行训练。
基于此,那么正向传播的流程如下:
反向传播的难点之一是链式求导,西瓜书中已经帮我们把求导过程写好了,这里我先讲tips,再梳理反向传播流程。
书中对损失函数E只提了一次,后续操作中用yhat-y直接减去的值带入sigmoid的偏导数公式,看的有点让人糊涂。
后来经过查询ChatGTP和别人的一些文章,了解到直接求差就是1/2 MSE求偏导数之后的结果,如果不用1/2 MSE作为损失函数,就把yhat-y换成其他公式。
书中的公式有点乱,下面给出按照顺序的梳理图:
基于此,那么反向传播的流程如下:
后来看书时,发现神经网络除了正向传播和反向传播,还有2个比较重要的东西损失函数、优化器,损失函数计算了每个超参数的梯度,优化器决定如何应用这些梯度。
例如本文末尾给的c#代码就用到了动量,也是一种优化器,像Tensorflow一些demo中常用的Adam优化器就是RMSProp+动量的做法,大致指动量的幅度是动态变化的,甚至学习率中还加了一个超参数去控制。
了解RMSProp优化器可以看一下《动手学深度学习》7.6.1章节,有相关实现。
以Mnist案例为例,该案例使用神经网络识别28x28像素内图片的0-9个手写数字,接下来给出C#版本的Mnist代码实现,脚本挂载后有3种模式:
该案例在西瓜书的基础上又加入了momentum动量、softmax、Dropout、初始随机值范围修改(-1,1),softmax使用《深度学习入门 基于PYTHON的理论与实现》一书中提供的公式。经过一些轮次训练后的运行结果:
c#代码如下:
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using UnityEngine;
public class TestMnist10 : MonoBehaviour
{
public enum EMode { Train, DrawImage, User }
const float kDropoutProb = 0.4f;
///
/// d个输入神经元
///
int d;
///
/// q个隐层神经元
///
int q;
///
/// l个输出神经元
///
int l;
///
/// 输入层原始值
///
float[] x;
///
/// 输入层到隐层神经元的连接权
///
float[][] v;
///
/// 缓存上一次的v权值
///
float[][] lastVMomentum;
///
/// 隐层神经元到输出层神经元的连接权
///
float[][] w;
///
/// 缓存上一次的w权值
///
float[][] lastWMomentum;
float[] wDropout;
///
/// 反向传播g项
///
float[] g;
///
/// 反向传播e项
///
float[] e;
///
/// 隐层接收到的输入(通常List长度是隐层长度)
///
List<float> b;
///
/// 输出层接收到的输入(通常List长度是输出层长度)
///
List<float> yhats;
///
/// 输出层神经元的阈值
///
float[] theta;
///
/// 隐层神经元的阈值
///
float[] gamma;
public void Init(int inputLayerCount, int hiddenLayerCount, int outputLayerCount)
{
d = inputLayerCount;
q = hiddenLayerCount;
l = outputLayerCount;
x = new float[inputLayerCount];
b = new List<float>(1024);
yhats = new List<float>(1024);
e = new float[hiddenLayerCount];
g = new float[outputLayerCount];
v = GenDimsArray(typeof(float), new int[] { q, d }, 0, () => UnityEngine.Random.Range(-1f, 1f)) as float[][];
w = GenDimsArray(typeof(float), new int[] { l, q }, 0, () => UnityEngine.Random.Range(-1f, 1f)) as float[][];
wDropout = GenDimsArray(typeof(float), new int[] { l }, 0, null) as float[];
lastVMomentum = GenDimsArray(typeof(float), new int[] { q, d }, 0, null) as float[][];
lastWMomentum = GenDimsArray(typeof(float), new int[] { l, q }, 0, null) as float[][];
theta = GenDimsArray(typeof(float), new int[] { l }, 0, () => UnityEngine.Random.Range(-1f, 1f)) as float[];
gamma = GenDimsArray(typeof(float), new int[] { q }, 0, () => UnityEngine.Random.Range(-1f, 1f)) as float[];
}
public void ForwardPropagation(float[] input, out int output)
{
x = input;
for (int jIndex = 0; jIndex < l; ++jIndex)
{
var r = UnityEngine.Random.value < kDropoutProb ? 1f : 0f;
wDropout[jIndex] = r;
}
b.Clear();
for (int hIndex = 0; hIndex < q; ++hIndex)
{
var sum = 0f;
for (int iIndex = 0; iIndex < d; ++iIndex)
{
var u = input[iIndex] * v[hIndex][iIndex];
sum += u;
}
var alpha = sum - gamma[hIndex];
var r = Sigmoid(alpha);
b.Add(r);
}
yhats.Clear();
for (int jIndex = 0; jIndex < l; ++jIndex)
{
var sum = 0f;
for (int hIndex = 0; hIndex < q; ++hIndex)
{
var u = b[hIndex] * w[jIndex][hIndex];
sum += u;
}
var beta = sum - theta[jIndex];
var r = Sigmoid(beta);
//实际使用时关闭Dropout,训练时打开
if (_EnableDropout)
{
r *= wDropout[jIndex];
r /= kDropoutProb;
}
yhats.Add(r);
}
var softmaxResult = Softmax(yhats.ToArray());
for (int i = 0; i < yhats.Count; i++)
{
yhats[i] = softmaxResult[i];
}
int index = 0;
float maxValue = yhats[0];
for (int jIndex = 0; jIndex < l; ++jIndex)
{
if (yhats[jIndex] > maxValue)
{
maxValue = yhats[jIndex];
index = jIndex;
}
}
output = index;
}
public void BackPropagation(float[] correct)
{
const float kEta1 = 0.03f;
const float kEta2 = 0.01f;
const float kMomentum = 0.3f;
for (int jIndex = 0; jIndex < l; ++jIndex)
{
var yhat = this.yhats[jIndex];
var y = correct[jIndex];
g[jIndex] = yhat * (1f - yhat) * (y - yhat);
}
for (int hIndex = 0; hIndex < q; ++hIndex)
{
var bh = b[hIndex];
var sum = 0f;
//这个for循环的内容,个人感觉是精妙之处,可以拿到别的神经元的梯度。
for (int jIndex = 0; jIndex < l; ++jIndex)
sum += w[jIndex][hIndex] * g[jIndex];
e[hIndex] = bh * (1f - bh) * sum;
}
for (int jIndex = 0; jIndex < l; ++jIndex)
{
theta[jIndex] += -kEta1 * g[jIndex];
}
for (int hIndex = 0; hIndex < q; ++hIndex)
{
for (int jIndex = 0; jIndex < l; ++jIndex)
{
var bh = b[hIndex];
var delta = kMomentum * lastWMomentum[jIndex][hIndex] + kEta1 * g[jIndex] * bh;
//实际使用时关闭Dropout,训练时打开
if (_EnableDropout)
{
var dropout = wDropout[jIndex];
delta *= dropout;
delta /= kDropoutProb;
}
w[jIndex][hIndex] += delta;
lastWMomentum[jIndex][hIndex] = delta;
}
}
for (int hIndex = 0; hIndex < q; ++hIndex)
{
gamma[hIndex] += -kEta2 * e[hIndex];
}
for (int hIndex = 0; hIndex < q; ++hIndex)
{
for (int iIndex = 0; iIndex < d; ++iIndex)
{
var delta = kMomentum * lastVMomentum[hIndex][iIndex] + kEta2 * e[hIndex] * x[iIndex];
v[hIndex][iIndex] += delta;
lastVMomentum[hIndex][iIndex] = delta;
}
}
}
void Start()
{
Init(784, 64, 10);
}
EMode _Mode;
int[] _DrawNumberImage;
bool _EnableDropout;
string _DataPath;
float Sigmoid(float val)
{
return 1f / (1f + Mathf.Exp(-val));
}
float[] Softmax(float[] inputs)
{
float[] outputs = new float[inputs.Length];
float maxInput = inputs.Max();
for (int i = 0; i < inputs.Length; i++)
{
outputs[i] = Mathf.Exp(inputs[i] - maxInput);
}
float expSum = outputs.Sum();
for (int i = 0; i < outputs.Length; i++)
{
outputs[i] /= expSum;
}
return outputs;
}
float[] GetOneHot(string input)
{
if (input.StartsWith("0"))
return new float[] { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0 };
if (input.StartsWith("1"))
return new float[] { 0, 1, 0, 0, 0, 0, 0, 0, 0, 0 };
if (input.StartsWith("2"))
return new float[] { 0, 0, 1, 0, 0, 0, 0, 0, 0, 0 };
if (input.StartsWith("3"))
return new float[] { 0, 0, 0, 1, 0, 0, 0, 0, 0, 0 };
if (input.StartsWith("4"))
return new float[] { 0, 0, 0, 0, 1, 0, 0, 0, 0, 0 };
if (input.StartsWith("5"))
return new float[] { 0, 0, 0, 0, 0, 1, 0, 0, 0, 0 };
if (input.StartsWith("6"))
return new float[] { 0, 0, 0, 0, 0, 0, 1, 0, 0, 0 };
if (input.StartsWith("7"))
return new float[] { 0, 0, 0, 0, 0, 0, 0, 1, 0, 0 };
if (input.StartsWith("8"))
return new float[] { 0, 0, 0, 0, 0, 0, 0, 0, 1, 0 };
else
return new float[] { 0, 0, 0, 0, 0, 0, 0, 0, 0, 1 };
}
void Shuffle<T>(List<T> cardList)
{
int tempIndex = 0;
T temp = default;
for (int i = 0; i < cardList.Count; ++i)
{
tempIndex = UnityEngine.Random.Range(0, cardList.Count);
temp = cardList[tempIndex];
cardList[tempIndex] = cardList[i];
cardList[i] = temp;
}
}
///
/// 快速得到多维数组
///
Array GenDimsArray(Type type, int[] dims, int deepIndex, Func<object> initFunc = null)
{
if (deepIndex < dims.Length - 1)
{
var sub_template = GenDimsArray(type, dims, deepIndex + 1, null);
var current = Array.CreateInstance(sub_template.GetType(), dims[deepIndex]);
for (int i = 0; i < dims[deepIndex]; ++i)
{
var sub = GenDimsArray(type, dims, deepIndex + 1, initFunc);
current.SetValue(sub, i);
}
return current;
}
else
{
var arr = Array.CreateInstance(type, dims[deepIndex]);
if (initFunc != null)
{
for (int i = 0; i < arr.Length; ++i)
arr.SetValue(initFunc(), i);
}
return arr;
}
}
void OnGUI()
{
if (_DrawNumberImage == null)
_DrawNumberImage = new int[784];
GUILayout.BeginHorizontal();
if (GUILayout.Button("Draw Image Mode"))
{
_Mode = EMode.DrawImage;
Array.Clear(_DrawNumberImage, 0, _DrawNumberImage.Length);
}
if (GUILayout.Button("User Mode"))
{
_Mode = EMode.User;
Array.Clear(_DrawNumberImage, 0, _DrawNumberImage.Length);
}
if (GUILayout.Button("Train Mode"))
{
_Mode = EMode.Train;
_DataPath = Directory.GetCurrentDirectory() + "/TrainData";
}
GUILayout.EndHorizontal();
var lastRect = GUILayoutUtility.GetLastRect();
switch (_Mode)
{
case EMode.Train:
{
GUILayout.BeginHorizontal();
GUILayout.Label("Data Path: ");
_DataPath = GUILayout.TextField(_DataPath);
GUILayout.EndHorizontal();
_EnableDropout = GUILayout.Button("dropout(" + (_EnableDropout ? "True" : "False") + ")")
? !_EnableDropout : _EnableDropout;
if (GUILayout.Button("Train 10"))
{
var files = Directory.GetFiles(_DataPath);
List<(string, float[])> datas = new(512);
for (int i = 0; i < files.Length; ++i)
{
var strArr = File.ReadAllText(files[i]).Split(',');
datas.Add((Path.GetFileNameWithoutExtension(files[i]), Array.ConvertAll(strArr, m => float.Parse(m))));
}
for (int s = 0; s < 10; ++s)
{
Shuffle(datas);
for (int i = 0; i < datas.Count; ++i)
{
ForwardPropagation(datas[i].Item2, out int output);
UnityEngine.Debug.Log(" Input Number: " + datas[i].Item1 + " output: " + output + "");
BackPropagation(GetOneHot(datas[i].Item1));
//break;
}
}
}
}
break;
case EMode.DrawImage:
{
lastRect.y += 50f;
var size = 20f;
var spacing = 2f;
var mousePosition = Event.current.mousePosition;
var mouseLeftIsPress = Input.GetMouseButton(0);
var mouseRightIsPress = Input.GetMouseButton(1);
var containSpacingSize = size + spacing;
for (int y = 0, i = 0; y < 28; ++y)
{
for (int x = 0; x < 28; ++x)
{
var rect = new Rect(lastRect.x + x * containSpacingSize, lastRect.y + y * containSpacingSize, size, size);
GUI.DrawTexture(rect, _DrawNumberImage[i] == 1 ? Texture2D.blackTexture : Texture2D.whiteTexture);
if (rect.Contains(mousePosition))
{
if (mouseLeftIsPress)
_DrawNumberImage[i] = 1;
else if (mouseRightIsPress)
_DrawNumberImage[i] = 0;
}
++i;
}
}
if (GUILayout.Button("Save"))
{
File.WriteAllText(Directory.GetCurrentDirectory() + "/Assets/tmp.txt", string.Join(",", _DrawNumberImage));
}
}
break;
case EMode.User:
{
lastRect.y += 150f;
var size = 20f;
var spacing = 2f;
var mousePosition = Event.current.mousePosition;
var mouseLeftIsPress = Input.GetMouseButton(0);
var mouseRightIsPress = Input.GetMouseButton(1);
var containSpacingSize = size + spacing;
for (int y = 0, i = 0; y < 28; ++y)
{
for (int x = 0; x < 28; ++x)
{
var rect = new Rect(lastRect.x + x * containSpacingSize, lastRect.y + y * containSpacingSize, size, size);
GUI.DrawTexture(rect, _DrawNumberImage[i] == 1 ? Texture2D.blackTexture : Texture2D.whiteTexture);
if (rect.Contains(mousePosition))
{
if (mouseLeftIsPress)
_DrawNumberImage[i] = 1;
else if (mouseRightIsPress)
_DrawNumberImage[i] = 0;
}
++i;
}
}
if (GUILayout.Button("Recognize"))
{
ForwardPropagation(Array.ConvertAll(_DrawNumberImage, m => (float)m), out int output);
Debug.Log("output: " + output);
}
break;
}
}
}
}
参考文章