pytroch五——engine

总览

pytroch五——engine_第1张图片

engine

pytroch五——engine_第2张图片

  • inference.py : 测试时运行
  • trainer.py: 训练模型时运行

trainer.py

完整代码见链接:这里为了简单,赋值命令和logging操作,SummaryWriter操作就会省去:

一、import

import logging

import matplotlib.pyplot as plt
import numpy as np
import torch
from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator
from ignite.handlers import ModelCheckpoint, Timer
from ignite.metrics import Loss, RunningAverage
from tensorboardX import SummaryWriter

from data.transforms import build_untransform
from data.transforms.transforms import COLORMAP
from utils.metric import Label_Accuracy

plt.switch_backend('agg')

二、定义do_train函数

def do_train(
        cfg,
        model,
        train_loader,
        val_loader,
        optimizer,
        loss_fn
):  # 以上均为实例

创建train_engine

trainer = create_supervised_trainer(model, optimizer, loss_fn, device=device)

通过该函数创建trainer,这个负责模型的训练,像以往的for index , data enumrater(dataloader)这类操作抽象化:

def create_supervised_trainer(model, optimizer, loss_fn,
                              device=None, non_blocking=False,
                              prepare_batch=_prepare_batch,
                              output_transform=lambda x, y, y_pred, loss: loss.item()):
    """
    Factory function for creating a trainer for supervised models.

    Args:
        model (`torch.nn.Module`): the model to train.
        optimizer (`torch.optim.Optimizer`): the optimizer to use.
        loss_fn (torch.nn loss function): the loss function to use.
        device (str, optional): device type specification (default: None).
            Applies to both model and batches.
        non_blocking (bool, optional): if True and this copy is between CPU and GPU, the copy may occur asynchronously
            with respect to the host. For other cases, this argument has no effect. # 默认即可
        prepare_batch (callable, optional): function that receives `batch`, `device`, `non_blocking` and outputs
            tuple of tensors `(batch_x, batch_y)`.  # 可调用函数
        output_transform (callable, optional): function that receives 'x', 'y', 'y_pred', 'loss' and returns value
            to be assigned to engine's state.output after each iteration. Default is returning `loss.item()`.

    Note: `engine.state.output` for this engine is defind by `output_transform` parameter and is the loss
        of the processed batch by default.

    Returns:
        Engine: a trainer engine with supervised update function.
    """
    if device:
        model.to(device)

    def _update(engine, batch):  # optimizer置零——样本——y_pred——计算loss——计算梯度——更新参数——输出loss
        model.train()
        optimizer.zero_grad()
        x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
        y_pred = model(x)
        loss = loss_fn(y_pred, y)
        loss.backward()
        optimizer.step()
        return output_transform(x, y, y_pred, loss)

    return Engine(_update)

创建evaluater
这里需要以dict的形式传入metric,

evaluator = create_supervised_evaluator(model, metrics={'mean_iu': Label_Accuracy(cfg.MODEL.NUM_CLASSES),
                                                            'loss': Loss(loss_fn)}, device=device)
def _prepare_batch(batch, device=None, non_blocking=False):
    """Prepare batch for training: pass to a device with options.

    """
    x, y = batch
    return (convert_tensor(x, device=device, non_blocking=non_blocking),
            convert_tensor(y, device=device, non_blocking=non_blocking))


def create_supervised_trainer(model, optimizer, loss_fn,
                              device=None, non_blocking=False,
                              prepare_batch=_prepare_batch,
                              output_transform=lambda x, y, y_pred, loss: loss.item()):
    """
    Factory function for creating a trainer for supervised models.

    Args:
        model (`torch.nn.Module`): the model to train.
        optimizer (`torch.optim.Optimizer`): the optimizer to use.
        loss_fn (torch.nn loss function): the loss function to use.
        device (str, optional): device type specification (default: None).
            Applies to both model and batches.
        non_blocking (bool, optional): if True and this copy is between CPU and GPU, the copy may occur asynchronously
            with respect to the host. For other cases, this argument has no effect.
        prepare_batch (callable, optional): function that receives `batch`, `device`, `non_blocking` and outputs
            tuple of tensors `(batch_x, batch_y)`.
        output_transform (callable, optional): function that receives 'x', 'y', 'y_pred', 'loss' and returns value
            to be assigned to engine's state.output after each iteration. Default is returning `loss.item()`.

    Note: `engine.state.output` for this engine is defind by `output_transform` parameter and is the loss
        of the processed batch by default.

    Returns:
        Engine: a trainer engine with supervised update function.
    """
    if device:
        model.to(device)

    def _update(engine, batch):
        model.train()
        optimizer.zero_grad()
        x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
        y_pred = model(x)
        loss = loss_fn(y_pred, y)
        loss.backward()
        optimizer.step()
        return output_transform(x, y, y_pred, loss)

    return Engine(_update)

