ResNet 在mnist数据集的实验

模型:
    1,使用30个Residual Block(所有Residual Block中均使用batch normalization)
    2,使用Adam优化器,学习率按指数动态衰减
    3,所有层的channel个数均为256
    4,不使用pooling

结果:
    训练比较快,8块卡训练2分钟即可到98.3%的测试准确率

结论:
    nn中显存占用率与batch size是成正比的线性关系
from __future__ import print_function
import numpy as np
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable


parser = argparse.ArgumentParser(description = 'pyTorch MNIST Example')
parser.add_argument('--batch-size', type = int, default = 512, help = 'batch size')
parser.add_argument('--lr', type = float, default = 1e-3, help = 'learning rate')

args = parser.parse_args()

print(args)

kwargs = {'num_workers': 1, 'pin_memory': True}

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST(
        './data', train = True, download = True, transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081,))])
        ),
    batch_size = args.batch_size,
    shuffle = True
)


test_data = datasets.MNIST(
        './data', train=False, download = True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
        )

test_x = (Variable(torch.unsqueeze(test_data.test_data, dim=1), volatile=True).type(torch.FloatTensor)[:2000]/255).cuda()

test_y = (test_data.test_labels[:2000]).cuda()
   
class Residual(nn.Module):
    def __init__(self, use_bn = True, input_channels = 256, out_channels = 256):
        super(Residual, self).__init__()
        self.use_bn = use_bn
        self.out_channels   = out_channels
        self.input_channels = input_channels
        self.mid_channels   = input_channels // 2

        self.down_channel = nn.Conv2d(input_channels, self.mid_channels, kernel_size = 1)
        self.AcFunc       = nn.ReLU()
        if use_bn:
            self.bn_0 = nn.BatchNorm2d(num_features = self.mid_channels)
            self.bn_1 = nn.BatchNorm2d(num_features = self.mid_channels)
            self.bn_2 = nn.BatchNorm2d(num_features = self.out_channels)

        self.conv = nn.Conv2d(self.mid_channels, self.mid_channels, kernel_size = 3, padding = 1)

        self.up_channel = nn.Conv2d(self.mid_channels, out_channels, kernel_size= 1)

        if input_channels != out_channels:
            self.trans = nn.Conv2d(input_channels, out_channels, kernel_size = 1)
    
    def forward(self, inputs):
        x = self.down_channel(inputs)
        if self.use_bn:
            x = self.bn_0(x)
        x = self.AcFunc(x)

        x = self.conv(x)
        if self.use_bn:
            x = self.bn_1(x)
        x = self.AcFunc(x)

        x = self.up_channel(x)

        if self.input_channels != self.out_channels:
            x += self.trans(inputs)
        else:
            x += inputs

        if self.use_bn:
            x = self.bn_2(x)
        
        return self.AcFunc(x)

class Net(nn.Module):
    def __init__(self, residual_stack = 30, use_bn = True, pre_layer = 5):
        super(Net, self).__init__()
        self.preprecess = nn.Sequential(
            nn.Conv2d(1, 256, kernel_size = 1),
            nn.BatchNorm2d(num_features = 256),
            nn.ReLU()
        )

        for _ in range(pre_layer):
            self.preprecess.add_module(
                name = 'pre_layer' + str(_), module = nn.Conv2d(256, 256, kernel_size = 3, padding = 1)
            )

            if use_bn:
                self.preprecess.add_module(
                    name = 'bn_pre_layer' + str(_), module = nn.BatchNorm2d(num_features = 256)
                )

            self.preprecess.add_module(
                name = 'relu' + str(_), module = nn.ReLU()
            )
        
        self.residual_blocks = nn.Sequential()

        for _ in range(residual_stack):
            self.residual_blocks.add_module(
                name = 'residual' + str(_), module = Residual(use_bn = use_bn)
            )
        
        self.out_layer = nn.Sequential(
            nn.Linear(in_features = 28 * 28 * 256, out_features = 10)
        )
    
    def forward(self, inputs):
        x = self.preprecess(inputs)
        x = self.residual_blocks(x)
        x = self.out_layer(x.view(-1,200704))
        return x

net = Net(residual_stack = 30)
net = torch.nn.DataParallel(net)
net.cuda()

optimizer = optim.Adam(net.parameters(), lr = args.lr)
loss_function = nn.CrossEntropyLoss()

def adjust_learning_rate(optimizer, epoch, lr):
    LR = lr * (0.9 ** (epoch / 2))
    for param_group in optimizer.param_groups:
        param_group['lr'] = LR

def train(epoch):
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = Variable(data.cuda()), Variable(target.cuda())
        optimizer.zero_grad()
        output = net(data)
        loss = loss_function(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx > 0 and batch_idx % 10 == 0:
            predict_output = net(test_x)
            pred_y = torch.max(predict_output, 1)[1].data.squeeze()  
            accuracy = sum(pred_y == test_y) / test_y.size(0)
            tips = 'epoch:{}, batch_idx:{}, percetange:{:.4f}, accuracy:{}'.format(epoch, batch_idx, batch_idx / len(train_loader),accuracy)
            print(tips)

    adjust_learning_rate(optimizer, epoch, args.lr)

if __name__ == '__main__':
    for _ in range(100):
        train(_)

你可能感兴趣的:(ResNet 在mnist数据集的实验)