H5文件读取:
import torch.utils.data as data
import torch
import h5py
class DatasetFromHdf5(data.Dataset):
def __init__(self, file_path):
super(DatasetFromHdf5, self).__init__()
hf = h5py.File(file_path)
self.data = hf.get('data')
self.target = hf.get('label')
def __getitem__(self, index):
return torch.from_numpy(self.data[index,:,:,:]).float(), torch.from_numpy(self.target[index,:,:,:]).float()
def __len__(self):
return self.data.shape[0]
调用的时候,先用DataLoader将数据装入 training_data_loader中
train_set = DatasetFromHdf5(r"D:\PycharmProjects\pytorch-vdsr-master\data\train.h5")
training_data_loader = DataLoader(dataset=train_set, num_workers=opt.threads, batch_size=opt.batchSize, shuffle=True)
在使用数据训练的时候写一个循环,iteration只是一个计数的,从1开始计数,表示已经取第iteration个批次了,batch就是每次取出一个批次的数值。
input和target是取出的输入和希望得到的输出,这里的返回顺序是在上边的DatasetFromHdf5中定义的。
def __getitem__(self, index):
return torch.from_numpy(self.data[index,:,:,:]).float(), torch.from_numpy(self.target[index,:,:,:]).float()
所以batch[0]表示input(也就是存储的data),batch[1]表示label(也就是label)。
index在这里应该是每次按第一个维度取出data中的数值。data[index,:,:,:],本来是维度是1000×1×41×41,每次取的是1×1×41×41。按照batch来,每次取出的就是batch×1×41×41
for iteration, batch in enumerate(training_data_loader, 1):
input, target = Variable(batch[0]), Variable(batch[1], requires_grad=False)