《动手学深度学习(李沐)》笔记3

多层感知机(mxnet)

from mxnet import gluon
from mxnet import ndarray as nd
from mxnet import autograd
def transform(data, label):
    return data.astype('float32') / 255, label.astype('float32')
def SGD(params, lr):
    for param in params:
        param[:] = param - lr * param.grad
mnist_train = gluon.data.vision.FashionMNIST(train=True, transform=transform)
mnist_test = gluon.data.vision.FashionMNIST(train=False, transform=transform)
batch_size = 256

train_data = gluon.data.DataLoader(mnist_train, batch_size, shuffle=True)
test_data = gluon.data.DataLoader(mnist_test, batch_size, shuffle=False)
num_inputs = 28*28#输入数
num_outputs = 10#输出数
num_hidden = 256#中间结点数
weight_scale = .01

W1 = nd.random_normal(shape=(num_inputs, num_hidden), scale=weight_scale)
b1 = nd.zeros(num_hidden)
W2 = nd.random_normal(shape=(num_hidden, num_outputs), scale=weight_scale)
b2 = nd.zeros(num_outputs)

params = [W1, b1, W2, b2]#参数整合

for param in params:#为参数创建导数空间
    param.attach_grad()
def relu(X):#激活函数
    return nd.maximum(X, 0)
def net(X):#定义网络
    X = X.reshape((-1, num_inputs))#-1表示函数未知
    h1 = relu(nd.dot(X, W1) + b1)#点乘后再用relu激活函数
    output = nd.dot(h1, W2) + b2#得到输出值
    return output
from mxnet import gluon
softmax_cross_entropy = gluon.loss.SoftmaxCrossEntropyLoss()#定义交叉熵

from mxnet import autograd as autograd

learning_rate = .5
def accuracy(output, label):
    return nd.mean(output.argmax(axis=1)==label).asscalar()
def evaluate_accuracy(data_iterator, net):
    acc = 0
    for data, label in data_iterator:
        output = net(data)
        # acc_tmp = accuracy(output, label)
        acc = acc + accuracy(output, label)
    return acc/len(data_iterator)

for epoch in range(5):
    train_loss = 0.
    train_acc = 0.
    for data, label in train_data:
        with autograd.record():#进行梯度自动求导计算
            output = net(data)
            loss = softmax_cross_entropy(output, label)
        loss.backward()
        SGD(params, learning_rate/batch_size)

        train_loss += nd.mean(loss).asscalar()
        train_acc += accuracy(output, label)

    test_acc = evaluate_accuracy(test_data, net)
    print("Epoch %d. Loss: %f, Train acc %f, Test acc %f" % (
        epoch, train_loss/len(train_data),
        train_acc/len(train_data), test_acc))

《动手学深度学习(李沐)》笔记3_第1张图片


多层感知机 — 使用Gluon

from mxnet import ndarray as nd
from mxnet import gluon
from mxnet import autograd
def transform(data, label):
    return data.astype('float32') / 255, label.astype('float32')


#数据读取
mnist_train = gluon.data.vision.FashionMNIST(train=True, transform=transform)
mnist_test = gluon.data.vision.FashionMNIST(train=False, transform=transform)
batch_size = 256
train_data = gluon.data.DataLoader(mnist_train, batch_size, shuffle=True)
test_data = gluon.data.DataLoader(mnist_test, batch_size, shuffle=False)

#初始化网络

net = gluon.nn.Sequential()
with net.name_scope():
    net.add(gluon.nn.Flatten())
    net.add(gluon.nn.Dense(256, activation="relu"))
    net.add(gluon.nn.Dense(10))
net.initialize()

#定义损失函数
softmax_cross_entropy = gluon.loss.SoftmaxCrossEntropyLoss()
#优化(训练)定义
trainer = gluon.Trainer(net.collect_params(), 'sgd', {
     'learning_rate': 0.5})
def accuracy(output, label):
    return nd.mean(output.argmax(axis=1) == label).asscalar()


def evaluate_accuracy(test_data, net):
    acc = .0
    for data, label in test_data:
        output = net(data)
        acc += accuracy(output, label)
    return acc / len(test_data)
for epoch in range(5):
    train_loss = 0.
    train_acc = 0.
    for data, label in train_data:
        with autograd.record():
            output = net(data)
            loss = softmax_cross_entropy(output, label)
        loss.backward()
        trainer.step(batch_size)#更新

        train_loss += nd.mean(loss).asscalar()
        train_acc += accuracy(output, label)

    test_acc = evaluate_accuracy(test_data, net)
    print("Epoch %d. Loss: %f, Train acc %f, Test acc %f" % (
        epoch, train_loss/len(train_data), train_acc/len(train_data), test_acc))


《动手学深度学习(李沐)》笔记3_第2张图片

转载于:https://www.cnblogs.com/yifdu25/p/8360142.html

你可能感兴趣的:(人工智能)