模型:
1,使用30个Residual Block(所有Residual Block中均使用batch normalization)
2,使用Adam优化器,学习率按指数动态衰减
3,所有层的channel个数均为256
4,不使用pooling
结果:
训练比较快,8块卡训练2分钟即可到98.3%的测试准确率
结论:
nn中显存占用率与batch size是成正比的线性关系
from __future__ import print_function
import numpy as np
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
parser = argparse.ArgumentParser(description = 'pyTorch MNIST Example')
parser.add_argument('--batch-size', type = int, default = 512, help = 'batch size')
parser.add_argument('--lr', type = float, default = 1e-3, help = 'learning rate')
args = parser.parse_args()
print(args)
kwargs = {'num_workers': 1, 'pin_memory': True}
train_loader = torch.utils.data.DataLoader(
datasets.MNIST(
'./data', train = True, download = True, transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081,))])
),
batch_size = args.batch_size,
shuffle = True
)
test_data = datasets.MNIST(
'./data', train=False, download = True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
)
test_x = (Variable(torch.unsqueeze(test_data.test_data, dim=1), volatile=True).type(torch.FloatTensor)[:2000]/255).cuda()
test_y = (test_data.test_labels[:2000]).cuda()
class Residual(nn.Module):
def __init__(self, use_bn = True, input_channels = 256, out_channels = 256):
super(Residual, self).__init__()
self.use_bn = use_bn
self.out_channels = out_channels
self.input_channels = input_channels
self.mid_channels = input_channels // 2
self.down_channel = nn.Conv2d(input_channels, self.mid_channels, kernel_size = 1)
self.AcFunc = nn.ReLU()
if use_bn:
self.bn_0 = nn.BatchNorm2d(num_features = self.mid_channels)
self.bn_1 = nn.BatchNorm2d(num_features = self.mid_channels)
self.bn_2 = nn.BatchNorm2d(num_features = self.out_channels)
self.conv = nn.Conv2d(self.mid_channels, self.mid_channels, kernel_size = 3, padding = 1)
self.up_channel = nn.Conv2d(self.mid_channels, out_channels, kernel_size= 1)
if input_channels != out_channels:
self.trans = nn.Conv2d(input_channels, out_channels, kernel_size = 1)
def forward(self, inputs):
x = self.down_channel(inputs)
if self.use_bn:
x = self.bn_0(x)
x = self.AcFunc(x)
x = self.conv(x)
if self.use_bn:
x = self.bn_1(x)
x = self.AcFunc(x)
x = self.up_channel(x)
if self.input_channels != self.out_channels:
x += self.trans(inputs)
else:
x += inputs
if self.use_bn:
x = self.bn_2(x)
return self.AcFunc(x)
class Net(nn.Module):
def __init__(self, residual_stack = 30, use_bn = True, pre_layer = 5):
super(Net, self).__init__()
self.preprecess = nn.Sequential(
nn.Conv2d(1, 256, kernel_size = 1),
nn.BatchNorm2d(num_features = 256),
nn.ReLU()
)
for _ in range(pre_layer):
self.preprecess.add_module(
name = 'pre_layer' + str(_), module = nn.Conv2d(256, 256, kernel_size = 3, padding = 1)
)
if use_bn:
self.preprecess.add_module(
name = 'bn_pre_layer' + str(_), module = nn.BatchNorm2d(num_features = 256)
)
self.preprecess.add_module(
name = 'relu' + str(_), module = nn.ReLU()
)
self.residual_blocks = nn.Sequential()
for _ in range(residual_stack):
self.residual_blocks.add_module(
name = 'residual' + str(_), module = Residual(use_bn = use_bn)
)
self.out_layer = nn.Sequential(
nn.Linear(in_features = 28 * 28 * 256, out_features = 10)
)
def forward(self, inputs):
x = self.preprecess(inputs)
x = self.residual_blocks(x)
x = self.out_layer(x.view(-1,200704))
return x
net = Net(residual_stack = 30)
net = torch.nn.DataParallel(net)
net.cuda()
optimizer = optim.Adam(net.parameters(), lr = args.lr)
loss_function = nn.CrossEntropyLoss()
def adjust_learning_rate(optimizer, epoch, lr):
LR = lr * (0.9 ** (epoch / 2))
for param_group in optimizer.param_groups:
param_group['lr'] = LR
def train(epoch):
for batch_idx, (data, target) in enumerate(train_loader):
data, target = Variable(data.cuda()), Variable(target.cuda())
optimizer.zero_grad()
output = net(data)
loss = loss_function(output, target)
loss.backward()
optimizer.step()
if batch_idx > 0 and batch_idx % 10 == 0:
predict_output = net(test_x)
pred_y = torch.max(predict_output, 1)[1].data.squeeze()
accuracy = sum(pred_y == test_y) / test_y.size(0)
tips = 'epoch:{}, batch_idx:{}, percetange:{:.4f}, accuracy:{}'.format(epoch, batch_idx, batch_idx / len(train_loader),accuracy)
print(tips)
adjust_learning_rate(optimizer, epoch, args.lr)
if __name__ == '__main__':
for _ in range(100):
train(_)