cs231 pytorch ignite 框架高级API应用案例mnist_with_tensorboardx
mnist_with_tensorboardx.py
"""
MNIST example with training and validation monitoring using TensorboardX and Tensorboard.
Requirements:
TensorboardX (https://github.com/lanpa/tensorboard-pytorch): `pip install tensorboardX`
Tensorboard: `pip install tensorflow` (or just install tensorboard without the rest of tensorflow)
Usage:
Start tensorboard:
```bash
tensorboard --logdir=/tmp/tensorboard_logs/
```
Run the example:
```bash
python mnist_with_tensorboardx.py --log_dir=/tmp/tensorboard_logs
```
"""
from __future__ import print_function
from argparse import ArgumentParser
import torch
from torch.utils.data import DataLoader
from torch import nn
import torch.nn.functional as F
from torch.optim import SGD
from torchvision.datasets import MNIST
from torchvision.transforms import Compose, ToTensor, Normalize
try:
from tensorboardX import SummaryWriter
except ImportError:
raise RuntimeError("No tensorboardX package is found. Please install with the command: \npip install tensorboardX")
from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator
from ignite.metrics import Accuracy, Loss
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.conv2_drop = nn.Dropout2d()
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)
def forward(self, x):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
x = x.view(-1, 320)
x = F.relu(self.fc1(x))
x = F.dropout(x, training=self.training)
x = self.fc2(x)
return F.log_softmax(x, dim=-1)
def get_data_loaders(train_batch_size, val_batch_size):
data_transform = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))])
train_loader = DataLoader(MNIST(download=True, root=".", transform=data_transform, train=True),
batch_size=train_batch_size, shuffle=True)
val_loader = DataLoader(MNIST(download=False, root=".", transform=data_transform, train=False),
batch_size=val_batch_size, shuffle=False)
return train_loader, val_loader
def create_summary_writer(model, data_loader, log_dir):
writer = SummaryWriter(log_dir=log_dir)
data_loader_iter = iter(data_loader)
x, y = next(data_loader_iter)
try:
writer.add_graph(model, x)
except Exception as e:
print("Failed to save model graph: {}".format(e))
return writer
def run(train_batch_size, val_batch_size, epochs, lr, momentum, log_interval, log_dir):
train_loader, val_loader = get_data_loaders(train_batch_size, val_batch_size)
model = Net()
writer = create_summary_writer(model, train_loader, log_dir)
device = 'cpu'
if torch.cuda.is_available():
device = 'cuda'
optimizer = SGD(model.parameters(), lr=lr, momentum=momentum)
trainer = create_supervised_trainer(model, optimizer, F.nll_loss, device=device)
evaluator = create_supervised_evaluator(model,
metrics={'accuracy': Accuracy(),
'nll': Loss(F.nll_loss)},
device=device)
@trainer.on(Events.ITERATION_COMPLETED)
def log_training_loss(engine):
iter = (engine.state.iteration - 1) % len(train_loader) + 1
if iter % log_interval == 0:
print("Epoch[{}] Iteration[{}/{}] Loss: {:.2f}"
"".format(engine.state.epoch, iter, len(train_loader), engine.state.output))
writer.add_scalar("training/loss", engine.state.output, engine.state.iteration)
@trainer.on(Events.EPOCH_COMPLETED)
def log_training_results(engine):
evaluator.run(train_loader)
metrics = evaluator.state.metrics
avg_accuracy = metrics['accuracy']
avg_nll = metrics['nll']
print("Training Results - Epoch: {} Avg accuracy: {:.2f} Avg loss: {:.2f}"
.format(engine.state.epoch, avg_accuracy, avg_nll))
writer.add_scalar("training/avg_loss", avg_nll, engine.state.epoch)
writer.add_scalar("training/avg_accuracy", avg_accuracy, engine.state.epoch)
@trainer.on(Events.EPOCH_COMPLETED)
def log_validation_results(engine):
evaluator.run(val_loader)
metrics = evaluator.state.metrics
avg_accuracy = metrics['accuracy']
avg_nll = metrics['nll']
print("Validation Results - Epoch: {} Avg accuracy: {:.2f} Avg loss: {:.2f}"
.format(engine.state.epoch, avg_accuracy, avg_nll))
writer.add_scalar("valdation/avg_loss", avg_nll, engine.state.epoch)
writer.add_scalar("valdation/avg_accuracy", avg_accuracy, engine.state.epoch)
# kick everything off
trainer.run(train_loader, max_epochs=epochs)
writer.close()
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument('--batch_size', type=int, default=64,
help='input batch size for training (default: 64)')
parser.add_argument('--val_batch_size', type=int, default=1000,
help='input batch size for validation (default: 1000)')
parser.add_argument('--epochs', type=int, default=10,
help='number of epochs to train (default: 10)')
parser.add_argument('--lr', type=float, default=0.01,
help='learning rate (default: 0.01)')
parser.add_argument('--momentum', type=float, default=0.5,
help='SGD momentum (default: 0.5)')
parser.add_argument('--log_interval', type=int, default=10,
help='how many batches to wait before logging training status')
parser.add_argument("--log_dir", type=str, default="tensorboard_logs",
help="log directory for Tensorboard log output")
args = parser.parse_args()
run(args.batch_size, args.val_batch_size, args.epochs, args.lr, args.momentum,
args.log_interval, args.log_dir)
运行结果:
Epoch[10] Iteration[740/938] Loss: 0.16
Epoch[10] Iteration[750/938] Loss: 0.25
Epoch[10] Iteration[760/938] Loss: 0.10
Epoch[10] Iteration[770/938] Loss: 0.10
Epoch[10] Iteration[780/938] Loss: 0.10
Epoch[10] Iteration[790/938] Loss: 0.29
Epoch[10] Iteration[800/938] Loss: 0.15
Epoch[10] Iteration[810/938] Loss: 0.21
Epoch[10] Iteration[820/938] Loss: 0.24
Epoch[10] Iteration[830/938] Loss: 0.39
Epoch[10] Iteration[840/938] Loss: 0.16
Epoch[10] Iteration[850/938] Loss: 0.10
Epoch[10] Iteration[860/938] Loss: 0.18
Epoch[10] Iteration[870/938] Loss: 0.42
Epoch[10] Iteration[880/938] Loss: 0.33
Epoch[10] Iteration[890/938] Loss: 0.17
Epoch[10] Iteration[900/938] Loss: 0.17
Epoch[10] Iteration[910/938] Loss: 0.28
Epoch[10] Iteration[920/938] Loss: 0.26
Epoch[10] Iteration[930/938] Loss: 0.10
Training Results - Epoch: 10 Avg accuracy: 0.98 Avg loss: 0.06
Validation Results - Epoch: 10 Avg accuracy: 0.98 Avg loss: 0.05
mnist_with_tensorboardx.py代码中使用了python的装饰器:
@trainer.on(Events.ITERATION_COMPLETED)
def log_training_loss(engine):
iter = (engine.state.iteration - 1) % len(train_loader) + 1
if iter % log_interval == 0:
print("Epoch[{}] Iteration[{}/{}] Loss: {:.2f}"
"".format(engine.state.epoch, iter, len(train_loader), engine.state.output))
writer.add_scalar("training/loss", engine.state.output, engine.state.iteration)
class Engine(object)的on方法,封装了一个装饰器函数:
def on(self, event_name, *args, **kwargs):
"""Decorator shortcut for add_event_handler
Args:
event_name: An event to attach the handler to. Valid events are from
:class:`ignite.engine.Events` or any `event_name` added by :meth:`register_events`.
*args: optional args to be passed to `handler`
**kwargs: optional keyword args to be passed to `handler`
"""
def decorator(f):
self.add_event_handler(event_name, f, *args, **kwargs)
return f
return decorator
在decorator函数中,self._event_handlers追加将事件名、函数放入到字典中。
def add_event_handler(self, event_name, handler, *args, **kwargs):
"""Add an event handler to be executed when the specified event is fired
Args:
event_name: An event to attach the handler to. Valid events are from
:class:`ignite.engine.Events` or any `event_name` added by :meth:`register_events`.
handler (Callable): the callable event handler that should be invoked
*args: optional args to be passed to `handler`
**kwargs: optional keyword args to be passed to `handler`
Notes:
The handler function's first argument will be `self`, the `Engine` object it was bound to.
Note that other arguments can be passed to the handler in addition to the `*args` and `**kwargs`
passed here, for example during `Events.EXCEPTION_RAISED`.
Example usage:
.. code-block:: python
engine = Engine(process_function)
def print_epoch(engine):
print("Epoch: {}".format(engine.state.epoch))
engine.add_event_handler(Events.EPOCH_COMPLETED, print_epoch)
"""
if event_name not in self._allowed_events:
self._logger.error("attempt to add event handler to an invalid event %s ", event_name)
raise ValueError("Event {} is not a valid event for this Engine".format(event_name))
event_args = (Exception(), ) if event_name == Events.EXCEPTION_RAISED else ()
self._check_signature(handler, 'handler', *(event_args + args), **kwargs)
self._event_handlers[event_name].append((handler, args, kwargs))
self._logger.debug("added handler for event %s ", event_name)
事件的触发:
def _fire_event(self, event_name, *event_args, **event_kwargs):
"""Execute all the handlers associated with given event.
This method executes all handlers associated with the event
`event_name`. Optional positional and keyword arguments can be used to
pass arguments to **all** handlers added with this event. These
aguments updates arguments passed using `add_event_handler`.
Args:
event_name: event for which the handlers should be executed. Valid
events are from :class:`ignite.engine.Events` or any `event_name` added by
:meth:`register_events`.
*event_args: optional args to be passed to all handlers.
**event_kwargs: optional keyword args to be passed to all handlers.
"""
if event_name in self._allowed_events:
self._logger.debug("firing handlers for event %s ", event_name)
for func, args, kwargs in self._event_handlers[event_name]:
kwargs.update(event_kwargs)
func(self, *(event_args + args), **kwargs)