import caffe
from caffe import layers as L
import os
import numpy as np
Solver_PATH = 'auto_solver.prototxt'
def change_env():
root = os.path.dirname(__file__)
os.chdir(root)
print ("current work root:->",root)
def net(img_list,batch_size,mean_value=0):
network = caffe.NetSpec()
network.data,network.label = L.ImageData(source=img_list,batch_size=batch_size,new_width=28,new_height=28,ntop=2, transform_param=dict(scale=1/255.0, mean_value=mean_value))
network.ip1 = L.InnerProduct(network.data,num_output=50,weight_filler=dict(type="xavier"))
network.relu1 = L.ReLU(network.ip1,in_place=True)
network.ip2 = L.InnerProduct(network.relu1,num_output=10,weight_filler=dict(type="xavier"))
network.loss = L.SoftmaxWithLoss(network.ip2,network.label)
return network.to_proto()
def file_write(path1="auto_train.prototxt",path2="auto_test.prototxt"):
with open(path1,"w") as f:
f.write(str(net("train.imglist",200)))
with open(path2,"w") as f:
f.write(str(net("test.imglist",50)))
def main():
#change_env()
file_write()
solver = caffe.SGDSolver(Solver_PATH)
solver.solve()
iternum = 100
loss_iter = np.zeros(iternum)
for it in range(iternum):
solver.step(1)
loss_iter[it] = solver.net.blobs['loss'].data
print (loss_iter)
if __name__ == '__main__':
main()
参考:
配套课件与代码:链接:https://pan.baidu.com/s/1rok-dDvKnelFEpPRdE_f8A
提取码:7re8