import torch
import torch.nn as nn
import numpy as np
from tensorboardX import SummaryWriter
import os
class Manager(object):
def __init__(self, model, args, get_loader):
self.model = model
self.args = args
self.get_loader = get_loader
self.test_loader = get_loader(args.test_filename, args.batch_size, shuffle=False)
self.criterion = nn.CrossEntropyLoss()
self.writer = SummaryWriter(args.logdir)
def eval(self, step):
print('Eval ...')
self.model.eval()
if self.args.cuda:
self.model.cuda()
loss = []
accuracy = []
with torch.no_grad():
for index, data in enumerate(self.test_loader):
inputs, targets = data
if self.args.cuda:
inputs, targets = inputs.cuda(), targets.cuda()
outputs = self.model(inputs)
batch_loss = self.criterion(outputs, targets)
batch_accuracy = calculate_corrects(outputs, targets)
if index % 20 == 0:
print('batch {}, loss: {:.3f}, accuracy: {:.3f}%'.format(index, batch_loss.item(), batch_accuracy.item()))
loss.append(batch_loss.item())
accuracy.append(batch_accuracy.item())
average_loss = np.array(loss).mean()
average_acc = np.array(accuracy).mean()
self.writer.add_scalar('Test/Loss', average_loss, step)
self.writer.add_scalar('Test/Acc', average_acc, step)
print('step: {}, eval loss: {:.3f}, accuracy: {:.3f}'.format(step, average_loss, average_acc))
self.model.train()
return average_acc
def train(self, optimizer, best_accuracy=0.0):
print('Train ...')
self.model.train()
if self.args.cuda:
self.model.cuda()
step = 0
loss = []
accuracy = []
while True:
train_loader = self.get_loader(self.args.train_filename, self.args.batch_size, shuffle=True)
for inputs, targets in train_loader:
step += 1
if self.args.cuda:
inputs, targets = inputs.cuda(), targets.cuda()
optimizer.zero_grad()
outputs = self.model(inputs)
batch_loss = self.criterion(outputs, targets)
batch_loss.backward()
optimizer.step()
batch_accuracy = calculate_corrects(outputs, targets)
loss.append(batch_loss.item())
accuracy.append(batch_accuracy.item())
if step % 50 == 0:
n_batch_loss = np.array(loss).mean()
n_batch_acc = np.array(accuracy).mean()
self.writer.add_scalar('Train/Loss', n_batch_loss, step)
self.writer.add_scalar('Train/Acc', n_batch_acc, step)
print('step: {}, train loss: {:.3f}, accuracy: {:.3f}%'.format(step, n_batch_loss, n_batch_acc))
loss.clear()
accuracy.clear()
if step % 200 == 0:
eval_acc = self.eval(step)
if eval_acc > best_accuracy:
self.save_model(step, best_accuracy, self.args.save_model_path)
print('best model so far {:.3f} -> {:.3f}'.format(best_accuracy, eval_acc))
best_accuracy = eval_acc
if step in [1000, 5000, 10000, 50000]:
self.adjust_learning_rate(optimizer)
print('-------- epoch end --------')
if step >= self.args.total_steps:
break
def save_model(self, step, best_accuracy, model_path):
if self.args.cuda:
self.model.cpu()
print('Save model at {}'.format(model_path))
ckpt = {
'step': step,
'best_accuracy': best_accuracy,
'state_dict': self.model.state_dict()
}
torch.save(ckpt, model_path)
if self.args.cuda:
self.model.cuda()
def load_model(self):
best_accuracy = 0.0
if os.path.exists(self.args.load_model_path):
ckpt = torch.load(self.args.load_model_path)
step = ckpt['step']
best_accuracy = ckpt['best_accuracy']
state_dict = ckpt['state_dict']
self.model.load_state_dict(state_dict)
print('Load {} at {}, best accuracy: {:.3f}%'.format(self.args.load_model_path, step, best_accuracy))
else:
self.init_model()
return best_accuracy
def init_model(self):
print('Init model ...')
for index, module in enumerate(self.model.modules()):
if isinstance(module, nn.Conv2d):
nn.init.kaiming_normal_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Linear):
nn.init.kaiming_normal_(module.weight)
nn.init.zeros_(module.bias)
print('Init completed')
def adjust_learning_rate(self, optimizer):
"""Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
if self.args.lr > 0.0001:
lr = self.args.lr * 0.1
for param_group in optimizer.param_groups:
param_group['lr'] = lr
print('adjust optimizer learning rate, {:.5f} -> {:.5f}'.format(self.args.lr, lr))
self.args.lr = lr
else:
print('learning rate is minimal value: {:.5f}'.format(self.args.lr))
def calculate_corrects(outputs, targets):
"""
计算准确率
"""
assert outputs.shape[0] == targets.shape[0], 'target and output do not match'
preds = torch.argmax(outputs, dim=1)
accuracy = (preds == targets).float().mean() * 100
return accuracy