Tensorflow入门系列(二)——读取csv文件代码详解

例子——波士顿房价预测CSV文件读取

下面介绍的方法均以这个数据集为例,完整代码参考教程

tf.train.string_input_producer()

Tensorflow对于数据的读取有三种方式:
1、一种是通过占位符的方式feeding,这种一般是通过PIL或Numpy接收数据,在来喂入神经网络。
2、一种是读取文件数据,适合大型数据集的使用。
3、最后一种是利用常量或变量存储数据,达到预加载的数据的效果,适用于数据量比较小。

  • string_input_producer(string_tensor,
    num_epochs=None,
    shuffle=True,
    seed=None,
    capacity=32,
    shared_name=None,
    name=None,
    cancel_op=None):

这个函数是一个针对文件的tensor生成器
1、string_tensor 输入是文件名字符串的一个一维tensor
2、num_epochs 生成器可迭代次数
3、shuffle 每个epoch是否打乱顺序
4、seed 随机数种子,shuffle为True时使用
5、capacity 队列的容量
6、shared_name 不同的上下文环境(Session)中可以通过这个名字共享生成的tensor
7、cancel_op 取消队列的操作
返回:一个文件队列生成器f_queue

tf.TextLineReader()

文本阅读器这个类继承于ReaderBase,默认按行读取,比如csv格式文件。下面介绍它的常用方法

  • init(skip_header_lines=None, name=None)

TextLineReader类的初始化方法,用于创建阅读器的实例,参数skip_header_lines表示每次阅读文件要忽略的行。
返回:一个阅读器实例reader

  • read(queue, name=None)方法

读取器从队列中读取,返回key和value,每次返回一行数据,key表示读取文件名:行号value表示读取那一行的数据,不同列以逗号隔开,类型是字符串
queue是一个队列,常用tf.train.string_input_producer()生成的文件队列

import tensorflow as tf 
filename = r'boston_housing_data.csv'
filename_queue = tf.train.string_input_producer([filename])#创建文件名队列
reader = tf.TextLineReader(skip_header_lines=1)#建立阅读器,并且跳过第一行
key,value = reader.read(filename_queue) #从文件名队列中读取文件
with tf.Session() as sess:
    sess.run(tf.initialize_local_variables())
    tf.train.start_queue_runners()
    for _ in range(5):#读取5行数据
        s_key,s_value = sess.run([key,value])
        print(s_key,s_value)

b’/home/bd/wanglong/boston_housing_data.csv:2’ b’0.00632,18,2.31,0,0.538,6.575,65.2,4.09,1,296,15.3,396.9,4.98,24’
b’/home/bd/wanglong/boston_housing_data.csv:3’ b’0.02731,0,7.07,0,0.469,6.421,78.9,4.9671,2,242,17.8,396.9,9.14,21.6’
b’/home/bd/wanglong/boston_housing_data.csv:4’ b’0.02729,0,7.07,0,0.469,7.185,61.1,4.9671,2,242,17.8,392.83,4.03,34.7’
b’/home/bd/wanglong/boston_housing_data.csv:5’ b’0.03237,0,2.18,0,0.458,6.998,45.8,6.0622,3,222,18.7,394.63,2.94,33.4’
b’/home/bd/wanglong/boston_housing_data.csv:6’ b’0.06905,0,2.18,0,0.458,7.147,54.2,6.0622,3,222,18.7,396.9,5.33,36.2’

tf.decode_csv()

该函数用于将csv记录转换为Tensor,每一列映射到一个张量。

  • decode_csv(records,
    record_defaults,
    field_delim=",",
    use_quote_delim=True,
    name=None,
    na_value="",
    select_cols=None)

1、records:字符串类型的张量。每个字符串都是csv中的记录/行,所有记录都应该具有相同的格式。
2、record_defaults 用于指定每一个样本的每一列的默认值和类型
3、fiels_delim 默认分隔符时是逗号

上面例子结果如下:
[0.00632, 18.0, 2.31, 0.0, 0.538, 6.575, 65.2, 4.09, 1.0, 296.0, 15.3, 396.9, 4.98, 24.0]
[0.02731, 0.0, 7.07, 0.0, 0.469, 6.421, 78.9, 4.9671, 2.0, 242.0, 17.8, 396.9, 9.14, 21.6]
[0.02729, 0.0, 7.07, 0.0, 0.469, 7.185, 61.1, 4.9671, 2.0, 242.0, 17.8, 392.83, 4.03, 34.7]
[0.03237, 0.0, 2.18, 0.0, 0.458, 6.998, 45.8, 6.0622, 3.0, 222.0, 18.7, 394.63, 2.94, 33.4]
[0.06905, 0.0, 2.18, 0.0, 0.458, 7.147, 54.2, 6.0622, 3.0, 222.0, 18.7, 396.9, 5.33, 36.2]

tf.gather_nd()

该函数作用是从一个多维/tensor中取出某部分。

  • gather_nd(params, indices, name=None)

params是待取出的tensor
indices是取出的地方
返回的是一个numpy的ndarray类型的数据

