这年头机器学习非常的火,神经网络算是机器学习算法中的比较重要的一种。这段时间我也花了些功夫,学了点皮毛,顺便做点学习笔记。
介绍人工神经网络的基本理论的教科书很多。我正在看的是蒋宗礼教授写的《人工神经网络导论》,之所以选这本书,主要是这本比较薄,太厚的书实在是啃不动。这本书写的也比较浅显,用来入门正合适。
看书的同时也在网上找了找人工神经网络的库代码。感觉 FANN 这个库还不错,就顺道学了学这个库的使用方法。
FANN 是个开源的 C 语言实现的人工神经网络库,由于是标准 C 语言写成的,所以对操作系统等的要求很少,在各个平台下都可以运行。而且这个库支持定点运算,在没有浮点处理器的 CPU 上运行会比别的不支持定点运算的库快很多。
FANN 虽然是纯 C 语言写成的,但是按照面向对象的思想构架的,接口设计的很好。有较为详细的文档,用起来很方便。而且已经支持在20多种编程语言环境下使用,比如 C#、 JAVA、Delphi、PYTHON、PHP、PERL、RUBY、Javascript、Matlab、R 等。
下面是一个非常简单的例子。我们用个神经网络来模拟两个布尔变量的与运算。我们的训练数据放到一个文件中,文件名为”and.data”。
内容如下:
4 2 1
0 0
0
0 1
0
1 0
0
1 1
1
其中 4 2 1 表示我们的训练数据集有 4 条训练数据。每个数据有 2 个输入和 1 个输出。之后跟的就是输入与输出数据。
我们知道与运算可以用单层感知器来实现。也就是说可以用 2 层的神经网络来实现(1层输入层、1层输出层)。
#include <doublefann.h>
const unsigned int NUM_INPUT = 2;
const unsigned int NUM_OUTPUT = 1;
const unsigned int NUM_LAYERS = 2;
const unsigned int NUM_NEURONS_HIDDEN = 1;
const float DESIRED_ERROR = (const float) 0.0001;
const unsigned int MAX_EPOCHS = 1000;
const unsigned int EPOCHS_BETWEEN_REPORTS = 10;
int main(int argc, char *argv[])
{
struct fann *ann;
struct fann_train_data *data;
printf("Creating network.\n");
ann = fann_create_standard(NUM_LAYERS, NUM_INPUT, NUM_NEURONS_HIDDEN, NUM_OUTPUT);
data = fann_read_train_from_file("q:\\and.data");
printf("Training network.\n");
fann_train_on_data(ann, data, MAX_EPOCHS, EPOCHS_BETWEEN_REPORTS, DESIRED_ERROR);
printf("Testing network. %f\n", fann_test_data(ann, data));
// fann_save(ann, "q:\\and_float.net");
fann_type input[2];
fann_type *calc_out;
input[0] = 0; input[1] = 0;
calc_out = fann_run(ann, input);
printf("and test (%f,%f) -> %f\n", input[0], input[1], calc_out[0]);
input[0] = 0; input[1] = 1;
calc_out = fann_run(ann, input);
printf("and test (%f,%f) -> %f\n", input[0], input[1], calc_out[0]);
input[0] = 1; input[1] = 0;
calc_out = fann_run(ann, input);
printf("and test (%f,%f) -> %f\n", input[0], input[1], calc_out[0]);
input[0] = 1; input[1] = 1;
calc_out = fann_run(ann, input);
printf("and test (%f,%f) -> %f\n", input[0], input[1], calc_out[0]);
}
程序很简单,有一行被我注释掉了:
fann_save(ann, "q:\\and_float.net");
这个是用来将训练好的神经网络保存起来。以便后面使用。实际上,对于大多数的神经网络应用来说,网络的训练过程与网络的使用都是分开的。因为通常训练神经网络是很费时间的,但是一旦训练好了,就可以一直使用。没有必要每次用之前再训练一次。
一旦将fann 的状态保存起来了。后面在使用就很方便。比如下面这样。
struct fann *ann;
ann = fann_create_from_file("q:\\and_float.net");
运行后输出的结果如下:
Creating network.
Training network.
Max epochs 1000. Desired error: 0.0001000000.
Epochs 1. Current error: 0.2458046824. Bit fail 4.
Epochs 10. Current error: 0.1185930669. Bit fail 2.
Epochs 20. Current error: 0.0314348340. Bit fail 0.
Epochs 30. Current error: 0.0145781031. Bit fail 0.
Epochs 40. Current error: 0.0055742161. Bit fail 0.
Epochs 50. Current error: 0.0024458752. Bit fail 0.
Epochs 60. Current error: 0.0015401742. Bit fail 0.
Epochs 70. Current error: 0.0008485225. Bit fail 0.
Epochs 80. Current error: 0.0004349701. Bit fail 0.
Epochs 90. Current error: 0.0002433052. Bit fail 0.
Epochs 100. Current error: 0.0001541690. Bit fail 0.
Epochs 103. Current error: 0.0000989792. Bit fail 0.
Testing network. 0.000095
and test (0.000000,0.000000) -> 0.000000
and test (0.000000,1.000000) -> 0.012569
and test (1.000000,0.000000) -> 0.007818
and test (1.000000,1.000000) -> 0.987361
可以看到拟合的结果非常好。我们还知道单级感知器无法表达异或运算。下面就可以做个实验,将与运算改为异或运算。训练数据如下:
4 2 1
0 0
0
0 1
1
1 0
1
1 1
0
程序基本不需要变。运行结果是这样的。
Creating network.
Training network.
Max epochs 1000. Desired error: 0.0001000000.
Epochs 1. Current error: 0.2500112057. Bit fail 4.
Epochs 10. Current error: 0.2502280176. Bit fail 4.
Epochs 20. Current error: 0.2500012517. Bit fail 4.
Epochs 30. Current error: 0.2500000000. Bit fail 4.
Epochs 40. Current error: 0.2500000000. Bit fail 4.
Epochs 50. Current error: 0.2500000000. Bit fail 4.
Epochs 60. Current error: 0.2500000000. Bit fail 4.
Epochs 70. Current error: 0.2500000000. Bit fail 4.
Epochs 80. Current error: 0.2500000000. Bit fail 4.
Epochs 90. Current error: 0.2500000000. Bit fail 4.
Epochs 100. Current error: 0.2500000298. Bit fail 4.
Epochs 110. Current error: 0.2500000000. Bit fail 4.
Epochs 120. Current error: 0.2500000000. Bit fail 4.
Epochs 130. Current error: 0.2500000000. Bit fail 4.
Epochs 140. Current error: 0.2499999851. Bit fail 4.
Epochs 150. Current error: 0.2500000000. Bit fail 4.
Epochs 160. Current error: 0.2500000000. Bit fail 4.
Epochs 170. Current error: 0.2500000000. Bit fail 4.
Epochs 180. Current error: 0.2500000000. Bit fail 4.
Epochs 190. Current error: 0.2500000000. Bit fail 4.
Epochs 200. Current error: 0.2500000000. Bit fail 4.
Epochs 210. Current error: 0.2500000000. Bit fail 4.
Epochs 220. Current error: 0.2500000000. Bit fail 4.
Epochs 230. Current error: 0.2500000000. Bit fail 4.
Epochs 240. Current error: 0.2500000000. Bit fail 4.
Epochs 250. Current error: 0.2500000000. Bit fail 4.
Epochs 260. Current error: 0.2500000000. Bit fail 4.
Epochs 270. Current error: 0.2500000000. Bit fail 4.
Epochs 280. Current error: 0.2500000000. Bit fail 4.
Epochs 290. Current error: 0.2500000000. Bit fail 4.
Epochs 300. Current error: 0.2500000000. Bit fail 4.
Epochs 310. Current error: 0.2500000000. Bit fail 4.
Epochs 320. Current error: 0.2500000000. Bit fail 4.
Epochs 330. Current error: 0.2500000000. Bit fail 4.
Epochs 340. Current error: 0.2500000000. Bit fail 4.
Epochs 350. Current error: 0.2500000000. Bit fail 4.
Epochs 360. Current error: 0.2500000000. Bit fail 4.
Epochs 370. Current error: 0.2500000000. Bit fail 4.
Epochs 380. Current error: 0.2500000000. Bit fail 4.
Epochs 390. Current error: 0.2500000000. Bit fail 4.
Epochs 400. Current error: 0.2500000000. Bit fail 4.
Epochs 410. Current error: 0.2500000298. Bit fail 4.
Epochs 420. Current error: 0.2500000000. Bit fail 4.
Epochs 430. Current error: 0.2500000000. Bit fail 4.
Epochs 440. Current error: 0.2500000000. Bit fail 4.
Epochs 450. Current error: 0.2499999851. Bit fail 4.
Epochs 460. Current error: 0.2500000000. Bit fail 4.
Epochs 470. Current error: 0.2500000000. Bit fail 4.
Epochs 480. Current error: 0.2500000000. Bit fail 4.
Epochs 490. Current error: 0.2500000000. Bit fail 4.
Epochs 500. Current error: 0.2500000000. Bit fail 4.
Epochs 510. Current error: 0.2500000000. Bit fail 4.
Epochs 520. Current error: 0.2500000000. Bit fail 4.
Epochs 530. Current error: 0.2500000000. Bit fail 4.
Epochs 540. Current error: 0.2500000000. Bit fail 4.
Epochs 550. Current error: 0.2500000000. Bit fail 4.
Epochs 560. Current error: 0.2500000000. Bit fail 4.
Epochs 570. Current error: 0.2500000000. Bit fail 4.
Epochs 580. Current error: 0.2500000000. Bit fail 4.
Epochs 590. Current error: 0.2500000000. Bit fail 4.
Epochs 600. Current error: 0.2500000000. Bit fail 4.
Epochs 610. Current error: 0.2500000000. Bit fail 4.
Epochs 620. Current error: 0.2500000000. Bit fail 4.
Epochs 630. Current error: 0.2500000000. Bit fail 4.
Epochs 640. Current error: 0.2500000000. Bit fail 4.
Epochs 650. Current error: 0.2500000000. Bit fail 4.
Epochs 660. Current error: 0.2500000000. Bit fail 4.
Epochs 670. Current error: 0.2500000000. Bit fail 4.
Epochs 680. Current error: 0.2500000000. Bit fail 4.
Epochs 690. Current error: 0.2500000000. Bit fail 4.
Epochs 700. Current error: 0.2500000000. Bit fail 4.
Epochs 710. Current error: 0.2500000000. Bit fail 4.
Epochs 720. Current error: 0.2500000298. Bit fail 4.
Epochs 730. Current error: 0.2500000000. Bit fail 4.
Epochs 740. Current error: 0.2500000000. Bit fail 4.
Epochs 750. Current error: 0.2500000000. Bit fail 4.
Epochs 760. Current error: 0.2499999851. Bit fail 4.
Epochs 770. Current error: 0.2500000000. Bit fail 4.
Epochs 780. Current error: 0.2500000000. Bit fail 4.
Epochs 790. Current error: 0.2500000000. Bit fail 4.
Epochs 800. Current error: 0.2500000000. Bit fail 4.
Epochs 810. Current error: 0.2500000000. Bit fail 4.
Epochs 820. Current error: 0.2500000000. Bit fail 4.
Epochs 830. Current error: 0.2500000000. Bit fail 4.
Epochs 840. Current error: 0.2500000000. Bit fail 4.
Epochs 850. Current error: 0.2500000000. Bit fail 4.
Epochs 860. Current error: 0.2500000000. Bit fail 4.
Epochs 870. Current error: 0.2500000000. Bit fail 4.
Epochs 880. Current error: 0.2500000000. Bit fail 4.
Epochs 890. Current error: 0.2500000000. Bit fail 4.
Epochs 900. Current error: 0.2500000000. Bit fail 4.
Epochs 910. Current error: 0.2500000000. Bit fail 4.
Epochs 920. Current error: 0.2500000000. Bit fail 4.
Epochs 930. Current error: 0.2500000000. Bit fail 4.
Epochs 940. Current error: 0.2500000000. Bit fail 4.
Epochs 950. Current error: 0.2500000000. Bit fail 4.
Epochs 960. Current error: 0.2500000000. Bit fail 4.
Epochs 970. Current error: 0.2500000000. Bit fail 4.
Epochs 980. Current error: 0.2500000000. Bit fail 4.
Epochs 990. Current error: 0.2500000000. Bit fail 4.
Epochs 1000. Current error: 0.2500000000. Bit fail 4.
Testing network. 0.250000
xor test (0.000000,0.000000) -> 0.500042
xor test (0.000000,1.000000) -> 0.500058
xor test (1.000000,0.000000) -> 0.500064
xor test (1.000000,1.000000) -> 0.500080
可以看到训练无法收敛。如果将神经网络改为 3 层,就具备表达异或的能力了。
程序只需改两行:
const unsigned int NUM_LAYERS = 3;
const unsigned int NUM_NEURONS_HIDDEN = 2;
这两行的意思是将网络层数改为 3,中间那层包含 2 个神经元。之后运行的结果如下:
Creating network.
Training network.
Max epochs 1000. Desired error: 0.0001000000.
Epochs 1. Current error: 0.2500394285. Bit fail 4.
Epochs 10. Current error: 0.2503248155. Bit fail 4.
Epochs 20. Current error: 0.2500005960. Bit fail 4.
Epochs 30. Current error: 0.2500000596. Bit fail 4.
Epochs 40. Current error: 0.2500004768. Bit fail 4.
Epochs 50. Current error: 0.2487613261. Bit fail 4.
Epochs 60. Current error: 0.2448616028. Bit fail 4.
Epochs 70. Current error: 0.2345527709. Bit fail 4.
Epochs 80. Current error: 0.1928165257. Bit fail 2.
Epochs 90. Current error: 0.0843421519. Bit fail 1.
Epochs 100. Current error: 0.0168493092. Bit fail 0.
Epochs 110. Current error: 0.0046176752. Bit fail 0.
Epochs 120. Current error: 0.0023012348. Bit fail 0.
Epochs 130. Current error: 0.0014233238. Bit fail 0.
Epochs 140. Current error: 0.0008981032. Bit fail 0.
Epochs 150. Current error: 0.0005876040. Bit fail 0.
Epochs 160. Current error: 0.0003150564. Bit fail 0.
Epochs 170. Current error: 0.0001641331. Bit fail 0.
Epochs 175. Current error: 0.0000839768. Bit fail 0.
Testing network. 0.000091
xor test (0.000000,0.000000) -> 0.011072
xor test (0.000000,1.000000) -> 0.993215
xor test (1.000000,0.000000) -> 0.992735
xor test (1.000000,1.000000) -> 0.011975
结果非常的好。