《如何制作类mnist的金融数据集》(完结篇)——4、如何使用生成的ubyte文件?

4、如何使用生成的ubyte文件?

        估计有很多同学拿到了ubyte文件后可能不知道怎么用,这里也简单说明一下吧。

       拿mnist数据集为例,通常来讲想使用mnist数据集时,会直接通过代码线上下载,然后使用dataloader去加载mnist数据集。如下述代码所示:

《如何制作类mnist的金融数据集》(完结篇)——4、如何使用生成的ubyte文件?_第1张图片

Dataloader的各个参数大家可以从torch官网中看到。

       然而当我数据集是已经下载好的线下数据集怎么办呢?其实就是自己创造dataloader。

import os
import numpy as np
import gzip
import torch.utils.data as Data
from torchvision import transforms

dataPath = 'E:/byte_creater'

def load_data(data_folder, data_name, label_name):
    """
        data_folder: 文件目录
        data_name: 数据文件名
        label_name:标签数据文件名
    """
    with open(os.path.join(data_folder, label_name), 'rb') as lbpath:  # rb表示的是读取二进制数据
        y_train = np.frombuffer(lbpath.read(), np.uint8, offset=8)

    with open(os.path.join(data_folder, data_name), 'rb') as imgpath:
        x_train = np.frombuffer(imgpath.read(), np.uint8, offset=16).reshape(len(y_train), 28, 28)

    return (x_train, y_train)


class DealDataset(Data.Dataset):
    """
        读取数据、初始化数据
    """

    def __init__(self, folder, data_name, label_name, transform=None):
        (train_set, train_labels) = load_data(folder, data_name,
                                              label_name)  # 其实也可以直接使用torch.load(),读取之后的结果为torch.Tensor形式
        self.train_set = train_set
        self.train_labels = train_labels
        self.transform = transform

    def __getitem__(self, index):
        img, target = self.train_set[index], int(self.train_labels[index])
        if self.transform is not None:
            img = self.transform(img)
        return img, target

    def __len__(self):
        return len(self.train_set)


# 实例化这个类,然后我们就得到了Dataset类型的数据,记下来就将这个类传给DataLoader,就可以了。
trainDataset = DealDataset(dataPath, "train-images-idx3-ubyte", "train-labels-idx1-ubyte",
                           transform=transforms.ToTensor())
testDataset = DealDataset(dataPath, "test-images-idx3-ubyte", "test-labels-idx1-ubyte",
                          transform=transforms.ToTensor())

# 训练数据和测试数据的装载
trainloader = Data.DataLoader(
    dataset=trainDataset,
    batch_size=64,  # 一个批次可以认为是一个包,每个包中含有100张图片
    shuffle=True,
)

testloader = Data.DataLoader(
    dataset=testDataset,
    batch_size=64,
    shuffle=False,
)

那么通过上述代码便可以将ubyte文件形式的数据集加载到网络中了。

你可能感兴趣的:(制作类mnist金融数据集,人工智能,金融)