主要包含3个类:
1、Dataset
2、DataLoader
3、DataLoaderIter
3者关系:1被封装进2,2被封装进3
是一个抽象类, 自定义的Dataset需要继承它并且实现两个成员方法:
__getitm__() ##读取数据(如图像)
__len__() ##返回数据长度
def __getitem__(self, index):
img_path, label = self.data[index].img_path, self.data[index].label
img = Image.open(img_path)
return img, label
定义:
class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False)
主要参数:
1、dataset : 即上面自定义的dataset.
2、collate_fn: 这个函数用来打包batch,
3、num_worker: 非常简单的多线程方法, 设置为>=1, 则多线程预读数据.
定义成员变量,赋值给DataLoaderIter;
def __iter__(self):
return DataLoaderIter(self)
DataLoader是DataLoaderIter的框架
例子:
class CustomDataset(Dataset):
# 自定义自己的dataset
dataset = CustomDataset()
dataloader = Dataloader(dataset, ...)
for data in dataloader:
# training...
在for 循环里, 总共有三点操作:
在利用DataLoader加载数据时,一般:
先定义数据集,然后加载数据。如:
train_set = ****Dataset(root_dir=…, …) #自定义
train_loader = DataLoader(train_set, batch_size=…, pin_memory=True, num_workers=…, shuffle=True, collate_fn=…, drop_last=True)
if eval:
eval_set = …
eval_loader = DataLoader(eval_set, …)
logger.info 可输出信息
在train一个任务时,步骤:
optimizer = optim.Adam(model.parameters(), lr=..., weight_decay=...)
## optimizer = optim.Adam(model.parameters(), lr=..., weight_decay=..., momentum=...)
或者别的方法
5. 如果用多个GPU,可以:
model = nn.DataParallel(model)
model.cuda()
pure_model = model.module if isinstance(model, torch.nn.DataParallel) else model
checkpoint = torch.load(filename) ## 包含'epoch'、'it'、‘model_state’和‘optimizer_state’等信息)
model.load_state_dict(checkpoint['model_state']) ##此处的model应该为pure_model
optimizer.load_state_dict(checkpoint['optimizer_state'])