活动地址:CSDN21天学习挑战赛
我们在转化数据集时经常会使用这个函数,他的所用是切分传入的 Tensor
的第一个维度,生成相应的 dataset
。
为什么要转换?
将python的list
和numpy数组
转换成tensorflow
的dataset
,才能被model.fit
函数训练
import tensorflow as tf
import numpy as np
x = np.random.uniform(size=(5, 3))
print(x)
print(type(x))
dataset = tf.data.Dataset.from_tensor_slices(x)
print(type(dataset))
for i in dataset:
print(i)
可以看到ndarray
类型的x被在第0维切分成了5个不同tensor
也就是5个相应的 dataset
import tensorflow as tf
import numpy as np
x = np.random.uniform(size=(5, 2))
print(x)
y = [1,2,3,4,5]
print(y)
dataset = tf.data.Dataset.from_tensor_slices((x, y))
for i in dataset:
print(i)
可以看到:
x和y均在第0维被切分成了5个tensor,并且相应位置的元素在dataset中组成了一组。
这一点很重要,这样就可以实现特征 + 标签的dataset
dict_data = dict([('a', [11,22]), ('b', [33, 44]), ('c', [55, 66])])
print(dict_data)
dataset = tf.data.Dataset.from_tensor_slices(dict_data)
for i in dataset:
print(i)
运行结果:
{'a': [11, 22], 'b': [33, 44], 'c': [55, 66]}
{'a': <tf.Tensor: shape=(), dtype=int32, numpy=11>, 'b': <tf.Tensor: shape=(), dtype=int32, numpy=33>, 'c': <tf.Tensor: shape=(), dtype=int32, numpy=55>}
{'a': <tf.Tensor: shape=(), dtype=int32, numpy=22>, 'b': <tf.Tensor: shape=(), dtype=int32, numpy=44>, 'c': <tf.Tensor: shape=(), dtype=int32, numpy=66>}
由此可知,from_tensor_slices的大概使用方法,和支持的传入数据类型(元组)。