pytorch dataloader的使用

pytorch的数据往模型里输入的时候,不像tensorflow一样定义一下placeholder直接feeddict就可以,需要使用dataloader中转。

使用dataloader了以后,可以通过dataloader的传入参数控制minibatch,shuffle,并行计算时使用的cpu核心数。

而dataloader用的时候,也需要一个dataset,将数据整理成dataloader可以读得懂的结构。

dataset需要分为3部分,init,getitem,len函数。

getitem返回训练数据和label,len返回数据长度。

 

import numpy as np
import torch
from torch.utils.data import Dataset
import os
import time
import collections
import random
from scipy.ndimage import zoom
import warnings
from scipy.ndimage.interpolation import rotate

class dataset(Dataset):
    def __init__(self, data_dir, split_path, config, phase = 'train',split_comber=None):
        assert(phase == 'train' or phase == 'val' or phase == 'test')
        self.phase = phase
        idcs = np.load(split_path)
        print("length of idcs is :",len(idcs))
        if phase!='test':
            idcs = [f for f in idcs ]

        self.filenames = [os.path.join(data_dir, idx) for idx in idcs]

        labels = []
        
        for idx in idcs:
           
            name=idx.split('_')[0]+'_'+idx.split('_')[1]
            l = np.load((data_dir+name+'_label.npy'))
            if np.all(l==0):
                l=np.array([])
            labels.append(l)
        self.labels=labels

    def __getitem__(self, idx,split=None):
        t = time.time()
        np.random.seed(int(str(t%1)[2:7]))#seed according to time
        randimid = np.random.randint(len(self.filenames))
        filename = self.filenames[randimid]
        imgs = np.load(filename)
        imgs=imgs[np.newaxis,...]
        imgs=imgs.astype(np.float32)
        label=self.labels[randimid]

        #label=torch.LongTensor(label_)
        
        return torch.from_numpy(imgs), label[0]
     

    def __len__(self):
        return len(self.labels)
        
         

 使用方式如下:

    dataset = data.DataBowl3Detector(
        datadir,
        'train.npy',
        config,
        phase = 'train')
     
    train_loader = DataLoader(
        dataset,
        batch_size = args.batch_size,
        shuffle = True,
        num_workers = args.workers,
        pin_memory=True)



    for epoch in range(start_epoch, args.epochs + 1):
        print('Epoch %2.4f'%epoch)
        print
        train(train_loader, net,loss, epoch, optimizer, get_lr, args.save_freq, save_dir)
        print
        validate(val_loader, net, loss)

 

你可能感兴趣的:(pytorch dataloader的使用)