在python中编写caffe的prototxt文件

利用python可以创建caffe的网络定义的prototxt文件,利用这种方法的一个好处就是:可以保证training,testing和deploy网络的一致性!

下面是利用python编写train和test的prototxt文件的一个小事例:

import caffe
from caffe import layers as L
from caffe import params as P


# Function: Set lenet net
def init_net_lenet(netName,netType,batch_size,prototxt_root,data_root):

    # --------
    # set the type and name of the net, e.g., train_lenet, or test_lenet
    NetTypeName = netType + '_'+netName

    # --------
    n = caffe.NetSpec()

    # --------
    # set the input layer 
    n.data, n.label = L.ImageData(
            image_data_param={"batch_size": batch_size, "is_color" : False,"shuffle": True},
            source= data_root+netType+".data", transform_param=dict(scale=1./255), ntop=2)

    # ------
    # set other layers
    n.conv1 = L.Convolution(n.data, kernel_size=5, num_output=6, weight_filler=dict(type='xavier'))
    n.tanh1 = L.TanH(n.conv1, in_place=True)
    n.pool1 = L.Pooling(n.conv1, kernel_size=2, stride=2, pool=P.Pooling.MAX)

    n.conv2 = L.Convolution(n.pool1, kernel_size=5, num_output=16, weight_filler=dict(type='xavier'))
    n.tanh2 = L.TanH(n.conv2, in_place=True)
    n.pool2 = L.Pooling(n.conv2, kernel_size=2, stride=2, pool=P.Pooling.MAX)

    n.ip3 = L.InnerProduct(n.pool2, num_output=120, weight_filler=dict(type='xavier'))
    n.tanh3 = L.TanH(n.ip3, in_place=True)

    n.ip4 = L.InnerProduct(n.ip3, num_output=84, weight_filler=dict(type='xavier'))
    n.tanh4 = L.TanH(n.ip4, in_place=True)

    n.ip5 = L.InnerProduct(n.ip4, num_output=2, weight_filler=dict(type='xavier'))
    n.loss = L.SoftmaxWithLoss(n.ip5, n.label)    

    # --------
    # write the prototxt file
    print('Writing net to %s' % prototxt_root+NetTypeName+'.prototxt')    
    with open(prototxt_root+ NetTypeName+'.prototxt', 'w') as f:
       f.write(str(n.to_proto()))
    print 'done...' 

    # --------
    # return the name of the output layer (used for predicting)
    return 'ip5'

你可能感兴趣的:(caffe,prototxt)