from argparse import ArgumentParser
from torch import nn
from torch.optim import SGD
from torchvision.transforms import Compose, ToTensor, Normalize
from ignite.engines import Events, create_supervised_trainer, create_supervised_evaluator
from ignite.metrics import CategoricalAccuracy, Loss
def run(train_batch_size, val_batch_size, epochs, lr, momentum, log_interval):
cuda = torch.cuda.is_available()
train_loader, val_loader = get_data_loaders(train_batch_size, val_batch_size)
model = Net()
if cuda:
model = model.cuda()
optimizer = SGD(model.parameters(), lr=lr, momentum=momentum)
trainer = create_supervised_trainer(model, optimizer, F.nll_loss, cuda=cuda)
evaluator = create_supervised_evaluator(model,
metrics={'accuracy': CategoricalAccuracy(),
'nll': Loss(F.nll_loss)},
cuda=cuda)
@trainer.on(Events.ITERATION_COMPLETED)
def log_training_loss(engine):
iter = (engine.iteration - 1) % len(train_loader) + 1
if iter % log_interval == 0:
print("Epoch[{}] Iteration[{}/{}] Loss: {:.2f}".format(engine.state.epoch, iter, len(train_loader), engine.state.output))
@trainer.on(Events.EPOCH_COMPLETED)
def log_validation_results(engine):
metrics = evaluator.run(val_loader).metrics
avg_accuracy = metrics['accuracy']
avg_nll = metrics['nll']
print("Validation Results - Epoch: {} Avg accuracy: {:.2f} Avg loss: {:.2f}"
.format(engine.state.epoch, avg_accuracy, avg_nll))
trainer.run(train_loader, max_epochs=epochs)
---------------------
创建模型, 创建 Dataloader
创建 trainer
创建 evaluator
为一些事件注册函数, @trainer.on()
trainer.run()
"""
类似枚举类, 定义了几个事件
"""
class Events(Enum):
EPOCH_STARTED = "epoch_started" # 当一个新的 epoch 开始时会触发此事件
EPOCH_COMPLETED = "epoch_completed" # 当一个 epoch 结束时, 会触发此事件
STARTED = "started" # 开始训练模型是, 会触发此事件
COMPLETED = "completed" # 当训练结束时, 会触发此事件
ITERATION_STARTED = "iteration_started" # 当一个 iteration 开始时, 会触发此事件
ITERATION_COMPLETED = "iteration_completed" # 当一个 iteration 结束时, 会触发此事件
EXCEPTION_RAISED = "exception_raised" # 当有异常发生时, 会触发此事件
---------------------
class State(object):
def __init__(self, **kwargs):
self.iteration = 0 # 记录 iteration
self.output = None # 当前 iteration 的 输出. 对于 Supervised Trainer 来说, 是 loss.
self.batch = None # 本次 iteration 的 mini-batch 样本
for k, v in kwargs.items(): # 其它一些希望 State 记录下来的 状态
setattr(self, k, v)
---------------------
def create_supervised_trainer(model, optimizer, loss_fn, cuda=False):
"""
Factory function for creating a trainer for supervised models
Args:
model (torch.nn.Module): the model to train
optimizer (torch.optim.Optimizer): the optimizer to use
loss_fn (torch.nn loss function): the loss function to use
cuda (bool, optional): whether or not to transfer batch to GPU (default: False)
Returns:
Trainer: a trainer instance with supervised update function
"""
---------------------
def create_supervised_evaluator(model, metrics={}, cuda=False):
"""
Factory function for creating an evaluator for supervised models
Args:
model (torch.nn.Module): the model to train
metrics (dict of str: Metric): a map of metric names to Metrics
cuda (bool, optional): whether or not to transfer batch to GPU (default: False)
Returns:
Evaluator: a evaluator instance with supervised inference function
"""
---------------------
# 继承自 Engine
def __init__(self, process_function):
pass
"""
process_function 的 signature 是 func(batch)->anything
def func(batch): # batch会保存在 state.batch 中
1. process batch
2. forward compution
3. compute loss
4. computer gradient
5. update parameters
6. return loss or else # 返回的值会被保存在 state.output 中
"""
""" 为某事件注册函数, 当事件发生时, 此函数就会被调用
函数的 signature 必须是 def func(trainer, state)
"""
@trainer.on(...)
def some_func(trainer):
pass
Trainer.run() # 训练模型
---------------------
# 继承自 Engine
def __init__(self, process_function):
pass
"""
process_function 的 signature 是 func(batch)->anything
def func(batch): # batch会保存在 state.batch 中
1. process batch
2. forward compution
3. return something # 返回的值会被保存在 state.output 中,
# 用来计算 Metric
"""
# 为 evaluator 一些事件注册 函数.
@evaluator.on(...)
def func(evaluator):
pass
Evaluator.run() # 执行计算. 返回 state
state.metrics # 验证集上 metrics 计算的结果都保存在这里
---------------------