Caffe 代码中的 solver.net.forward() , solver.test_nets[0].forward() 和 solver.step(1) 区别和作用。
三个函数都是将批量大小(batch_size)的图片送到网络, solver.net.forward() 和 solver.test_nets[0].forward() 是将batch_size个图片送到网络中去,只有前向传播(Forward Propagation,BP),solver.net.forward()作用于训练集,solver.test_nets[0].forward() 作用于测试集,一般用于获得测试集的正确率。solver.step(1) 也是将batch_size个图片送到网络中去,不过 solver.step(1) 不仅有FP,而且还有反向传播(Back Propagation,BP)!这样就可以更新整个网络的权值(weights),同时得到该batch的loss。
让我们用代码实例来验证一下上述函数的作用。
下面的例子来自我的博客 使用 Caffe Python 编写 LeNetB1 ,在 B1 中,我定义了网络,我们首先加载这个网路。该网络的训练集batch_size=64,测试集batch_size=100。为获得更高的效率,我在Jupyter Notebook 中实现如下代码。
为了方便,我先给出mnist训练集前192个数字和测试集前200个数字。
代码1:
from pylab import *
import caffe
%matplotlib inline
caffe.set_device(0)
caffe.set_mode_gpu() #设置GPU,不是GPU环境的使用caffe.set_mode_cpu()
solver = None
solver = caffe.SGDSolver('C:/Users/Admin512/Desktop/MyStudy/caffe_python/LeNet/mnist/lenet_auto_solver.prototxt')
在运行网络之前,我们先看一下网络的层次:
代码2:
[(k, v.data.shape) for k, v in solver.net.blobs.items()]
输出2:
[('data', (64L, 1L, 28L, 28L)),
('label', (64L,)),
('conv1', (64L, 20L, 24L, 24L)),
('pool1', (64L, 20L, 12L, 12L)),
('conv2', (64L, 50L, 8L, 8L)),
('pool2', (64L, 50L, 4L, 4L)),
('fc1', (64L, 500L)),
('score', (64L, 10L)),
('loss', ())]
看一下第一层“data”中batch图像中的第一个图像是哪个数字?
代码3:
A = solver.net.blobs['data']
print(A.data.shape)
A.data[0,0]
imshow(A.data[0,0], cmap='gray');
黑乎乎的一片,并没有显示数字,说明数据没有传入到网络中。
下面我们运行:solver.net.forward() :
代码4:
solver.net.forward() # train net
数据已经传进来了,很明显这是训练集第一个图。
我们再次运行代码4,然后运行代码3,输出如下:
是不是第65(索引为64)个?是!说明执行solver.net.forward() 后,训练集数据按照batch_size=64输入带网络的。
我们再来看 solver.test_nets[0].forward() :
代码5:
solver.test_nets[0].forward()
# 显示test_net传入的数据
B = solver.test_nets[0].blobs['data'].data
imshow(B[0,0], cmap='gray');
是测试集第1个图像。
猜想再次运行代码5,应该显示的是第101个图像(索引位100),我们运行一下:
猜想正确!
现在知道 solver.net.forward() 和 solver.test_nets[0].forward() 的作用了,下面我们看看 solver.step(1) 的作用。
首先初始化整个网络,就是再次运行代码1。
代码6:
solver.step(1)
A = solver.net.blobs['data']
print(A.data.shape)
A.data[0,0]
imshow(A.data[0,0], cmap='gray');
两次运行代码6,发现两次出现的图像和两次运行solver.net.forward() 的效果一样,说明 代码6 执行的是FP+DP。
完。