Pytorch训练模版

#! -*- coding: utf-8 -*-
import torch
import torch.nn as nn
import numpy as np
from tensorboardX import SummaryWriter
import os


class Manager(object):
    def __init__(self, model, args, get_loader):
        self.model = model
        self.args = args
        self.get_loader = get_loader
        self.test_loader = get_loader(args.test_filename, args.batch_size, shuffle=False)

        self.criterion = nn.CrossEntropyLoss()
        self.writer = SummaryWriter(args.logdir)

    def eval(self, step):
        print('Eval ...')
        self.model.eval()  # 切换模式
        if self.args.cuda:  # CPU -> GPU
            self.model.cuda()

        loss = []
        accuracy = []
        with torch.no_grad():  # 无需反向传导
            for index, data in enumerate(self.test_loader):
                inputs, targets = data
                if self.args.cuda:
                    inputs, targets = inputs.cuda(), targets.cuda()

                outputs = self.model(inputs)  # 模型输出(无需softmax层, CrossEntropyLoss会先计算softmax)

                batch_loss = self.criterion(outputs, targets)  # 默认为每一个mini-batch的loss均值

                batch_accuracy = calculate_corrects(outputs, targets)

                if index % 20 == 0:
                    print('batch {}, loss: {:.3f}, accuracy: {:.3f}%'.format(index, batch_loss.item(), batch_accuracy.item()))
                loss.append(batch_loss.item())
                accuracy.append(batch_accuracy.item())

        average_loss = np.array(loss).mean()
        average_acc = np.array(accuracy).mean()

        self.writer.add_scalar('Test/Loss', average_loss, step)
        self.writer.add_scalar('Test/Acc', average_acc, step)
        print('step: {}, eval loss: {:.3f}, accuracy: {:.3f}'.format(step, average_loss, average_acc))

        self.model.train()  # remember to change model mode
        return average_acc

    def train(self, optimizer, best_accuracy=0.0):
        print('Train ...')
        self.model.train()
        if self.args.cuda:
            self.model.cuda()

        step = 0
        loss = []
        accuracy = []
        while True:
            train_loader = self.get_loader(self.args.train_filename, self.args.batch_size, shuffle=True)  # 每个epoch结束重新shuffle
            for inputs, targets in train_loader:
                step += 1  # step counter
                if self.args.cuda:
                    inputs, targets = inputs.cuda(), targets.cuda()

                optimizer.zero_grad()

                outputs = self.model(inputs)
                batch_loss = self.criterion(outputs, targets)
                batch_loss.backward()
                optimizer.step()

                batch_accuracy = calculate_corrects(outputs, targets)
                loss.append(batch_loss.item())
                accuracy.append(batch_accuracy.item())

                if step % 50 == 0:
                    n_batch_loss = np.array(loss).mean()
                    n_batch_acc = np.array(accuracy).mean()

                    self.writer.add_scalar('Train/Loss', n_batch_loss, step)
                    self.writer.add_scalar('Train/Acc', n_batch_acc, step)
                    print('step: {}, train loss: {:.3f}, accuracy: {:.3f}%'.format(step, n_batch_loss, n_batch_acc))

                    # reset
                    loss.clear()
                    accuracy.clear()

                if step % 200 == 0:
                    eval_acc = self.eval(step)
                    # Save model
                    if eval_acc > best_accuracy:
                        self.save_model(step, best_accuracy, self.args.save_model_path)
                        print('best model so far {:.3f} -> {:.3f}'.format(best_accuracy, eval_acc))
                        best_accuracy = eval_acc

                if step in [1000, 5000, 10000, 50000]:
                    self.adjust_learning_rate(optimizer)

            print('-------- epoch end --------')
            if step >= self.args.total_steps:
                break

    def save_model(self, step, best_accuracy, model_path):
        if self.args.cuda:
            self.model.cpu()
        print('Save model at {}'.format(model_path))
        ckpt = {
            'step': step,
            'best_accuracy': best_accuracy,
            'state_dict': self.model.state_dict()
        }
        torch.save(ckpt, model_path)
        if self.args.cuda:
            self.model.cuda()

    def load_model(self):
        best_accuracy = 0.0
        if os.path.exists(self.args.load_model_path):
            ckpt = torch.load(self.args.load_model_path)
            step = ckpt['step']
            best_accuracy = ckpt['best_accuracy']
            state_dict = ckpt['state_dict']
            self.model.load_state_dict(state_dict)
            print('Load {} at {}, best accuracy: {:.3f}%'.format(self.args.load_model_path, step, best_accuracy))
        else:
            self.init_model()
        return best_accuracy

    def init_model(self):
        print('Init model ...')
        for index, module in enumerate(self.model.modules()):
            if isinstance(module, nn.Conv2d):
                nn.init.kaiming_normal_(module.weight)
                if module.bias is not None:
                    nn.init.zeros_(module.bias)
            elif isinstance(module, nn.Linear):
                nn.init.kaiming_normal_(module.weight)
                nn.init.zeros_(module.bias)
        print('Init completed')

    def adjust_learning_rate(self, optimizer):
        """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
        if self.args.lr > 0.0001:
            lr = self.args.lr * 0.1
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr
            print('adjust optimizer learning rate, {:.5f} -> {:.5f}'.format(self.args.lr, lr))
            self.args.lr = lr
        else:
            print('learning rate is minimal value: {:.5f}'.format(self.args.lr))


def calculate_corrects(outputs, targets):
    """
    计算准确率
    """
    assert outputs.shape[0] == targets.shape[0], 'target and output do not match'
    preds = torch.argmax(outputs, dim=1)
    accuracy = (preds == targets).float().mean() * 100
    return accuracy

你可能感兴趣的:(深度学习-算法)