Loss

注意到传入creater_evaluatior中的损失函数Loss()

class Loss(Metric):
    """
    Calculates the average loss according to the passed loss_fn.

    Args:
        loss_fn (callable): a callable taking a prediction tensor, a target
            tensor, optionally other arguments, and returns the average loss
            over all observations in the batch.
        output_transform (callable): a callable that is used to transform the
            :class:`~ignite.engine.Engine`'s `process_function`'s output into the
            form expected by the metric.
            This can be useful if, for example, you have a multi-output model and
            you want to compute the metric with respect to one of the outputs.
            The output is is expected to be a tuple (prediction, target) or
            (prediction, target, kwargs) where kwargs is a dictionary of extra
            keywords arguments.
        batch_size (callable): a callable taking a target tensor that returns the
            first dimension size (usually the batch size).

    """

    def __init__(self, loss_fn, output_transform=lambda x: x,
                 batch_size=lambda x: x.shape[0]):
        super(Loss, self).__init__(output_transform)
        self._loss_fn = loss_fn
        self._batch_size = batch_size

    def reset(self):
        self._sum = 0
        self._num_examples = 0

    def update(self, output):
        if len(output) == 2:
            y_pred, y = output
            kwargs = {}
        else:
            y_pred, y, kwargs = output
        average_loss = self._loss_fn(y_pred, y, **kwargs)

        if len(average_loss.shape) != 0:
            raise ValueError('loss_fn did not return the average loss.')

        N = self._batch_size(y)
        self._sum += average_loss.item() * N
        self._num_examples += N

    def compute(self):
        if self._num_examples == 0:
            raise NotComputableError(
                'Loss must have at least one example before it can be computed.')
        return self._sum / self._num_examples

handle

1. ModelCheckpoint

保存文件时需要用到的类:

 checkpointer = ModelCheckpoint(output_dir, 'fcn', checkpoint_period, n_saved=10, require_empty=False)

函数原型:

class ModelCheckpoint(object):
    """ ModelCheckpoint handler can be used to periodically save objects to disk.

    This handler expects two arguments:

        - an :class:`~ignite.engine.Engine` object
        - a `dict` mapping names (`str`) to objects that should be saved to disk.

    See Notes and Examples for further details.

    Args:
        dirname (str):
            Directory path where objects will be saved.
        filename_prefix (str):
            Prefix for the filenames to which objects will be saved. See Notes
            for more details.
        save_interval (int, optional):
            if not None, objects will be saved to disk every `save_interval` calls to the handler.
            Exactly one of (`save_interval`, `score_function`) arguments must be provided.
        score_function (callable, optional):
            if not None, it should be a function taking a single argument,
            an :class:`~ignite.engine.Engine` object,
            and return a score (`float`). Objects with highest scores will be retained.
            Exactly one of (`save_interval`, `score_function`) arguments must be provided.
        score_name (str, optional):
            if `score_function` not None, it is possible to store its absolute value using `score_name`. See Notes for
            more details.
        n_saved (int, optional):
            Number of objects that should be kept on disk. Older files will be removed.
        atomic (bool, optional):
            If True, objects are serialized to a temporary file,
            and then moved to final destination, so that files are
            guaranteed to not be damaged (for example if exception occures during saving).
        require_empty (bool, optional):
            If True, will raise exception if there are any files starting with `filename_prefix`
            in the directory 'dirname'.
        create_dir (bool, optional):
            If True, will create directory 'dirname' if it doesnt exist.
        save_as_state_dict (bool, optional):
            If True, will save only the `state_dict` of the objects specified, otherwise the whole object will be saved.

    Notes:
          This handler expects two arguments: an :class:`~ignite.engine.Engine` object and a `dict`
          mapping names to objects that should be saved.

          These names are used to specify filenames for saved objects.
          Each filename has the following structure:
          `{filename_prefix}_{name}_{step_number}.pth`.
          Here, `filename_prefix` is the argument passed to the constructor,
          `name` is the key in the aforementioned `dict`, and `step_number`
          is incremented by `1` with every call to the handler.

          If `score_function` is provided, user can store its absolute value using `score_name` in the filename.
          Each filename can have the following structure:
          `{filename_prefix}_{name}_{step_number}_{score_name}={abs(score_function_result)}.pth`.
          For example, `score_name="val_loss"` and `score_function` that returns `-loss` (as objects with highest scores
          will be retained), then saved models filenames will be `model_resnet_10_val_loss=0.1234.pth`.

    Examples:
        >>> import os
        >>> from ignite.engine import Engine, Events
        >>> from ignite.handlers import ModelCheckpoint
        >>> from torch import nn
        >>> trainer = Engine(lambda batch: None)
        >>> handler = ModelCheckpoint('/tmp/models', 'myprefix', save_interval=2, n_saved=2, create_dir=True)
        >>> model = nn.Linear(3, 3)
        >>> trainer.add_event_handler(Events.EPOCH_COMPLETED, handler, {'mymodel': model})
        >>> trainer.run([0], max_epochs=6)
        >>> os.listdir('/tmp/models')
        ['myprefix_mymodel_4.pth', 'myprefix_mymodel_6.pth']
    """

