PyTorch 之数据读取及任务训练

主要包含3个类:
1、Dataset
2、DataLoader
3、DataLoaderIter
3者关系:1被封装进2,2被封装进3

torch.utils.data.Dataset

是一个抽象类, 自定义的Dataset需要继承它并且实现两个成员方法:

__getitm__()  ##读取数据(如图像)
__len__()  ##返回数据长度
def __getitem__(self, index): 
        img_path, label = self.data[index].img_path, self.data[index].label
        img = Image.open(img_path)
        return img, label

torch.utils.data.DataLoader

定义:

class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False)

主要参数:
1、dataset : 即上面自定义的dataset.
2、collate_fn: 这个函数用来打包batch,
3、num_worker: 非常简单的多线程方法, 设置为>=1, 则多线程预读数据.

定义成员变量,赋值给DataLoaderIter;

def __iter__(self):
        return DataLoaderIter(self)

torch.utils.data.DataLoaderIter

DataLoader是DataLoaderIter的框架
例子:

class CustomDataset(Dataset):
   # 自定义自己的dataset

dataset = CustomDataset()
dataloader = Dataloader(dataset, ...)

for data in dataloader:
   # training...

在for 循环里, 总共有三点操作:

  1. 调用了dataloader 的__iter__() 方法, 产生了一个DataLoaderIter
  2. 反复调用DataLoaderIter 的__next__()来得到batch, 具体操作就是, 多次调用dataset的__getitem__()方法 (如果num_worker>0就多线程调用), 然后用collate_fn来把它们打包成batch. 中间还会涉及到shuffle , 以及sample 的方法等
  3. 当数据读完后, next()抛出一个StopIteration异常, for循环结束, dataloader 失效.

在利用DataLoader加载数据时,一般:
先定义数据集,然后加载数据。如:
train_set = ****Dataset(root_dir=…, …) #自定义
train_loader = DataLoader(train_set, batch_size=…, pin_memory=True, num_workers=…, shuffle=True, collate_fn=…, drop_last=True)
if eval:
eval_set = …
eval_loader = DataLoader(eval_set, …)

logger.info 可输出信息
在train一个任务时,步骤:

  1. 参数定义(包含一些超参数、路径设置、文件创建、logger定义)
  2. 加载数据(如利用DataLoader)
  3. 定义模型model
  4. 创建优化器:
optimizer = optim.Adam(model.parameters(), lr=..., weight_decay=...)
## optimizer = optim.Adam(model.parameters(), lr=..., weight_decay=..., momentum=...) 

或者别的方法
5. 如果用多个GPU,可以:

model = nn.DataParallel(model)
model.cuda()
  1. 如果可能(之前训练时保存的checkpoint),可以加载checkpoint:
pure_model = model.module if isinstance(model, torch.nn.DataParallel) else model
checkpoint = torch.load(filename)  ## 包含'epoch'、'it'、‘model_state’和‘optimizer_state’等信息)
model.load_state_dict(checkpoint['model_state'])   ##此处的model应该为pure_model
optimizer.load_state_dict(checkpoint['optimizer_state'])
  1. 创建scheduler(任务调度器), 总步数 total_steps= len(train_loader)*epochs.
  2. 训练

你可能感兴趣的:(PyTorch 之数据读取及任务训练)