pytorch学习笔记(一):ignite(训练模型的高级API)

例子

from argparse import ArgumentParser
from torch import nn
from torch.optim import SGD
from torchvision.transforms import Compose, ToTensor, Normalize

from ignite.engines import Events, create_supervised_trainer, create_supervised_evaluator
from ignite.metrics import CategoricalAccuracy, Loss

def run(train_batch_size, val_batch_size, epochs, lr, momentum, log_interval):
    cuda = torch.cuda.is_available()
    train_loader, val_loader = get_data_loaders(train_batch_size, val_batch_size)

    model = Net()
    if cuda:
        model = model.cuda()
    optimizer = SGD(model.parameters(), lr=lr, momentum=momentum)
    trainer = create_supervised_trainer(model, optimizer, F.nll_loss, cuda=cuda)
    evaluator = create_supervised_evaluator(model,
                                            metrics={'accuracy': CategoricalAccuracy(),
                                                     'nll': Loss(F.nll_loss)},
                                            cuda=cuda)

    @trainer.on(Events.ITERATION_COMPLETED)
    def log_training_loss(engine):
        iter = (engine.iteration - 1) % len(train_loader) + 1
        if iter % log_interval == 0:
            print("Epoch[{}] Iteration[{}/{}] Loss: {:.2f}".format(engine.state.epoch, iter, len(train_loader), engine.state.output))

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_validation_results(engine):
        metrics = evaluator.run(val_loader).metrics
        avg_accuracy = metrics['accuracy']
        avg_nll = metrics['nll']
        print("Validation Results - Epoch: {}  Avg accuracy: {:.2f} Avg loss: {:.2f}"
              .format(engine.state.epoch, avg_accuracy, avg_nll))

    trainer.run(train_loader, max_epochs=epochs)
--------------------- 

流程

创建模型, 创建 Dataloader

创建 trainer

创建 evaluator

为一些事件注册函数, @trainer.on()

trainer.run()

Event

"""
类似枚举类, 定义了几个事件
"""
class Events(Enum):
    EPOCH_STARTED = "epoch_started"               # 当一个新的 epoch 开始时会触发此事件
    EPOCH_COMPLETED = "epoch_completed"           # 当一个 epoch 结束时, 会触发此事件
    STARTED = "started"                           # 开始训练模型是, 会触发此事件
    COMPLETED = "completed"                       # 当训练结束时, 会触发此事件
    ITERATION_STARTED = "iteration_started"       # 当一个 iteration 开始时, 会触发此事件
    ITERATION_COMPLETED = "iteration_completed"   # 当一个 iteration 结束时, 会触发此事件
    EXCEPTION_RAISED = "exception_raised"         # 当有异常发生时, 会触发此事件
--------------------- 

State

class State(object):
    def __init__(self, **kwargs):
        self.iteration = 0            # 记录 iteration
        self.output = None            # 当前 iteration 的 输出. 对于 Supervised Trainer 来说, 是 loss.
        self.batch = None             # 本次 iteration 的 mini-batch 样本
        for k, v in kwargs.items():   # 其它一些希望 State 记录下来的 状态
            setattr(self, k, v)
--------------------- 

create_supervised_trainer

def create_supervised_trainer(model, optimizer, loss_fn, cuda=False):
    """
    Factory function for creating a trainer for supervised models

    Args:
        model (torch.nn.Module): the model to train
        optimizer (torch.optim.Optimizer): the optimizer to use
        loss_fn (torch.nn loss function): the loss function to use
        cuda (bool, optional): whether or not to transfer batch to GPU (default: False)

    Returns:
        Trainer: a trainer instance with supervised update function
    """
--------------------- 

create_supervised_evaluator

def create_supervised_evaluator(model, metrics={}, cuda=False):
    """
    Factory function for creating an evaluator for supervised models

    Args:
        model (torch.nn.Module): the model to train
        metrics (dict of str: Metric): a map of metric names to Metrics
        cuda (bool, optional): whether or not to transfer batch to GPU (default: False)

    Returns:
        Evaluator: a evaluator instance with supervised inference function
    """
--------------------- 

Trainer

# 继承自 Engine
def __init__(self, process_function):
    pass 

"""
process_function 的 signature 是 func(batch)->anything
def func(batch): # batch会保存在 state.batch 中
    1. process batch
    2. forward compution
    3. compute loss
    4. computer gradient
    5. update parameters
    6. return loss or else # 返回的值会被保存在 state.output 中

"""


""" 为某事件注册函数, 当事件发生时, 此函数就会被调用
函数的 signature 必须是 def func(trainer, state)
"""
@trainer.on(...)
def some_func(trainer):
    pass

Trainer.run() # 训练模型
--------------------- 

Evaluator

# 继承自 Engine
def __init__(self, process_function):
    pass 

"""
process_function 的 signature 是 func(batch)->anything
def func(batch): # batch会保存在 state.batch 中
    1. process batch
    2. forward compution
    3. return something # 返回的值会被保存在 state.output 中,
    #  用来计算 Metric
"""



# 为 evaluator 一些事件注册 函数.
@evaluator.on(...) 
def func(evaluator):
    pass

Evaluator.run() # 执行计算. 返回 state
state.metrics # 验证集上 metrics 计算的结果都保存在这里
--------------------- 

 

你可能感兴趣的:(Python)