Capsules胶囊神经网络代码测试(1)——测试mnist数据

原地址:https://github.com/Sarasra/models/tree/master/research/capsules

运行必须要有GPU

下载两个数据包,第一个是 MNIST tfrecords,第二个是已经训练好的模型文件。

  • Download and extract MNIST tfrecords to $DATA_DIR/ from:https://storage.googleapis.com/capsule_toronto/mnist_data.tar.gz
  • Download and extract MNIST model checkpoint to $CKPT_DIR from:https://storage.googleapis.com/capsule_toronto/mnist_checkpoints.tar.gz

直接放入对应的文件夹下,结构如图所示:

Capsules胶囊神经网络代码测试(1)——测试mnist数据_第1张图片

首先利用训练好的模型进行测试(由于Python版本原因,可能需要把几个py文件中的xrange全部修改为range)在文件夹下打开命令行,输入:

python experiment.py --data_dir=./mnist_data/ --train=false --summary_dir=./tmp/ --checkpoint=./mnist_checkpoint/model.ckpt-1

 结果如图:

Capsules胶囊神经网络代码测试(1)——测试mnist数据_第2张图片

准确率达到了100%-0.24%=99.76%

然后自己训练网络,输入:

python experiment.py --data_dir=./mnist_data/ --max_steps=300000 --summary_dir=./tmp/attempt0/ 

等待结果,为了节省时间,我把max_steps设置成3000,代码设置的是每1500次自动保存一次模型

Capsules胶囊神经网络代码测试(1)——测试mnist数据_第3张图片

训练完成后在.\capsules\tmp\attempt0\train文件夹下会生成相应的模型文件

Capsules胶囊神经网络代码测试(1)——测试mnist数据_第4张图片

model.ckpt-3000.data-00000-of-00001 改名 model.ckpt-1.data-00000-of-00001

model.ckpt-3000.index 改名 model.ckpt-1.index

model.ckpt-3000.meta 改名 model.ckpt-1.meta

然后放到mnist_checkpoint文件夹下,替换原文件,再次输入

python experiment.py --data_dir=./mnist_data/ --train=false --summary_dir=./tmp/ --checkpoint=./mnist_checkpoint/model.ckpt-1

测试结果:

Capsules胶囊神经网络代码测试(1)——测试mnist数据_第5张图片

准确率达到了100%-0.94%=99.06%,迭代3000次就有这么高的准确率,可见胶囊神经网络确实很强大。

下一篇:Capsules胶囊神经网络代码测试(2)——测试cifar10数据

你可能感兴趣的:(TensorFlow,胶囊神经网络)