import torch
import torch.nn as nn
import torch.optim as optim
import torchnet as tnt
from torch.autograd import Variable
from tqdm import tqdm
from torch.utils.data import DataLoader
from torchvision.models import AlexNet
from torchvision import transforms
from cub200_dataset import CUB_200
import os
mnist_root = "/home/zwx/Works/January/pytorch_tutorial/data"
cifar_root = '/home/zwx/Works/January/data/cifar_data'
cub200_root = '/home/zwx/workspace/DATASETS/CUB_200_2011'
batch_size = 32
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
transforms_ = {
'train': transforms.Compose([
transforms.Resize(256),
transforms.RandomCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
]),
'test': transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize
])
}
datasets_cub200 = {d: CUB_200(root=cub200_root, train=(d == 'train'), transform=transforms_['train'])
for d in ['train', 'test']}
dataloader = {d: DataLoader(dataset=datasets_cub200[d], batch_size=batch_size, shuffle=(d == 'train'), num_workers=4)
for d in ['train', 'test']}
def main():
args = {
'epochs': 30,
'lr': 1e-3,
'weight_decay': 1e-6,
'use_gpu': torch.cuda.is_available(),
'train_data_loader': dataloader['train'],
'test_data_loader': dataloader['test'],
'save_prefix': '../checkpoints'
}
model = AlexNet(num_classes=200)
manager = Manager(args, model)
optimizer = optim.Adam(model.parameters(), lr=args['lr'], weight_decay=args['weight_decay'])
manager.train(optimizer, savename=os.path.join(args['save_prefix'], 'model_{}'.format(args['epochs'])))
class Manager():
def __init__(self, args, model):
self.args = args
self.model = model
self.cuda = args['use_gpu']
self.epochs = args['epochs']
self.train_data_loader = args['train_data_loader']
self.test_data_loader = args['test_data_loader']
self.criterion = nn.CrossEntropyLoss()
def eval(self):
self.model.eval()
if self.cuda:
self.model.cuda()
error_meter = None
print('Performing eval...')
topk = [1]
for batch, label in tqdm(self.test_data_loader, desc='Eval'):
if self.cuda:
batch = batch.cuda()
batch = Variable(batch, volatile=False)
scores = self.model(batch)
outputs = scores.data.view(-1, scores.size(1))
label = label.view(-1)
if error_meter is None:
if outputs.size(1) > 2:
topk.append(2)
error_meter = tnt.meter.ClassErrorMeter(topk=topk)
error_meter.add(outputs, label)
error = error_meter.value()
print(', '.join('@%s=%.2f' % t for t in zip(topk, error)))
return error
def do_epoch(self, epoch_idx, dataloader, optimizer):
for batch, label in tqdm(dataloader, desc='Epoch: {} '.format(epoch_idx)):
self.do_batch(batch, label, optimizer)
def do_batch(self, batch, label, optimizer):
if self.cuda:
batch = batch.cuda()
label = label.cuda()
batch, label = Variable(batch), Variable(label)
self.model.zero_grad()
scores = self.model(batch)
loss = self.criterion(scores, label)
loss.backward()
optimizer.step()
def train(self, optimizer, savename='', best_accuracy=0.0):
print('Performing training...')
if self.cuda:
self.model.cuda()
self.model.train()
best_accuracy = best_accuracy
for i in range(self.epochs):
epoch_idx = i + 1
print('Epoch: {}'.format(epoch_idx))
self.do_epoch(epoch_idx, self.train_data_loader, optimizer)
errors = self.eval()
accuracy = 100 - errors[0]
if accuracy >= best_accuracy:
print('Best model so far, Accuracy: %0.2f%% -> %0.2f%%' % (best_accuracy, accuracy))
self.save_model(epoch_idx, best_accuracy, errors, savename)
print('Finished finetuning...')
print('Best error/accuracy: %0.2f%%, %0.2f%%' %
(100 - best_accuracy, best_accuracy))
print('-' * 16)
def save_model(self, epoch, best_accuracy, errors, savename):
"""Saves model to file."""
self.model.cpu()
ckpt = {
'args': self.args,
'epoch': epoch,
'accuracy': best_accuracy,
'errors': errors,
'state_dict': self.model.state_dict(),
}
if self.cuda:
self.model.cuda()
torch.save(ckpt, savename + '.pt')
if __name__ == '__main__':
main()