完整代码见链接:这里为了简单,赋值命令和logging操作,SummaryWriter操作就会省去:
import logging
import matplotlib.pyplot as plt
import numpy as np
import torch
from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator
from ignite.handlers import ModelCheckpoint, Timer
from ignite.metrics import Loss, RunningAverage
from tensorboardX import SummaryWriter
from data.transforms import build_untransform
from data.transforms.transforms import COLORMAP
from utils.metric import Label_Accuracy
plt.switch_backend('agg')
def do_train(
cfg,
model,
train_loader,
val_loader,
optimizer,
loss_fn
): # 以上均为实例
创建train_engine
trainer = create_supervised_trainer(model, optimizer, loss_fn, device=device)
通过该函数创建trainer,这个负责模型的训练,像以往的for index , data enumrater(dataloader)
这类操作抽象化:
def create_supervised_trainer(model, optimizer, loss_fn,
device=None, non_blocking=False,
prepare_batch=_prepare_batch,
output_transform=lambda x, y, y_pred, loss: loss.item()):
"""
Factory function for creating a trainer for supervised models.
Args:
model (`torch.nn.Module`): the model to train.
optimizer (`torch.optim.Optimizer`): the optimizer to use.
loss_fn (torch.nn loss function): the loss function to use.
device (str, optional): device type specification (default: None).
Applies to both model and batches.
non_blocking (bool, optional): if True and this copy is between CPU and GPU, the copy may occur asynchronously
with respect to the host. For other cases, this argument has no effect. # 默认即可
prepare_batch (callable, optional): function that receives `batch`, `device`, `non_blocking` and outputs
tuple of tensors `(batch_x, batch_y)`. # 可调用函数
output_transform (callable, optional): function that receives 'x', 'y', 'y_pred', 'loss' and returns value
to be assigned to engine's state.output after each iteration. Default is returning `loss.item()`.
Note: `engine.state.output` for this engine is defind by `output_transform` parameter and is the loss
of the processed batch by default.
Returns:
Engine: a trainer engine with supervised update function.
"""
if device:
model.to(device)
def _update(engine, batch): # optimizer置零——样本——y_pred——计算loss——计算梯度——更新参数——输出loss
model.train()
optimizer.zero_grad()
x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
y_pred = model(x)
loss = loss_fn(y_pred, y)
loss.backward()
optimizer.step()
return output_transform(x, y, y_pred, loss)
return Engine(_update)
创建evaluater
这里需要以dict的形式传入metric,
evaluator = create_supervised_evaluator(model, metrics={'mean_iu': Label_Accuracy(cfg.MODEL.NUM_CLASSES),
'loss': Loss(loss_fn)}, device=device)
def _prepare_batch(batch, device=None, non_blocking=False):
"""Prepare batch for training: pass to a device with options.
"""
x, y = batch
return (convert_tensor(x, device=device, non_blocking=non_blocking),
convert_tensor(y, device=device, non_blocking=non_blocking))
def create_supervised_trainer(model, optimizer, loss_fn,
device=None, non_blocking=False,
prepare_batch=_prepare_batch,
output_transform=lambda x, y, y_pred, loss: loss.item()):
"""
Factory function for creating a trainer for supervised models.
Args:
model (`torch.nn.Module`): the model to train.
optimizer (`torch.optim.Optimizer`): the optimizer to use.
loss_fn (torch.nn loss function): the loss function to use.
device (str, optional): device type specification (default: None).
Applies to both model and batches.
non_blocking (bool, optional): if True and this copy is between CPU and GPU, the copy may occur asynchronously
with respect to the host. For other cases, this argument has no effect.
prepare_batch (callable, optional): function that receives `batch`, `device`, `non_blocking` and outputs
tuple of tensors `(batch_x, batch_y)`.
output_transform (callable, optional): function that receives 'x', 'y', 'y_pred', 'loss' and returns value
to be assigned to engine's state.output after each iteration. Default is returning `loss.item()`.
Note: `engine.state.output` for this engine is defind by `output_transform` parameter and is the loss
of the processed batch by default.
Returns:
Engine: a trainer engine with supervised update function.
"""
if device:
model.to(device)
def _update(engine, batch):
model.train()
optimizer.zero_grad()
x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
y_pred = model(x)
loss = loss_fn(y_pred, y)
loss.backward()
optimizer.step()
return output_transform(x, y, y_pred, loss)
return Engine(_update)
注意到传入creater_evaluatior
中的损失函数Loss()
class Loss(Metric):
"""
Calculates the average loss according to the passed loss_fn.
Args:
loss_fn (callable): a callable taking a prediction tensor, a target
tensor, optionally other arguments, and returns the average loss
over all observations in the batch.
output_transform (callable): a callable that is used to transform the
:class:`~ignite.engine.Engine`'s `process_function`'s output into the
form expected by the metric.
This can be useful if, for example, you have a multi-output model and
you want to compute the metric with respect to one of the outputs.
The output is is expected to be a tuple (prediction, target) or
(prediction, target, kwargs) where kwargs is a dictionary of extra
keywords arguments.
batch_size (callable): a callable taking a target tensor that returns the
first dimension size (usually the batch size).
"""
def __init__(self, loss_fn, output_transform=lambda x: x,
batch_size=lambda x: x.shape[0]):
super(Loss, self).__init__(output_transform)
self._loss_fn = loss_fn
self._batch_size = batch_size
def reset(self):
self._sum = 0
self._num_examples = 0
def update(self, output):
if len(output) == 2:
y_pred, y = output
kwargs = {}
else:
y_pred, y, kwargs = output
average_loss = self._loss_fn(y_pred, y, **kwargs)
if len(average_loss.shape) != 0:
raise ValueError('loss_fn did not return the average loss.')
N = self._batch_size(y)
self._sum += average_loss.item() * N
self._num_examples += N
def compute(self):
if self._num_examples == 0:
raise NotComputableError(
'Loss must have at least one example before it can be computed.')
return self._sum / self._num_examples
保存文件时需要用到的类:
checkpointer = ModelCheckpoint(output_dir, 'fcn', checkpoint_period, n_saved=10, require_empty=False)
函数原型:
class ModelCheckpoint(object):
""" ModelCheckpoint handler can be used to periodically save objects to disk.
This handler expects two arguments:
- an :class:`~ignite.engine.Engine` object
- a `dict` mapping names (`str`) to objects that should be saved to disk.
See Notes and Examples for further details.
Args:
dirname (str):
Directory path where objects will be saved.
filename_prefix (str):
Prefix for the filenames to which objects will be saved. See Notes
for more details.
save_interval (int, optional):
if not None, objects will be saved to disk every `save_interval` calls to the handler.
Exactly one of (`save_interval`, `score_function`) arguments must be provided.
score_function (callable, optional):
if not None, it should be a function taking a single argument,
an :class:`~ignite.engine.Engine` object,
and return a score (`float`). Objects with highest scores will be retained.
Exactly one of (`save_interval`, `score_function`) arguments must be provided.
score_name (str, optional):
if `score_function` not None, it is possible to store its absolute value using `score_name`. See Notes for
more details.
n_saved (int, optional):
Number of objects that should be kept on disk. Older files will be removed.
atomic (bool, optional):
If True, objects are serialized to a temporary file,
and then moved to final destination, so that files are
guaranteed to not be damaged (for example if exception occures during saving).
require_empty (bool, optional):
If True, will raise exception if there are any files starting with `filename_prefix`
in the directory 'dirname'.
create_dir (bool, optional):
If True, will create directory 'dirname' if it doesnt exist.
save_as_state_dict (bool, optional):
If True, will save only the `state_dict` of the objects specified, otherwise the whole object will be saved.
Notes:
This handler expects two arguments: an :class:`~ignite.engine.Engine` object and a `dict`
mapping names to objects that should be saved.
These names are used to specify filenames for saved objects.
Each filename has the following structure:
`{filename_prefix}_{name}_{step_number}.pth`.
Here, `filename_prefix` is the argument passed to the constructor,
`name` is the key in the aforementioned `dict`, and `step_number`
is incremented by `1` with every call to the handler.
If `score_function` is provided, user can store its absolute value using `score_name` in the filename.
Each filename can have the following structure:
`{filename_prefix}_{name}_{step_number}_{score_name}={abs(score_function_result)}.pth`.
For example, `score_name="val_loss"` and `score_function` that returns `-loss` (as objects with highest scores
will be retained), then saved models filenames will be `model_resnet_10_val_loss=0.1234.pth`.
Examples:
>>> import os
>>> from ignite.engine import Engine, Events
>>> from ignite.handlers import ModelCheckpoint
>>> from torch import nn
>>> trainer = Engine(lambda batch: None)
>>> handler = ModelCheckpoint('/tmp/models', 'myprefix', save_interval=2, n_saved=2, create_dir=True)
>>> model = nn.Linear(3, 3)
>>> trainer.add_event_handler(Events.EPOCH_COMPLETED, handler, {'mymodel': model})
>>> trainer.run([0], max_epochs=6)
>>> os.listdir('/tmp/models')
['myprefix_mymodel_4.pth', 'myprefix_mymodel_6.pth']
"""
时间管理相关类:
timer = Timer(average=True)
函数原型:
class Timer:
""" Timer object can be used to measure (average) time between events.
Args:
average (bool, optional): if True, then when ``.value()`` method is called, the returned value
will be equal to total time measured, divided by the value of internal counter.
Attributes:
total (float): total time elapsed when the Timer was running (in seconds).
step_count (int): internal counter, usefull to measure average time, e.g. of processing a single batch.
Incremented with the ``.step()`` method.
running (bool): flag indicating if timer is measuring time.
Notes:
When using ``Timer(average=True)`` do not forget to call ``timer.step()`` everytime an event occurs. See
the examples below.
Examples:
Measuring total time of the epoch:
>>> from ignite.handlers import Timer
>>> import time
>>> work = lambda : time.sleep(0.1)
>>> idle = lambda : time.sleep(0.1)
>>> t = Timer(average=False)
>>> for _ in range(10):
... work()
... idle()
...
>>> t.value()
2.003073937026784
Measuring average time of the epoch:
>>> t = Timer(average=True)
>>> for _ in range(10):
... work()
... idle()
... t.step()
...
>>> t.value()
0.2003182829997968
Measuring average time it takes to execute a single ``work()`` call:
>>> t = Timer(average=True)
>>> for _ in range(10):
... t.resume()
... work()
... t.pause()
... idle()
... t.step()
...
>>> t.value()
0.10016545779653825
Using the Timer to measure average time it takes to process a single batch of examples:
>>> from ignite.engine import Engine, Events
>>> from ignite.handlers import Timer
>>> trainer = Engine(training_update_function)
>>> timer = Timer(average=True)
>>> timer.attach(trainer,
... start=Events.EPOCH_STARTED,
... resume=Events.ITERATION_STARTED,
... pause=Events.ITERATION_COMPLETED,
... step=Events.ITERATION_COMPLETED)
"""
# automatically adding handlers via a special `attach` method of `RunningAverage` handler
# 这里有一种滑动平均的感觉
RunningAverage(output_transform=lambda x: x).attach(trainer, 'avg_loss')
这里的attach,我们可以看下函数原型:
class RunningAverage(Metric):
"""Compute running average of a metric or the output of process function.
Args:
src (Metric or None): input source: an instance of :class:`~ignite.metrics.Metric` or None. The latter
corresponds to `engine.state.output` which holds the output of process function.
alpha (float, optional): running average decay factor, default 0.98
output_transform (callable, optional): a function to use to transform the output if `src` is None and
corresponds the output of process function. Otherwise it should be None.
Examples:
.. code-block:: python
alpha = 0.98
acc_metric = RunningAverage(Accuracy(output_transform=lambda x: [x[1], x[2]]), alpha=alpha)
acc_metric.attach(trainer, 'running_avg_accuracy')
avg_output = RunningAverage(output_transform=lambda x: x[0], alpha=alpha)
avg_output.attach(trainer, 'running_avg_loss')
@trainer.on(Events.ITERATION_COMPLETED)
def log_running_avg_metrics(engine):
print("running avg accuracy:", engine.state.metrics['running_avg_accuracy'])
print("running avg loss:", engine.state.metrics['running_avg_loss'])
"""
def __init__(self, src=None, alpha=0.98, output_transform=None): # 接受metric对象
if not (isinstance(src, Metric) or src is None):
raise TypeError("Argument src should be a Metric or None.")
if not (0.0 < alpha <= 1.0):
raise ValueError("Argument alpha should be a float between 0.0 and 1.0.")
if isinstance(src, Metric): # src为metric,则output_transform = None
if output_transform is not None:
raise ValueError("Argument output_transform should be None if src is a Metric.")
self.src = src
self._get_src_value = self._get_metric_value
self.iteration_completed = self._metric_iteration_completed
else:
if output_transform is None:
raise ValueError("Argument output_transform should not be None if src corresponds "
"to the output of process function.")
self._get_src_value = self._get_output_value
self.update = self._output_update
self.alpha = alpha
super(RunningAverage, self).__init__(output_transform=output_transform)
def reset(self):
self._value = None
def update(self, output):
# Implement abstract method
pass
def compute(self):
if self._value is None:
self._value = self._get_src_value()
else:
self._value = self._value * self.alpha + (1.0 - self.alpha) * self._get_src_value()
return self._value
def attach(self, engine, name):
# restart average every epoch
engine.add_event_handler(Events.EPOCH_STARTED, self.started)
# compute metric
engine.add_event_handler(Events.ITERATION_COMPLETED, self.iteration_completed)
# apply running average
engine.add_event_handler(Events.ITERATION_COMPLETED, self.completed, name)
def _get_metric_value(self):
return self.src.compute()
def _get_output_value(self):
return self.src
def _metric_iteration_completed(self, engine):
self.src.started(engine)
self.src.iteration_completed(engine)
def _output_update(self, output):
self.src = output
trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpointer, {'model': model.state_dict(),
'optimizer': optimizer.state_dict()})
Event类的成员函数add_event_handler
:
def add_event_handler(self, event_name, handler, *args, **kwargs):
"""Add an event handler to be executed when the specified event is fired.
Args:
event_name: An event to attach the handler to. Valid events are from :class:`~ignite.engine.Events`
or any `event_name` added by :meth:`~ignite.engine.Engine.register_events`.
handler (callable): the callable event handler that should be invoked
*args: optional args to be passed to `handler`.
**kwargs: optional keyword args to be passed to `handler`.
Notes:
The handler function's first argument will be `self`, the :class:`~ignite.engine.Engine` object it
was bound to.
Note that other arguments can be passed to the handler in addition to the `*args` and `**kwargs`
passed here, for example during :attr:`~ignite.engine.Events.EXCEPTION_RAISED`.
Example usage:
.. code-block:: python
engine = Engine(process_function)
def print_epoch(engine):
print("Epoch: {}".format(engine.state.epoch))
engine.add_event_handler(Events.EPOCH_COMPLETED, print_epoch)
"""
if event_name not in self._allowed_events:
self._logger.error("attempt to add event handler to an invalid event %s.", event_name)
raise ValueError("Event {} is not a valid event for this Engine.".format(event_name))
event_args = (Exception(), ) if event_name == Events.EXCEPTION_RAISED else ()
self._check_signature(handler, 'handler', *(event_args + args), **kwargs)
self._event_handlers[event_name].append((handler, args, kwargs))
self._logger.debug("added handler for event %s.", event_name)
就上面{'model': model.state_dict(), 'optimizer': optimizer.state_dict()
则一部分,以参数的形式,传递给checker,那么checker将会调用call函数。之后,不清楚
2. timer attain 到 train
# automatically adding handlers via a special `attach` method of `Timer` handler
timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED,
pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED)
timer的attain成员函数:
def attach(self, engine, start=Events.STARTED, pause=Events.COMPLETED, resume=None, step=None):
""" Register callbacks to control the timer.
Args:
engine (Engine):
Engine that this timer will be attached to.
start (Events):
Event which should start (reset) the timer.
pause (Events):
Event which should pause the timer.
resume (Events, optional):
Event which should resume the timer.
step (Events, optional):
Event which should call the `step` method of the counter.
Returns:
self (Timer)
"""
engine.add_event_handler(start, self.reset)
engine.add_event_handler(pause, self.pause)
if resume is not None:
engine.add_event_handler(resume, self.resume)
if step is not None:
engine.add_event_handler(step, self.step)
記錄事件:
# adding handlers using `trainer.on` decorator API
@trainer.on(Events.ITERATION_COMPLETED)
def log_training_loss(engine):
iter = (engine.state.iteration - 1) % len(train_loader) + 1
if iter % log_period == 0:
logger.info("Epoch[{}] Iteration[{}/{}] Loss: {:.3f}"
.format(engine.state.epoch, iter, len(train_loader), engine.state.metrics['avg_loss']))
writer.add_scalars("loss", {'train': engine.state.metrics['avg_loss']}, engine.state.iteration)
# adding handlers using `trainer.on` decorator API
@trainer.on(Events.EPOCH_COMPLETED)
def log_training_results(engine):
evaluator.run(train_loader)
metrics = evaluator.state.metrics
mean_iu = metrics['mean_iu']
avg_loss = metrics['loss']
logger.info("Training Results - Epoch: {} Mean IU: {:.3f} Avg Loss: {:.3f}"
.format(engine.state.epoch, mean_iu, avg_loss))
writer.add_scalars("mean_iu", {'train': mean_iu}, engine.state.epoch)
if val_loader is not None:
# adding handlers using `trainer.on` decorator API
@trainer.on(Events.EPOCH_COMPLETED)
def log_validation_results(engine):
evaluator.run(val_loader)
metrics = evaluator.state.metrics
mean_iu = metrics['mean_iu']
avg_loss = metrics['loss']
logger.info("Validation Results - Epoch: {} Mean IU: {:.3f} Avg Loss: {:.3f}"
.format(engine.state.epoch, mean_iu, avg_loss)
)
writer.add_scalars("loss", {'validation': avg_loss}, engine.state.iteration)
writer.add_scalars("mean_iu", {'validation': mean_iu}, engine.state.epoch)
# adding handlers using `trainer.on` decorator API
@trainer.on(Events.EPOCH_COMPLETED)
def print_times(engine):
logger.info('Epoch {} done. Time per batch: {:.3f}[s] Speed: {:.1f}[samples/s]'
.format(engine.state.epoch, timer.value() * timer.step_count,
train_loader.batch_size / timer.value()))
timer.reset()
@trainer.on(Events.EPOCH_COMPLETED)
def plot_output(engine):
model.eval()
dataset = val_loader.dataset
idx = np.random.choice(np.arange(len(dataset)), size=1).item()
val_x, val_y = dataset[idx]
val_x = val_x.to(device)
with torch.no_grad():
pred_y = model(val_x.unsqueeze(0))
orig_img, val_y = untransform(val_x.cpu().data, val_y)
pred_y = pred_y.max(1)[1].cpu().data[0].numpy()
pred_val = cm[pred_y]
seg_val = cm[val_y]
# matplotlib
fig = plt.figure(figsize=(9, 3))
plt.subplot(131)
plt.imshow(orig_img)
plt.axis("off")
plt.subplot(132)
plt.imshow(seg_val)
plt.axis("off")
plt.subplot(133)
plt.imshow(pred_val)
plt.axis("off")
writer.add_figure('show_result', fig, engine.state.iteration)
trainer.run(train_loader, max_epochs=epochs)
writer.close()