上面例子结果如下:
[ 6.575 15.3 4.98 ]
[ 6.421 17.8 9.14 ]
[ 7.185 17.8 4.03 ]
[ 6.998 18.7 2.94 ]
[ 7.147 18.7 5.33 ]

    indices = [[0, 0], [1, 1]]
    params = [['a', 'b'], ['c', 'd']]
    output = ['a', 'd']
    indices = [[1], [0]]
    params = [['a', 'b'], ['c', 'd']]
    output = [['c', 'd'], ['a', 'b']]
    indices = [[1]]
    params = [[['a0', 'b0'], ['c0', 'd0']],
              [['a1', 'b1'], ['c1', 'd1']]]
    output = [[['a1', 'b1'], ['c1', 'd1']]]

tf.stack()

这是一个矩阵拼接的函数,与之相反的是unstack()函数。参考博客
该函数将 values 中的张量列表打包成一个张量,该张量比 values 中的每个张量都高一个秩,通过沿 axis 维度打包。

  • stack(values, axis=0, name=“stack”)

1、values是待拼接的矩阵,一般是一个列表,元素是各个张量
2、axis是按哪个维度进行拼接

tf.train.shuffle_batch()

作用:函数是先将队列中数据打乱,然后再从队列里读取出来,因此队列中剩下的数据也是乱序的.

shuffle_batch(tensors_list, batch_size, capacity, min_after_dequeue,
       num_threads=1, seed=None, enqueue_many=False, shapes=None,
    allow_smaller_final_batch=False, shared_name=None, name=None)

tensors_list:张量列表
batch_size:一次读取的行数
capacity:队列的最大容量
min_after_dequeue:出队后队列中元素的最小数量,用于确保元素的混合级别
num_threads:线程数量
seed:队列内随机乱序的种子值
allow_smaller_final_batch:为True时,若队列中没有足够的项目,则允许最终批次更小.(可选项)
shared_name:如果设置,则队列将在多个会话中以给定名称共享.(可选项)
name:操作的名称.(可选项)

data_fen = tf.gather_nd(data,[[5],[10],[12]])
feature = tf.stack(data_fen)
label = data[-1]
feature_batch,label_batch = tf.train.shuffle_batch([feature,label],batch_size=10,min_after_dequeue=100,capacity=200)

例子结果如下:打印feature_batch和label_batch
[[ 5.841 19.2 11.41 ]
[ 6.286 18.7 8.94 ]
[ 5.949 21. 8.26 ]
[ 5.95 21. 27.71 ]
[ 6.015 18.5 12.86 ]
[ 6.096 21. 10.26 ]
[ 6.674 21. 11.98 ]
[ 6.142 21. 18.72 ]
[ 6.456 19.7 6.73 ]
[ 6.121 18.5 8.44 ]]
(10, 3)

[20. 21.4 20.4 13.2 22.5 18.2 21. 15.2 22.2 22.2]
(10,)

tf.train.Coordinator()、tf.train.start_queue_runners()

TensorFlow的Session对象是支持多线程的,可以在同一个会话(Session)中创建多个线程,并行执行。在Session中的所有线程都必须能被同步终止,异常必须能被正确捕获并报告,会话终止的时候, 队列必须能被正确地关闭。参考博客

TensorFlow提供了两个类来实现对Session中多线程的管理:tf.Coordinator和 tf.QueueRunner,这两个类往往一起使用。

使用 tf.train.Coordinator()来创建一个线程管理器(协调器)对象,用来管理在Session中的多个线程,可以用来同时停止多个工作线程并且向那个在等待所有工作线程终止的程序报告异常,该线程捕获到这个异常之后就会终止所有线程。

QueueRunner类用来启动tensor的入队线程,可以用来启动多个工作线程同时将多个tensor(训练数据)推送入文件名称队列中,具体执行函数是 tf.train.start_queue_runners

    with tf.Session() as sess:
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)
        for _ in  range(5):
            feature,label = sess.run([feature_batch,label_batch])
            print(feature)
        coord.request_stop()
        coord.join(threads)

使用coord.request_stop()来发出终止所有线程的命令,使用coord.join(threads)把线程加入主线程,等待threads结束。

tf.equal(x, y, name=None)

用于判断两个张量是否相等,相等返回True,不相等返回False

tf.where(condition, x=None, y=None,name=None)

condition:一个Tensor,数据类型为tf.bool类型
如果x、y均为空,那么返回的值为condition中值为True的位置

如果x、y不为空,那么x、y必须有相同的形状。如果x、y是标量,那么condition参数也必须是标量。如果x、y是向量,那么condition必须和x的第一维有相同的形状或者和x形状一致。

condition = tf.equal(data[13],tf.constant(24.0))
data_1 = tf.where(condition,tf.zeros(14),data)

True
()

[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
(14,)

False
()

[2.7310e-02 0.0000e+00 7.0700e+00 0.0000e+00 4.6900e-01 6.4210e+00
7.8900e+01 4.9671e+00 2.0000e+00 2.4200e+02 1.7800e+01 3.9690e+02
9.1400e+00 2.1600e+01]
(14,)

你可能感兴趣的:(Tensorflow)