Pytorch的ignite库是一个high-level封装训练和测试代码的库,使用库里的对象和函数,我们就会更加简洁的写出训练和测试模型的代码,下面先给出具体的使用例子:
from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator
from ignite.metrics import Accuracy, Loss
model = Net()
train_loader, val_loader = get_data_loaders(train_batch_size, val_batch_size)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.8)
criterion = nn.NLLLoss()
trainer = create_supervised_trainer(model, optimizer, criterion)
val_metrics = {
"accuracy": Accuracy(),
"nll": Loss(criterion)
}
evaluator = create_supervised_evaluator(model, metrics=val_metrics)
@trainer.on(Events.ITERATION_COMPLETED(every=log_interval))
def log_training_loss(trainer):
print("Epoch[{}] Loss: {:.2f}".format(trainer.state.epoch, trainer.state.output))
@trainer.on(Events.EPOCH_COMPLETED)
def log_training_results(trainer):
evaluator.run(train_loader)
metrics = evaluator.state.metrics
print("Training Results - Epoch: {} Avg accuracy: {:.2f} Avg loss: {:.2f}"
.format(trainer.state.epoch, metrics["accuracy"], metrics["nll"]))
@trainer.on(Events.EPOCH_COMPLETED)
def log_validation_results(trainer):
evaluator.run(val_loader)
metrics = evaluator.state.metrics
print("Validation Results - Epoch: {} Avg accuracy: {:.2f} Avg loss: {:.2f}"
.format(trainer.state.epoch, metrics["accuracy"], metrics["nll"]))
trainer.run(train_loader, max_epochs=100)
(1)首先定义好基本元件:model,dataloader,optimizer,loss_criterion。
(2)使用create_supervised_trainer
创建trainer engine,传入model, optimizer, criterion
。
(3)使用create_supervised_evaluator
创建evaluator engine,传入model, metrics
,其中metrics
是一个字典来存储需要度量的指标。
(4)@trainer.on(Events.ITERATION_COMPLETED(every=log_interval))
是触发迭代log_interval
(是个整数)step结束后,触发函数def log_training_loss(trainer):
。这里代码段:
@trainer.on(Events.ITERATION_COMPLETED(every=log_interval))
def log_training_loss(engine):
print("Epoch[{}] Loss: {:.2f}".format(engine.state.epoch, engine.state.output))
等价于
def log_training_loss(engine):
print("Epoch[{}] Loss: {:.2f}".format(engine.state.epoch, engine.state.output))
trainer.add_event_handler(Events.ITERATION_COMPLETED, log_training_loss)
(5)最后我们使用trainer.run
启动trainer engine来进行训练,epoch数为100。
以上两种函数都可以重写且函数名可以修改,具有如下规则:
规则是在函数内部实现一个def _update(engine, batch)
,关键参数名engine和batch不要动,内容可以修改,最后返回Engine(_update)
,一个Engine对象传入_update方法。
def create_supervised_trainer(model, optimizer, loss_fn,
device=None):
"""
Factory function for creating a trainer for supervised models
Args:
model (`torch.nn.Module`): the model to train
optimizer (`torch.optim.Optimizer`): the optimizer to use
loss_fn (torch.nn loss function): the loss function to use
device (str, optional): device type specification (default: None).
Applies to both model and batches.
Returns:
Engine: a trainer engine with supervised update function
"""
if device:
if torch.cuda.device_count() > 1:
model = nn.DataParallel(model)
model.to(device)
def _update(engine, batch):
model.train()
optimizer.zero_grad()
img, target = batch
img = img.to(device) if torch.cuda.device_count() >= 1 else img
target = target.to(device) if torch.cuda.device_count() >= 1 else target
score, feat = model(img)
loss = loss_fn(score, feat, target)
loss.backward()
optimizer.step()
# compute acc
acc = (score.max(1)[1] == target).float().mean()
return loss.item(), acc.item()
return Engine(_update)
原始实现:https://pytorch.org/ignite/_modules/ignite/engine.html#create_supervised_evaluator
def create_supervised_evaluator(model, metrics,
device=None):
"""
Factory function for creating an evaluator for supervised models
Args:
model (`torch.nn.Module`): the model to train
metrics (dict of str - :class:`ignite.metrics.Metric`): a map of metric names to Metrics
device (str, optional): device type specification (default: None).
Applies to both model and batches.
Returns:
Engine: an evaluator engine with supervised inference function
"""
if device:
if torch.cuda.device_count() > 1:
model = nn.DataParallel(model)
model.to(device)
def _inference(engine, batch):
model.eval()
with torch.no_grad():
data, pids, camids = batch
data = data.to(device) if torch.cuda.device_count() >= 1 else data
feat = model(data)
return feat, pids, camids
engine = Engine(_inference)
for name, metric in metrics.items():
metric.attach(engine, name)
return engine
规则是在函数内部实现一个def _inference(engine, batch)
,关键参数名engine和batch不要动,内容可以修改,最后返回Engine(_inference)
,一个Engine对象传入_inference方法。此外metric.attach()
里的metric是ignite.metrics下的对象,会于evaluator绑定,在run
时计算metric。
Event类定义了几种变量用于指示事件。
class Events(Enum):
EPOCH_STARTED = "epoch_started" # 当一个新的 epoch 开始时会触发此事件
EPOCH_COMPLETED = "epoch_completed" # 当一个 epoch 结束时, 会触发此事件
STARTED = "started" # 开始训练模型是, 会触发此事件
COMPLETED = "completed" # 当训练结束时, 会触发此事件
ITERATION_STARTED = "iteration_started" # 当一个 iteration 开始时, 会触发此事件
ITERATION_COMPLETED = "iteration_completed" # 当一个 iteration 结束时, 会触发此事件
EXCEPTION_RAISED = "exception_raised" # 当有异常发生时, 会触发此事件