Pytorch读取npy数据格式,编写dataset模块,可配合Dataloader进行使用

        在训练模型前,最重要的部分就是制作好数据集,有些情况下,由于图片数据过多,然后存储很不方便,我们就需要将数据制作成npy类型的数据格式。npy数据格式是一个四维的数组[N,H,W, C],其中N代表数据集的总数,H, W,C分别代表每一张图片对应的长、宽、以及通道数。

数据制作好之后,就是如何加载数据问题,TF中加载数据相对比较容易,但是Pytorch中,我们一般都是将数据制作成dataset,再传入Dataloader进行加载,因此就需要继承Dataset的类,然后编写读取npy的数据格式。Dataset中,我们需要定义三个函数。

一、__init__(self,data) 函数

主要是用来加载npy数据的,也可以加载数据预处理的函数,比如将数据转化为tensor之类的操作

 def __init__(self, data):
        self.data = np.load(data) #加载npy数据
        self.transforms = transform #转为tensor形式

二、__len__(self)函数

这个函数就是用来返回数据的总个数

 def __len__(self):
        return self.data.shape[0] #返回数据的总个数

三、 __getitem__(self,index)函数

这个是最要的函数,类似一个for循环,从头开始,每次读取一个保存在npy里面的数据,然后进行处理后,可以同时返回训练数据,以及对应的标签

    def __getitem__(self, index):
        hdct= self.data[index, :, :, :]  # 读取每一个npy的数据
        hdct = np.squeeze(hdct)  # 删掉一维的数据,就是把通道数这个维度删除
        ldct = 2.5 * skimage.util.random_noise(hdct * (0.4 / 255), mode='poisson', seed=None) * 255 #加poisson噪声
        hdct=Image.fromarray(np.uint8(hdct)) #转成image的形式
        ldct=Image.fromarray(np.uint8(ldct)) #转成image的形式
        hdct= self.transforms(hdct)  #转为tensor形式
        ldct= self.transforms(ldct)  #转为tensor形式
        return ldct,hdct #返回数据还有标签

完整的代码如下:

import torch
import numpy as np
import skimage
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
torch.manual_seed(1)  # reproducible

transform = transforms.Compose([
    transforms.ToTensor(),  # 将图片转换为Tensor,归一化至[0,1]
])
'''NPY数据格式'''
class MyDataset(Dataset):
    def __init__(self, data):
        self.data = np.load(data) #加载npy数据
        self.transforms = transform #转为tensor形式
    def __getitem__(self, index):
        hdct= self.data[index, :, :, :]  # 读取每一个npy的数据
        hdct = np.squeeze(hdct)  # 删掉一维的数据,就是把通道数这个维度删除
        ldct = 2.5 * skimage.util.random_noise(hdct * (0.4 / 255), mode='poisson', seed=None) * 255 #加poisson噪声
        hdct=Image.fromarray(np.uint8(hdct)) #转成image的形式
        ldct=Image.fromarray(np.uint8(ldct)) #转成image的形式
        hdct= self.transforms(hdct)  #转为tensor形式
        ldct= self.transforms(ldct)  #转为tensor形式
        return ldct,hdct #返回数据还有标签
    def __len__(self):
        return self.data.shape[0] #返回数据的总个数

def main():
    dataset=MyDataset('.\data_npy\img_covid_poisson_glay_clean_BATCH_64_PATS_100.npy')
    data= DataLoader(dataset, batch_size=64, shuffle=True, pin_memory=True)

if __name__ == '__main__':
	main()

 

你可能感兴趣的:(机器学习,图像处理,numpy,机器学习,Pytorch,npy数据格式,读取自己的数据集)