mxnet(gluon)学习之路-mnist训练

训练流程

import mxnet as mx
from mxnet.gluon import loss as gloss, nn
import mxnet.gluon as gluon
from mxnet import autograd
import mxnet.ndarray as nd
import numpy as np
import mxnet.metric

class LeNet(gluon.nn.HybridBlock):
    def __init__(self, classes=10,feature_size=120, **kwargs):
        super(LeNet,self).__init__(**kwargs)

        with self.name_scope():
            self.conv1 = nn.Conv2D(channels=20, kernel_size=5, activation='relu')
            self.conv2 = nn.Conv2D(channels=50, kernel_size=5, activation='relu')
            self.maxpool = nn.MaxPool2D(pool_size=2, strides=2)
            self.flat = nn.Flatten()
            self.dense1 = nn.Dense(feature_size)
            self.dense2 = nn.Dense(classes)

    def hybrid_forward(self, F, x, *args, **kwargs):
        x = self.maxpool(self.conv1(x))
        x = self.maxpool(self.conv2(x))
        ft = self.dense1(x)
        output = self.dense2(ft)

        return output


def transformer(data, label):
    return nd.transpose(data.astype(np.float32), (2, 0, 1)).asnumpy()  / 255, label.astype(np.int32)


def try_gpu():
    try:
        ctx = mx.gpu()
        _ = nd.zeros((1,), ctx=ctx)
    except:
        ctx = mx.cpu()
    return ctx


if __name__ == '__main__':
    net = LeNet()
    net.hybridize()

    train_data = gluon.data.DataLoader(gluon.data.vision.MNIST('./data', train=True, transform=transformer),batch_size=64, shuffle=True, last_batch='discard')
    val_data = gluon.data.DataLoader(gluon.data.vision.MNIST('./data', train=False, transform=transformer),batch_size=100, shuffle=False)

    ctx = try_gpu()
    print(ctx)
    net.initialize(mx.init.Xavier(magnitude=2.24), ctx=ctx)
    trainer = gluon.Trainer(net.collect_params(),
                                optimizer='sgd', optimizer_params={'learning_rate': 0.01, 'wd': 5e-4})
    metric = mx.metric.Accuracy()
    loss = gluon.loss.SoftmaxCrossEntropyLoss()

    epochs = 10
    for epoch in range(epochs):
        metric.reset()
        for i, (data, label) in enumerate(train_data):
            data = data.as_in_context(ctx)
            label = label.as_in_context(ctx)
            with autograd.record():
                output = net(data)
                L = loss(output, label)
                L.backward()

            trainer.step(data.shape[0])
            metric.update([label], [output])

            if i % 100 == 0 and i > 0:
                name, acc = metric.get()
                print('[Epoch %d Batch %d] Training: %s=%f'%(epoch, i, name, acc))

        name, acc = metric.get()
        print('[Epoch %d] Training: %s=%f'%(epoch, name, acc))












你可能感兴趣的:(mxnet-gluon之路)