TFRecord reader and writer

TFRecord reader and writer

这个库允许在python中高效地读写tfrecord文件。该库还为PyTorch提供了tfrecord文件的IterableDataset读取器。目前支持未压缩和压缩gzip格式的TFRecords。

Installation

pip3 install tfrecord

Usage

建议为每个TFRecord文件创建一个索引文件。当使用多个worker时必须提供索引文件,否则加载器可能会返回重复的记录。

python3 -m tfrecord.tools.tfrecord2idx  

******************** Reading & Writing tf.train.Example *********

Reading tf.Example records in PyTorch

Use TFRecordDataset to read TFRecord files in PyTorch.

import torch
from tfrecord.torch.dataset import TFRecordDataset

tfrecord_path = "/tmp/data.tfrecord"
index_path = None
description = {"image": "byte", "label": "float"}
dataset = TFRecordDataset(tfrecord_path, index_path, description)
loader = torch.utils.data.DataLoader(dataset, batch_size=32)

data = next(iter(loader))
print(data)

Use MultiTFRecordDataset to read multiple TFRecord files. This class samples from given tfrecord files with given probability.
使用MultiTFRecordDataset读取多个TFRecord文件。这个类以给定的概率从给定的tfrecord文件中采样。

import torch
from tfrecord.torch.dataset import MultiTFRecordDataset

tfrecord_pattern = "/tmp/{}.tfrecord"
index_pattern = "/tmp/{}.index"
splits = {
    "dataset1": 0.8,
    "dataset2": 0.2,
}
description = {"image": "byte", "label": "int"}
dataset = MultiTFRecordDataset(tfrecord_pattern, index_pattern, splits, description)
loader = torch.utils.data.DataLoader(dataset, batch_size=32)

data = next(iter(loader))
print(data)

Infinite and finite PyTorch dataset

无限和有限的PyTorch数据集
默认情况下,MultiTFRecordDataset是无限的,这意味着它永远采样数据。您可以通过提供适当的标志使其成为有限的

dataset = MultiTFRecordDataset(..., infinite=False)

Shuffling the data

移动数据
当你提供队列大小时,TFRecordDataset和MultiTFRecordDataset会自动洗牌数据。

dataset = TFRecordDataset(..., shuffle_queue_size=1024)

Transforming input data

您可以选择传递一个函数作为转换参数,在返回之前执行特征的后处理。例如,这可以用于解码图像或将颜色归一化到特定范围或填充可变长度序列。

import tfrecord
import cv2

def decode_image(features):
    # get BGR image from bytes
    features["image"] = cv2.imdecode(features["image"], -1)
    return features


description = {
    "image": "bytes",
}

dataset = tfrecord.torch.TFRecordDataset("/tmp/data.tfrecord",
                                         index_path=None,
                                         description=description,
                                         transform=decode_image)

data = next(iter(dataset))
print(data)

Writing tf.Example records in Python

import tfrecord

writer = tfrecord.TFRecordWriter("/tmp/data.tfrecord")
writer.write({
    "image": (image_bytes, "byte"),
    "label": (label, "float"),
    "index": (index, "int")
})
writer.close()

Reading tf.Example records in Python

import tfrecord

loader = tfrecord.tfrecord_loader("/tmp/data.tfrecord", None, {
    "image": "byte",
    "label": "float",
    "index": "int"
})
for record in loader:
    print(record["label"])

*************** Reading & Writing tf.train.SequenceExample ***************

可以使用与上面相同的方法读取和写入SequenceExamples,并使用一个额外的参数(用于读取的sequence_description和用于写入的sequence_datum),这将导致各自的读写函数将数据视为SequenceExample。

Writing SequenceExamples to file

import tfrecord

writer = tfrecord.TFRecordWriter("/tmp/data.tfrecord")
writer.write({'length': (3, 'int'), 'label': (1, 'int')},
             {'tokens': ([[0, 0, 1], [0, 1, 0], [1, 0, 0]], 'int'), 'seq_labels': ([0, 1, 1], 'int')})
writer.write({'length': (3, 'int'), 'label': (1, 'int')},
             {'tokens': ([[0, 0, 1], [1, 0, 0]], 'int'), 'seq_labels': ([0, 1], 'int')})
writer.close()

Reading SequenceExamples in python

SequenceExample中读取会产生一个包含两个元素的元组。

import tfrecord

context_description = {"length": "int", "label": "int"}
sequence_description = {"tokens": "int", "seq_labels": "int"}
loader = tfrecord.tfrecord_loader("/tmp/data.tfrecord", None,
                                  context_description,
                                  sequence_description=sequence_description)

for context, sequence_feats in loader:
    print(context["label"])
    print(sequence_feats["seq_labels"])

Read SequenceExamples in PyTorch

正如在转换输入一节中所描述的,可以传递一个函数作为转换参数来执行特征的后处理。这应该特别用于序列特性,因为这些是可变长度的序列,需要在批处理之前填充出来。

import torch
import numpy as np
from tfrecord.torch.dataset import TFRecordDataset

PAD_WIDTH = 5
def pad_sequence_feats(data):
    context, features = data
    for k, v in features.items():
        features[k] = np.pad(v, ((0, PAD_WIDTH - len(v)), (0, 0)), 'constant')
    return (context, features)

context_description = {"length": "int", "label": "int"}
sequence_description = {"tokens": "int ", "seq_labels": "int"}
dataset = TFRecordDataset("/tmp/data.tfrecord",
                          index_path=None,
			  description=context_description,
			  transform=pad_sequence_feats,
			  sequence_description=sequence_description)
loader = torch.utils.data.DataLoader(dataset, batch_size=32)
data = next(iter(loader))
print(data)

或者,您可以选择实现一个自定义collate_fn来组装批处理,例如,执行动态填充。

import torch
import numpy as np
from tfrecord.torch.dataset import TFRecordDataset

def collate_fn(batch):
    from torch.utils.data._utils import collate
    from torch.nn.utils import rnn
    context, feats = zip(*batch)
    feats_ = {k: [torch.Tensor(d[k]) for d in feats] for k in feats[0]}
    return (collate.default_collate(context),
            {k: rnn.pad_sequence(f, True) for (k, f) in feats_.items()})

context_description = {"length": "int", "label": "int"}
sequence_description = {"tokens": "int ", "seq_labels": "int"}
dataset = TFRecordDataset("/tmp/data.tfrecord",
                          index_path=None,
			  description=context_description,
			  transform=pad_sequence_feats,
			  sequence_description=sequence_description)
loader = torch.utils.data.DataLoader(dataset, batch_size=32, collate_fn=collate_fn)
data = next(iter(loader))
print(data)

你可能感兴趣的:(深度学习,pytorch,python)