tensorflow 批量读取csv文件用于做深度学习算法相关

目前用了tensorflow、deeplearning4j两个深度学习框架,dl相关算法对数据格式要求都是批量的喂进去,deepl4j在前面已经有几个例子说明,tensorflow也可以批量读取数据,不断给dl算法喂数据进去,在网上刚刚看到一个例子,http://www.cnblogs.com/hunttown/p/6844477.html ,首先数据格式如下,鸾尾花数据
做机器学习的人应该都知道:



  1. Id,SepalLengthCm,SepalWidthCm,PetalLengthCm,PetalWidthCm,Species
  2. 21,5.4,3.4,1.7,0.2,Iris-setosa
  3. 22,5.1,3.7,1.5,0.4,Iris-setosa
  4. 23,4.6,3.6,1.0,0.2,Iris-setosa
  5. 24,5.1,3.3,1.7,0.5,Iris-setosa
  6. 25,4.8,3.4,1.9,0.2,Iris-setosa
  7. 26,5.0,3.0,1.6,0.2,Iris-setosa
  8. 27,5.0,3.4,1.6,0.4,Iris-setosa
  9. 28,5.2,3.5,1.5,0.2,Iris-setosa
  10. 29,5.2,3.4,1.4,0.2,Iris-setosa
  11. 30,4.7,3.2,1.6,0.2,Iris-setosa
  12. 31,4.8,3.1,1.6,0.2,Iris-setosa
  13. 32,5.4,3.4,1.5,0.4,Iris-setosa
  14. 33,5.2,4.1,1.5,0.1,Iris-setosa
  15. 34,5.5,4.2,1.4,0.2,Iris-setosa
  16. 35,4.9,3.1,1.5,0.1,Iris-setosa
  17. 36,5.0,3.2,1.2,0.2,Iris-setosa
  18. 37,5.5,3.5,1.3,0.2,Iris-setosa
  19. 39,5.5,4.2,1.4,0.2,Iris-virginica
  20. 40,4.9,3.1,1.5,0.1,Iris-versicolor
  21. 38,5.0,3.2,1.2,0.2,Iris-versicolor
  22. 51,5.5,3.5,1.3,0.2,Iris-versicolor

下面是程序实现:

  1. import tensorflow as tf
  2. path="/Users/shuubiasahi/Desktop/业务相关文档/iris.csv"
  3. def read_data(file_queue):
  4. reader=tf.TextLineReader(skip_header_lines=1)
  5. key,value=reader.read(file_queue)
  6. defaults=[[0], [0.], [0.], [0.], [0.], ['']]
  7. Id,SepalLengthCm,SepalWidthCm,PetalLengthCm,PetalWidthCm,Species = tf.decode_csv(value, defaults)
  8. preprocess_op=tf.case({
  9. tf.equal(Species,tf.constant('Iris-setosa')):lambda :tf.constant(0),
  10. tf.equal(Species, tf.constant('Iris-versicolor')): lambda: tf.constant(1),
  11. tf.equal(Species, tf.constant('Iris-virginica')): lambda: tf.constant(2),
  12. },lambda :tf.constant(-1),exclusive=True)
  13. return tf.stack([SepalLengthCm, SepalWidthCm, PetalLengthCm, PetalWidthCm]), preprocess_op
  14. def create_pipeline(filename,batch_size,num_epochs=None):
  15. file_queue = tf.train.string_input_producer([filename], num_epochs=num_epochs)
  16. example, label = read_data(file_queue)
  17. min_after_dequeue = 1000
  18. capacity = min_after_dequeue + batch_size
  19. example_batch, label_batch = tf.train.shuffle_batch(
  20. [example, label], batch_size=batch_size, capacity=capacity,
  21. min_after_dequeue=min_after_dequeue
  22. )
  23. return example_batch, label_batch
  24. x_train_batch, y_train_batch = create_pipeline(path, 5, num_epochs=1000)
  25. x_test, y_test = create_pipeline(path, 60)
  26. init_op = tf.global_variables_initializer()
  27. local_init_op = tf.local_variables_initializer() # local variables like epoch_num, batch_size
  28. with tf.Session() as sess:
  29. sess.run(init_op)
  30. sess.run(local_init_op)
  31. # Start populating the filename queue.
  32. coord = tf.train.Coordinator()
  33. threads = tf.train.start_queue_runners(coord=coord)
  34. # Retrieve a single instance:
  35. try:
  36. #while not coord.should_stop():
  37. for _ in range(6):
  38. example, label = sess.run([x_test, y_test])
  39. print (example)
  40. print (label)
  41. except tf.errors.OutOfRangeError:
  42. print ('Done reading')
  43. finally:
  44. coord.request_stop()
  45. coord.join(threads)
  46. sess.close()

你可能感兴趣的:(python编程,机器学习)