2.timer

时间管理相关类:

  timer = Timer(average=True)

函数原型:

class Timer:
    """ Timer object can be used to measure (average) time between events.

    Args:
        average (bool, optional): if True, then when ``.value()`` method is called, the returned value
            will be equal to total time measured, divided by the value of internal counter.

    Attributes:
        total (float): total time elapsed when the Timer was running (in seconds).
        step_count (int): internal counter, usefull to measure average time, e.g. of processing a single batch.
            Incremented with the ``.step()`` method.
        running (bool): flag indicating if timer is measuring time.

    Notes:
        When using ``Timer(average=True)`` do not forget to call ``timer.step()`` everytime an event occurs. See
        the examples below.

    Examples:

        Measuring total time of the epoch:

        >>> from ignite.handlers import Timer
        >>> import time
        >>> work = lambda : time.sleep(0.1)
        >>> idle = lambda : time.sleep(0.1)
        >>> t = Timer(average=False)
        >>> for _ in range(10):
        ...    work()
        ...    idle()
        ...
        >>> t.value()
        2.003073937026784

        Measuring average time of the epoch:

        >>> t = Timer(average=True)
        >>> for _ in range(10):
        ...    work()
        ...    idle()
        ...    t.step()
        ...
        >>> t.value()
        0.2003182829997968

        Measuring average time it takes to execute a single ``work()`` call:

        >>> t = Timer(average=True)
        >>> for _ in range(10):
        ...    t.resume()
        ...    work()
        ...    t.pause()
        ...    idle()
        ...    t.step()
        ...
        >>> t.value()
        0.10016545779653825

        Using the Timer to measure average time it takes to process a single batch of examples:

        >>> from ignite.engine import Engine, Events
        >>> from ignite.handlers import Timer
        >>> trainer = Engine(training_update_function)
        >>> timer = Timer(average=True)
        >>> timer.attach(trainer,
        ...              start=Events.EPOCH_STARTED,
        ...              resume=Events.ITERATION_STARTED,
        ...              pause=Events.ITERATION_COMPLETED,
        ...              step=Events.ITERATION_COMPLETED)
    """

3.RunningAverage

    # automatically adding handlers via a special `attach` method of `RunningAverage` handler
    # 这里有一种滑动平均的感觉
    RunningAverage(output_transform=lambda x: x).attach(trainer, 'avg_loss')

这里的attach,我们可以看下函数原型:

class RunningAverage(Metric):
    """Compute running average of a metric or the output of process function.

    Args:
        src (Metric or None): input source: an instance of :class:`~ignite.metrics.Metric` or None. The latter
            corresponds to `engine.state.output` which holds the output of process function.
        alpha (float, optional): running average decay factor, default 0.98
        output_transform (callable, optional): a function to use to transform the output if `src` is None and
            corresponds the output of process function. Otherwise it should be None.

    Examples:

    .. code-block:: python

        alpha = 0.98
        acc_metric = RunningAverage(Accuracy(output_transform=lambda x: [x[1], x[2]]), alpha=alpha)
        acc_metric.attach(trainer, 'running_avg_accuracy')

        avg_output = RunningAverage(output_transform=lambda x: x[0], alpha=alpha)
        avg_output.attach(trainer, 'running_avg_loss')

        @trainer.on(Events.ITERATION_COMPLETED)
        def log_running_avg_metrics(engine):
            print("running avg accuracy:", engine.state.metrics['running_avg_accuracy'])
            print("running avg loss:", engine.state.metrics['running_avg_loss'])

    """

    def __init__(self, src=None, alpha=0.98, output_transform=None):   # 接受metric对象
        if not (isinstance(src, Metric) or src is None):
            raise TypeError("Argument src should be a Metric or None.")
        if not (0.0 < alpha <= 1.0):
            raise ValueError("Argument alpha should be a float between 0.0 and 1.0.")

        if isinstance(src, Metric):  # src为metric,则output_transform = None
            if output_transform is not None:
                raise ValueError("Argument output_transform should be None if src is a Metric.")
            self.src = src
            self._get_src_value = self._get_metric_value
            self.iteration_completed = self._metric_iteration_completed
        else:
            if output_transform is None:
                raise ValueError("Argument output_transform should not be None if src corresponds "
                                 "to the output of process function.")
            self._get_src_value = self._get_output_value
            self.update = self._output_update

        self.alpha = alpha
        super(RunningAverage, self).__init__(output_transform=output_transform)

    def reset(self):
        self._value = None

    def update(self, output):
        # Implement abstract method
        pass

    def compute(self):
        if self._value is None:
            self._value = self._get_src_value()
        else:
            self._value = self._value * self.alpha + (1.0 - self.alpha) * self._get_src_value()
        return self._value

    def attach(self, engine, name):
        # restart average every epoch
        engine.add_event_handler(Events.EPOCH_STARTED, self.started)
        # compute metric
        engine.add_event_handler(Events.ITERATION_COMPLETED, self.iteration_completed)
        # apply running average
        engine.add_event_handler(Events.ITERATION_COMPLETED, self.completed, name)

    def _get_metric_value(self):
        return self.src.compute()

    def _get_output_value(self):
        return self.src

    def _metric_iteration_completed(self, engine):
        self.src.started(engine)
        self.src.iteration_completed(engine)

    def _output_update(self, output):
        self.src = output

add handle to event

  1. 模型保存函数添加到epcoh_completerd事件中,
    trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpointer, {'model': model.state_dict(),
                                                                     'optimizer': optimizer.state_dict()})

Event类的成员函数add_event_handler:

 def add_event_handler(self, event_name, handler, *args, **kwargs):
        """Add an event handler to be executed when the specified event is fired.

        Args:
            event_name: An event to attach the handler to. Valid events are from :class:`~ignite.engine.Events`
                or any `event_name` added by :meth:`~ignite.engine.Engine.register_events`.
            handler (callable): the callable event handler that should be invoked
            *args: optional args to be passed to `handler`.
            **kwargs: optional keyword args to be passed to `handler`.

        Notes:
              The handler function's first argument will be `self`, the :class:`~ignite.engine.Engine` object it
              was bound to.

              Note that other arguments can be passed to the handler in addition to the `*args` and  `**kwargs`
              passed here, for example during :attr:`~ignite.engine.Events.EXCEPTION_RAISED`.

        Example usage:

        .. code-block:: python

            engine = Engine(process_function)

            def print_epoch(engine):
                print("Epoch: {}".format(engine.state.epoch))

            engine.add_event_handler(Events.EPOCH_COMPLETED, print_epoch)

        """
        if event_name not in self._allowed_events:
            self._logger.error("attempt to add event handler to an invalid event %s.", event_name)
            raise ValueError("Event {} is not a valid event for this Engine.".format(event_name))

        event_args = (Exception(), ) if event_name == Events.EXCEPTION_RAISED else ()
        self._check_signature(handler, 'handler', *(event_args + args), **kwargs)

        self._event_handlers[event_name].append((handler, args, kwargs))
        self._logger.debug("added handler for event %s.", event_name)

