tf.data.Dataset
是一种高效好用的数据集载入工具. 使用它的 map
方法对数据进行处理也十分方便, 这个处理的函数最好是用 tf2 已经提供的 API 来实现.
对于一个用文件夹/文件名存储的数据集, 本人想对其文件路径进行解析(用正则库), 从而直接返回每个样本的数据和标签:
def load_data(file_name): # get tf.Tensor here
file_name = file_name.numpy().decode("utf8")
label = label_dict[pattern.search(file_name).group(0)]
data = np.loadtxt(file_name)[..., np.newaxis]
return tf.cast(data, tf.float32), tf.cast(label, tf.uint8)
ds_train = tf.data.Dataset.list_files("./data/dataset/train/*/*.txt", shuffle=True) \
.filter(lambda file_name: tf.strings.regex_full_match(file_name, source_list_regex)) \
.map(load_data, num_parallel_calls=64)
.batch(80) \
.prefetch(tf.data.experimental.AUTOTUNE) \
.cache()
可惜 tf2 没有内置的完整接口, 用 python 思维去如上去处理数据会有如下报错:
AttributeError: 'Tensor' object has no attribute 'numpy'
其主要原因是 map
方法默认情况下会把处理函数编译为静态图, 此时传入的参数类型为 tf1 里面的类型 tf.Tensor
, 而非 tf2 默认的 EagerTensor
, 并没有 numpy()
方法来拿到普通数据.
此时就需要把这个函数单独用 Eager 模式执行, 此处就可以使用 tf.py_function
或 tf.numpy_function
.
PS: 很多教程里说直接全局用 Eager 模式去解决, 虽然理论上可以 , 但是会影响效率, 还不如直接去 pytorch
参考 Tensorflow 官网 的文档, tf.py_function
的用法如下:
tf.py_function(
func, # 一个 Python 函数, 它的输入是 inp, 输出的类型为 Tout
inp, # func 的输入, 必须为一个列表或 Tuple
Tout, # func 的输出类型, 也必须为一个列表或元组
name=None
)
为了方便使用和更美观, 我包装成了一个装饰器 (类型标注部分可以删去, 不影响), 用法在最后面展示:
from typing import Union
import tensorflow as tf
def naive_function(input_signature: Union[tuple, list], type_out: Union[tuple, list]):
"""
Wraps a python function into a TensorFlow op that executes it eagerly.
Args:
input_signature: This is the input signature of the function. It is a list of TensorSpec objects.
type_out: The output types of the function.
Returns:
A decorator function
"""
if type_out is None:
raise TypeError("You must provide output types as a list or a tuple!")
def map_decorator(func):
@tf.function(input_signature=input_signature)
def wrapper(*args):
# Use a tf.py_function to prevent auto-graph from compiling the method
return tf.py_function(
func,
inp=args,
Tout=type_out
)
return wrapper
return map_decorator
同样地, 可以直接看官网例子使用或像我一样包装成装饰器:
from typing import Union
import tensorflow as tf
def np_function(input_signature: Union[tuple, list], type_out: Union[tuple, list]):
"""
It takes a function and returns a function that takes a tensor as input and returns a tensor as output
Args:
input_signature: This is the input signature list of the function. It is a list of TensorSpec objects.
type_out: The output types of the function.
Returns:
A decorator that takes a function and returns a wrapper function.
"""
if type_out is None:
raise TypeError("You must provide output types as a list or a tuple!")
def map_decorator(func):
@tf.function(input_signature=input_signature)
def wrapper(*args):
return tf.numpy_function(func, inp=args, Tout=type_out)
return wrapper
return map_decorator
将上面任意一种装饰器复制到你的个人工具包中并 import 之后(也可以直接复制到代码区, 但不美观), 可以像以下案例来使用, 来避免最上面所展示的 load_data()
报错.
只要给装饰器提供输入参数的类型(tf.TensorSpec
对象)和输出参数的类型即可(都要放入列表!).
装饰器作用后, file_name
已经是一个 Eager tensor, 最上面的函数几乎不需要改动就能用. 稳定起见, 函数最后来一个 tf.cast()
强调一下数据类型.
from mytools import naive_function
@naive_function(input_signature=[tf.TensorSpec(shape=None, dtype=tf.string)], type_out=[tf.float32, tf.uint8])
def load_data(file_name): # get eager tensors here!!!
"""
It takes a file name, extracts the label from the file name, loads the data from the file, and returns the data
and label
Args:
file_name: The name of the file to load,whose type is eager tensor not tf.Tensor!!!
Returns:
The data and the label
Examples for test
--------
for a in ds_train.take(1):
print(a[1].numpy())
"""
file_name = file_name.numpy().decode("utf8") # .numpy() 取出数据
label = label_dict[pattern.search(file_name).group(0)]
data = np.loadtxt(file_name)[..., np.newaxis]
return tf.cast(data, tf.float32), tf.cast(label, tf.uint8)
基本同上,只有一处不同: 装饰器作用后, file_name
已经是一个 numpy 类型数据.
from mytools import np_function
@np_function(input_signature=[tf.TensorSpec(shape=None, dtype=tf.string)], type_out=[tf.float32, tf.uint8])
def load_data(file_name): # get numpy types here
file_name = file_name.decode("utf8") # 直接使用数据
label = label_dict[pattern.search(file_name).group(0)]
data = np.loadtxt(file_name)[..., np.newaxis]
return tf.cast(data, tf.float32), tf.cast(label, tf.uint8)
这两种方式是 Eager 模式执行, 所以在序列化, 多线程和效率等方面有影响, 具体可以参考 官网 说明, 直接搜索 API 名字即可.
除了类型标注部分可以删去, 装饰器里的 wrapper()
上面的 @tf.function(input_signature)
(应该)也可以删除, 这样的话使用装饰器时只提供 type_out
即可. 这些修改很简单, 不再赘述.