今天在写NCF代码的时候,发现网络上的代码有一种新的数据读取方式,这里将对应的片段剪出来给大家分享下。
NCF的文章参考:https://www.jianshu.com/p/6173dbde4f53
原始数据
我们的原始数据保存在npy文件中,是一个字典类型,有三个key,分别是user,item和label:
data = np.load('data/test_data.npy').item()
print(type(data))
#output
构建tf的Dataset
使用 tf.data.Dataset.from_tensor_slices方法,将我们的数据变成tensorflow的DataSet:
dataset = tf.data.Dataset.from_tensor_slices(data)
print(type(dataset))
#output
进一步,将我们的Dataset变成一个BatchDataset,这样的话,在迭代数据的时候,就可以一次返回一个batch大小的数据:
dataset = dataset.shuffle(1000).batch(100)
print(type(dataset))
#output
可以看到,我们在变成batch之前使用了一个shuffle对数据进行打乱,100表示buffersize,即每取1000个打乱一次。
此时dataset有两个属性,分别是output_shapes和output_types,我们将根据这两个属性来构造迭代器,用于迭代数据。
print(dataset.output_shapes)
print(dataset.output_types)
#output
{'user': TensorShape([Dimension(None)]), 'item': TensorShape([Dimension(None)]), 'label': TensorShape([Dimension(None)])}
{'user': tf.int32, 'item': tf.int32, 'label': tf.int32}
构造迭代器
我们使用上面提到的两个dataset的属性,并使用tf.data.Iterator.from_structure方法来构造一个迭代器:
iterator = tf.data.Iterator.from_structure(dataset.output_types,
dataset.output_shapes)
迭代器需要初始化:
sess.run(iterator.make_initializer(dataset))
此时,就可以使用get_next(),方法来源源不断的读取batch大小的数据了
def getBatch():
sample = iterator.get_next()
print(sample)
user = sample['user']
item = sample['item']
return user,item
使用迭代器的正确姿势
我们这里来计算返回的每个batch中,user和item的平均值:
users,items = getBatch()
usersum = tf.reduce_mean(users,axis=-1)
itemsum = tf.reduce_mean(items,axis=-1)
迭代器iterator只能往前遍历,如果遍历完之后还调用get_next()的话,会报tf.errors.OutOfRangeError错误,因此需要使用try-catch。
try:
while True:
print(sess.run([usersum,itemsum]))
except tf.errors.OutOfRangeError:
print("outOfRange")
如果想要多次遍历数据的话,初始化外面包裹一层循环即可:
for i in range(2):
sess.run(iterator.make_initializer(dataset))
try:
while True:
print(sess.run([usersum,itemsum]))
except tf.errors.OutOfRangeError:
print("outOfRange")
完整代码
import numpy as np
import tensorflow as tf
data = np.load('data/test_data.npy').item()
print(type(data))
dataset = tf.data.Dataset.from_tensor_slices(data)
print(type(dataset))
dataset = dataset.shuffle(10000).batch(100)
print(type(dataset))
print(dataset.output_shapes)
print(dataset.output_types)
iterator = tf.data.Iterator.from_structure(dataset.output_types,
dataset.output_shapes)
print(type(iterator))
def getBatch():
sample = iterator.get_next()
print(sample)
user = sample['user']
item = sample['item']
return user,item
users,items = getBatch()
usersum = tf.reduce_mean(users,axis=-1)
itemsum = tf.reduce_mean(items,axis=-1)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for i in range(2):
sess.run(iterator.make_initializer(dataset))
try:
while True:
print(sess.run([usersum,itemsum]))
except tf.errors.OutOfRangeError:
print("outOfRange")