就上面{'model': model.state_dict(), 'optimizer': optimizer.state_dict() 则一部分,以参数的形式,传递给checker,那么checker将会调用call函数。之后,不清楚
2. timer attain 到 train

    # automatically adding handlers via a special `attach` method of `Timer` handler
    timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED,
                 pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED)

timer的attain成员函数:

    def attach(self, engine, start=Events.STARTED, pause=Events.COMPLETED, resume=None, step=None):
        """ Register callbacks to control the timer.

        Args:
            engine (Engine):
                Engine that this timer will be attached to.
            start (Events):
                Event which should start (reset) the timer.
            pause (Events):
                Event which should pause the timer.
            resume (Events, optional):
                Event which should resume the timer.
            step (Events, optional):
                Event which should call the `step` method of the counter.

        Returns:
            self (Timer)

        """

        engine.add_event_handler(start, self.reset)
        engine.add_event_handler(pause, self.pause)

        if resume is not None:
            engine.add_event_handler(resume, self.resume)

        if step is not None:
            engine.add_event_handler(step, self.step)

記錄事件:

    # adding handlers using `trainer.on` decorator API
    @trainer.on(Events.ITERATION_COMPLETED)
    def log_training_loss(engine):
        iter = (engine.state.iteration - 1) % len(train_loader) + 1

        if iter % log_period == 0:
            logger.info("Epoch[{}] Iteration[{}/{}] Loss: {:.3f}"
                        .format(engine.state.epoch, iter, len(train_loader), engine.state.metrics['avg_loss']))
            writer.add_scalars("loss", {'train': engine.state.metrics['avg_loss']}, engine.state.iteration)

    # adding handlers using `trainer.on` decorator API
    @trainer.on(Events.EPOCH_COMPLETED)
    def log_training_results(engine):
        evaluator.run(train_loader)
        metrics = evaluator.state.metrics
        mean_iu = metrics['mean_iu']
        avg_loss = metrics['loss']
        logger.info("Training Results - Epoch: {} Mean IU: {:.3f} Avg Loss: {:.3f}"
                    .format(engine.state.epoch, mean_iu, avg_loss))
        writer.add_scalars("mean_iu", {'train': mean_iu}, engine.state.epoch)

    if val_loader is not None:
        # adding handlers using `trainer.on` decorator API
        @trainer.on(Events.EPOCH_COMPLETED)
        def log_validation_results(engine):
            evaluator.run(val_loader)
            metrics = evaluator.state.metrics
            mean_iu = metrics['mean_iu']
            avg_loss = metrics['loss']
            logger.info("Validation Results - Epoch: {} Mean IU: {:.3f} Avg Loss: {:.3f}"
                        .format(engine.state.epoch, mean_iu, avg_loss)
                        )
            writer.add_scalars("loss", {'validation': avg_loss}, engine.state.iteration)
            writer.add_scalars("mean_iu", {'validation': mean_iu}, engine.state.epoch)

    # adding handlers using `trainer.on` decorator API
    @trainer.on(Events.EPOCH_COMPLETED)
    def print_times(engine):
        logger.info('Epoch {} done. Time per batch: {:.3f}[s] Speed: {:.1f}[samples/s]'
                    .format(engine.state.epoch, timer.value() * timer.step_count,
                            train_loader.batch_size / timer.value()))
        timer.reset()

    @trainer.on(Events.EPOCH_COMPLETED)
    def plot_output(engine):
        model.eval()
        dataset = val_loader.dataset
        idx = np.random.choice(np.arange(len(dataset)), size=1).item()
        val_x, val_y = dataset[idx]
        val_x = val_x.to(device)
        with torch.no_grad():
            pred_y = model(val_x.unsqueeze(0))

        orig_img, val_y = untransform(val_x.cpu().data, val_y)
        pred_y = pred_y.max(1)[1].cpu().data[0].numpy()
        pred_val = cm[pred_y]
        seg_val = cm[val_y]

        # matplotlib
        fig = plt.figure(figsize=(9, 3))
        plt.subplot(131)
        plt.imshow(orig_img)
        plt.axis("off")

        plt.subplot(132)
        plt.imshow(seg_val)
        plt.axis("off")

        plt.subplot(133)
        plt.imshow(pred_val)
        plt.axis("off")
        writer.add_figure('show_result', fig, engine.state.iteration)

    trainer.run(train_loader, max_epochs=epochs)
    writer.close()

你可能感兴趣的:(pytorch)