【MXNet】(二十五):实现ResNet

论文原文《Deep residual learning for image recognition》。

from mxnet import gluon, init, nd
from mxnet.gluon import nn

class Residual(nn.Block):
    def __init__(self, num_channels, use_1x1conv=False, strides=1, **kwargs):
        super(Residual, self).__init__(**kwargs)
        self.conv1 = nn.Conv2D(num_channels, kernel_size=3, padding=1, strides=strides)
        self.conv2 = nn.Conv2D(num_channels, kernel_size=3, padding=1)
        if use_1x1conv:
            self.conv3 = nn.Conv2D(num_channels, kernel_size=1, strides=strides)
        else:
            self.conv3 = None
        self.bn1 = nn.BatchNorm()
        self.bn2 = nn.BatchNorm()
        
    def forward(self, X):
        Y = nd.relu(self.bn1(self.conv1(X)))
        Y = self.bn2(self.conv2(Y))
        if self.conv3:
            X = self.conv3(X)
        return nd.relu(Y + X)
net = nn.Sequential()
net.add(nn.Conv2D(64, kernel_size=7, strides=2, padding=3),
       nn.BatchNorm(), nn.Activation('relu'),
       nn.MaxPool2D(pool_size=3, strides=2, padding=1))
def resnet_block(num_channels, num_residuals, first_block=False):
    blk = nn.Sequential()
    for i in range(num_residuals):
        if i == 0 and not first_block:
            blk.add(Residual(num_channels, use_1x1conv=True, strides=2))
        else:
            blk.add(Residual(num_channels))
    return blk
net.add(resnet_block(64, 2, first_block=True),
       resnet_block(128, 2),
       resnet_block(256, 2),
       resnet_block(512, 2))
net.add(nn.GlobalAvgPool2D(), nn.Dense(10))

构造一个输入看一下各层的尺寸。

X = nd.random.uniform(shape=(1, 1, 224, 224))
net.initialize()
for layer in net:
    X = layer(X)
    print(layer.name, 'output shape:\t', X.shape)

【MXNet】(二十五):实现ResNet_第1张图片

还是使用Fashion-MNIST数据集,try_gpu(),load_data_fashion_mnist(),train()这几个函数的实现请参考链接:【MXNet】(二十):实现AlexNet。

显存有限,因此把batch_size设置的比较小。

lr, num_epochs, batch_size, ctx = 0.05, 5, 64, try_gpu()
net.initialize(force_reinit=True, ctx=ctx, init=init.Xavier())
trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': lr})
train_iter, test_iter = load_data_fashion_mnist(batch_size, resize=224)
train(net, train_iter, test_iter, batch_size, trainer, ctx, num_epochs)

【MXNet】(二十五):实现ResNet_第2张图片

你可能感兴趣的:(深度学习,MXNet)