这个库允许在python中高效地读写tfrecord文件。该库还为PyTorch提供了tfrecord文件的IterableDataset读取器。目前支持未压缩和压缩gzip格式的TFRecords。
pip3 install tfrecord
建议为每个TFRecord文件创建一个索引文件。当使用多个worker时必须提供索引文件,否则加载器可能会返回重复的记录。
python3 -m tfrecord.tools.tfrecord2idx
Use TFRecordDataset to read TFRecord files in PyTorch.
import torch
from tfrecord.torch.dataset import TFRecordDataset
tfrecord_path = "/tmp/data.tfrecord"
index_path = None
description = {"image": "byte", "label": "float"}
dataset = TFRecordDataset(tfrecord_path, index_path, description)
loader = torch.utils.data.DataLoader(dataset, batch_size=32)
data = next(iter(loader))
print(data)
Use MultiTFRecordDataset to read multiple TFRecord files. This class samples from given tfrecord files with given probability.
使用MultiTFRecordDataset读取多个TFRecord文件。这个类以给定的概率从给定的tfrecord文件中采样。
import torch
from tfrecord.torch.dataset import MultiTFRecordDataset
tfrecord_pattern = "/tmp/{}.tfrecord"
index_pattern = "/tmp/{}.index"
splits = {
"dataset1": 0.8,
"dataset2": 0.2,
}
description = {"image": "byte", "label": "int"}
dataset = MultiTFRecordDataset(tfrecord_pattern, index_pattern, splits, description)
loader = torch.utils.data.DataLoader(dataset, batch_size=32)
data = next(iter(loader))
print(data)
无限和有限的PyTorch数据集
默认情况下,MultiTFRecordDataset
是无限的,这意味着它永远采样数据。您可以通过提供适当的标志使其成为有限的
dataset = MultiTFRecordDataset(..., infinite=False)
移动数据
当你提供队列大小时,TFRecordDataset和MultiTFRecordDataset会自动洗牌数据。
dataset = TFRecordDataset(..., shuffle_queue_size=1024)
您可以选择传递一个函数作为转换参数,在返回之前执行特征的后处理。例如,这可以用于解码图像或将颜色归一化到特定范围或填充可变长度序列。
import tfrecord
import cv2
def decode_image(features):
# get BGR image from bytes
features["image"] = cv2.imdecode(features["image"], -1)
return features
description = {
"image": "bytes",
}
dataset = tfrecord.torch.TFRecordDataset("/tmp/data.tfrecord",
index_path=None,
description=description,
transform=decode_image)
data = next(iter(dataset))
print(data)
import tfrecord
writer = tfrecord.TFRecordWriter("/tmp/data.tfrecord")
writer.write({
"image": (image_bytes, "byte"),
"label": (label, "float"),
"index": (index, "int")
})
writer.close()
import tfrecord
loader = tfrecord.tfrecord_loader("/tmp/data.tfrecord", None, {
"image": "byte",
"label": "float",
"index": "int"
})
for record in loader:
print(record["label"])
可以使用与上面相同的方法读取和写入SequenceExamples,并使用一个额外的参数(用于读取的sequence_description和用于写入的sequence_datum),这将导致各自的读写函数将数据视为SequenceExample。
import tfrecord
writer = tfrecord.TFRecordWriter("/tmp/data.tfrecord")
writer.write({'length': (3, 'int'), 'label': (1, 'int')},
{'tokens': ([[0, 0, 1], [0, 1, 0], [1, 0, 0]], 'int'), 'seq_labels': ([0, 1, 1], 'int')})
writer.write({'length': (3, 'int'), 'label': (1, 'int')},
{'tokens': ([[0, 0, 1], [1, 0, 0]], 'int'), 'seq_labels': ([0, 1], 'int')})
writer.close()
从SequenceExample
中读取会产生一个包含两个元素的元组。
import tfrecord
context_description = {"length": "int", "label": "int"}
sequence_description = {"tokens": "int", "seq_labels": "int"}
loader = tfrecord.tfrecord_loader("/tmp/data.tfrecord", None,
context_description,
sequence_description=sequence_description)
for context, sequence_feats in loader:
print(context["label"])
print(sequence_feats["seq_labels"])
正如在转换输入一节中所描述的,可以传递一个函数作为转换参数来执行特征的后处理。这应该特别用于序列特性,因为这些是可变长度的序列,需要在批处理之前填充出来。
import torch
import numpy as np
from tfrecord.torch.dataset import TFRecordDataset
PAD_WIDTH = 5
def pad_sequence_feats(data):
context, features = data
for k, v in features.items():
features[k] = np.pad(v, ((0, PAD_WIDTH - len(v)), (0, 0)), 'constant')
return (context, features)
context_description = {"length": "int", "label": "int"}
sequence_description = {"tokens": "int ", "seq_labels": "int"}
dataset = TFRecordDataset("/tmp/data.tfrecord",
index_path=None,
description=context_description,
transform=pad_sequence_feats,
sequence_description=sequence_description)
loader = torch.utils.data.DataLoader(dataset, batch_size=32)
data = next(iter(loader))
print(data)
或者,您可以选择实现一个自定义collate_fn
来组装批处理,例如,执行动态填充。
import torch
import numpy as np
from tfrecord.torch.dataset import TFRecordDataset
def collate_fn(batch):
from torch.utils.data._utils import collate
from torch.nn.utils import rnn
context, feats = zip(*batch)
feats_ = {k: [torch.Tensor(d[k]) for d in feats] for k in feats[0]}
return (collate.default_collate(context),
{k: rnn.pad_sequence(f, True) for (k, f) in feats_.items()})
context_description = {"length": "int", "label": "int"}
sequence_description = {"tokens": "int ", "seq_labels": "int"}
dataset = TFRecordDataset("/tmp/data.tfrecord",
index_path=None,
description=context_description,
transform=pad_sequence_feats,
sequence_description=sequence_description)
loader = torch.utils.data.DataLoader(dataset, batch_size=32, collate_fn=collate_fn)
data = next(iter(loader))
print(data)