Convnet在Windows下的使用

最简单的方法是找个能调试的python IDE,读取各个文件之后看里面的内容,这里有个我写的方法,仔细阅读之后就可以使用了。事先安装PIL包。有错写评论即可。

第93,94行是建立各个data_batch_1等等

最后一行建立batches.meta文件

这个code生成7个data_batch文件,分为两类,一类为nopeople,一类为exist_people。需要修改网络配置文件[fc10]层outputs个数为2。

14-1-9补充:这个程序最好在linux下使用,windows下生成的文件貌似在使用时会报错。


[python]  view plain copy print ?
  1. import os  
  2. import cPickle  
  3. import pickle  
  4. import numpy as np  
  5. from numpy import array, append  
  6. from PIL import Image  
  7. import Image  
  8.   
  9. def makeBatch (load_path, save_path, data_size):  
  10.     data = []  
  11.     filenames = []  
  12.     class_list = []  
  13.     class_file = file('train-origin-pics-labels.txt''rb').readlines()  
  14.     file_list = os.listdir(load_path)  
  15.     num_sq = save_path[len(save_path)-1]  
  16.     for item in  file_list:  
  17.         if item.endswith(".jpg"):  
  18.             picture_number = item[0:len(item)-4]  
  19.             picture_num = int(picture_number)  
  20.             class_picture = class_file[picture_num-1][10:11]  
  21.             if int(picture_num)%100 == 0:  
  22.                 print picture_number  
  23.             n = os.path.join(load_path, item)  
  24.             inputImage = Image.open(n)  
  25.             (width,height) = inputImage.size  
  26.             #if  width > height:  
  27.             #    newwidth = width/height*128  
  28.             #    small_image = inputImage.resize((newwidth, 128),Image.ANTIALIAS)  
  29.             #else:  
  30.             #    newheight = height/width*128  
  31.             #    small_image = inputImage.resize((128, newheight),Image.ANTIALIAS)  
  32.             small_image = inputImage.resize((data_size, data_size),Image.ANTIALIAS)  
  33.             try:  
  34.                 r, g, b = small_image.split()  
  35.                 reseqImage = list(r.getdata()) + list(g.getdata()) + list(b.getdata())  
  36.                 data.append(reseqImage)  
  37.                 filenames.append(item)  
  38.                 class_list.append(class_picture)  
  39.             except:  
  40.                 print 'error' + picture_number  
  41.     data_array = np.array(data, dtype = np.uint8)  
  42.     T_data = data_array.T  
  43.     out_file = file(save_path, 'wb')  
  44.     dic = {'batch_label':'batch ' + num_sq + ' of 6''data':T_data, 'labels':class_list, 'filenames':filenames}  
  45.     pickle.dump(dic, out_file)  
  46.     out_file.close()  
  47.   
  48. def read_batch(batch_path, data_size):  
  49.     in_file = open(batch_path, 'r+')  
  50.     xx = cPickle.load(in_file)  
  51.     in_file.close()  
  52.     T_datas = xx['data']  
  53.     datas = T_datas.T  
  54.     c = np.zeros((1, data_size*data_size*3), dtype=np.float32)  
  55.     i  = 0  
  56.     for data in datas:  
  57.         i += 1  
  58.         c = c + data  
  59.     return i, c  
  60.   
  61. def add_all(data_size, path):  
  62.     count = 0  
  63.     totalc = np.zeros((1, data_size*data_size*3), dtype=np.float32)  
  64.     for idx in range(17):  
  65.         print 'reading batch'+str(idx)  
  66.         path += '/data_batch_' + str(idx)  
  67.         curcount, curc = read_batch(path, data_size)  
  68.         count += curcount  
  69.         totalc = totalc + curc  
  70.   
  71.     return count, totalc  
  72.   
  73. def write_data(data_size, path):  
  74.     cout, total = add_all(data_size)  
  75.     a  = []  
  76.     for i in range(0, len(total[0])):  
  77.         c = total[0][i] / cout  
  78.         a.append( [c])  
  79.     a_array = array(a, dtype = np.float32)  
  80.     return a_array  
  81.   
  82. def main(data_size, path):  
  83.     data_mean = write_data(data_size, path)  
  84.     label_names = ['nopeople''exist_people']  
  85.     num1 = 5000  
  86.     num2 = data_size*data_size*3  
  87.     dic = {'data_mean':data_mean, 'label_names':label_names, 'num_cases_per_batch':num1, 'num_vis':num2}  
  88.     out_file = open(path+'/batches.meta''w+')  
  89.     cPickle.dump(dic, out_file)  
  90.     out_file.close()  
  91.   
  92. data_size = 64  
  93. for i in range(17):  
  94.     makeBatch('./train-origin-pics-part'+str(i), 'baidu_data_size_'+str(data_size)+'/data_batch_'+str(i), data_size)  
  95. main(data_size, 'baidu_data_size_'+str(data_size))  

转载地址:http://blog.csdn.net/xuanwu_yan/article/details/16948385

你可能感兴趣的:(python)