paddle模型训练

1 使用高层API进行训练

import paddle

# 指定在 CPU 上训练
paddle.device.set_device('cpu')

# 指定在 GPU 第 0 号卡上训练
# paddle.device.set_device('gpu:0')

##模型的训练和推理

###方法1:使用高层API进行训练评估和推理

from paddle.vision.transforms import Normalize

transform = Normalize(mean=[127.5], std=[127.5], data_format='CHW')
# 加载 MNIST 训练集和测试集
train_dataset = paddle.vision.datasets.MNIST(mode='train', transform=transform)
test_dataset = paddle.vision.datasets.MNIST(mode='test', transform=transform)

# 模型组网,构建并初始化一个模型 mnist
mnist = paddle.nn.Sequential(
    paddle.nn.Flatten(1, -1),
    paddle.nn.Linear(784, 512),
    paddle.nn.ReLU(),
    paddle.nn.Dropout(0.2),
    paddle.nn.Linear(512, 10)
)
print(type(mnist))
# 使用paddle.Model封装模型
model=paddle.Model(mnist)
# print(type(model))
# print(dir(model))

# 使用 Model.prepare 配置训练准备参数
# 为模型训练做准备,设置优化器及其学习率,并将网络的参数传入优化器,设置损失函数和精度计算方式
model.prepare(optimizer=paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters()),
              loss=paddle.nn.CrossEntropyLoss(),
              metrics=paddle.metric.Accuracy())
##使用 Model.fit 训练模型
model.fit(train_dataset,
          epochs=5,
          batch_size=64,
          verbose=1)

# 用 evaluate 在测试集上对模型进行验证
eval_result = model.evaluate(test_dataset, verbose=1)
print(eval_result)

2、基础API进行训练

import paddle

# 指定在 CPU 上训练
paddle.device.set_device('cpu')

from paddle.vision.transforms import Normalize

transform = Normalize(mean=[127.5], std=[127.5], data_format='CHW')
# 加载 MNIST 训练集和测试集
train_dataset = paddle.vision.datasets.MNIST(mode='train', transform=transform)
test_dataset = paddle.vision.datasets.MNIST(mode='test', transform=transform)

mnist = paddle.nn.Sequential(
    paddle.nn.Flatten(1, -1),
    paddle.nn.Linear(784, 512),
    paddle.nn.ReLU(),
    paddle.nn.Dropout(0.2),
    paddle.nn.Linear(512, 10)
)

# dataset与mnist的定义与使用高层API的内容一致
# 用 DataLoader 实现数据加载
train_loader = paddle.io.DataLoader(train_dataset, batch_size=64, shuffle=True)

# 将mnist模型及其所有子层设置为训练模式。这只会影响某些模块,如Dropout和BatchNorm。
mnist.train()

# 设置迭代次数
epochs = 5

# 设置优化器
optim = paddle.optimizer.Adam(parameters=mnist.parameters())
# 设置损失函数
loss_fn = paddle.nn.CrossEntropyLoss()
for epoch in range(epochs):
    for batch_id, data in enumerate(train_loader()):

        x_data = data[0]  # 训练数据
        y_data = data[1]  # 训练数据标签
        predicts = mnist(x_data)  # 预测结果

        # 计算损失 等价于 prepare 中loss的设置
        loss = loss_fn(predicts, y_data)

        # 计算准确率 等价于 prepare 中metrics的设置
        acc = paddle.metric.accuracy(predicts, y_data)

        # 下面的反向传播、打印训练信息、更新参数、梯度清零都被封装到 Model.fit() 中
        # 反向传播
        loss.backward()

        if (batch_id + 1) % 900 == 0:
            print("epoch: {}, batch_id: {}, loss is: {}, acc is: {}".format(epoch, batch_id + 1, loss.numpy(),
                                                                            acc.numpy()))
        # 更新参数
        optim.step()
        # 梯度清零
        optim.clear_grad()

# 加载测试数据集
test_loader = paddle.io.DataLoader(test_dataset, batch_size=64, drop_last=True)
# 设置损失函数
loss_fn = paddle.nn.CrossEntropyLoss()
# 将该模型及其所有子层设置为预测模式。这只会影响某些模块,如Dropout和BatchNorm
mnist.eval()
# 禁用动态图梯度计算
for batch_id, data in enumerate(test_loader()):

    x_data = data[0]  # 测试数据
    y_data = data[1]  # 测试数据标签
    predicts = mnist(x_data)  # 预测结果

    # 计算损失与精度
    loss = loss_fn(predicts, y_data)
    acc = paddle.metric.accuracy(predicts, y_data)

    # 打印信息
    if (batch_id + 1) % 30 == 0:
        print("batch_id: {}, loss is: {}, acc is: {}".format(batch_id + 1, loss.numpy(), acc.numpy()))


# 加载测试数据集
test_loader = paddle.io.DataLoader(test_dataset, batch_size=64, drop_last=True)
# 将该模型及其所有子层设置为预测模式
mnist.eval()
for batch_id, data in enumerate(test_loader()):
    # 取出测试数据
    x_data = data[0]
    # 获取预测结果
    predicts = mnist(x_data)
print("predict finished")

# 从测试集中取出一组数据
img, label = test_loader().next()

# 执行推理并打印结果
pred_label = mnist(img)[0].argmax()
print('true label: {}, pred label: {}'.format(label[0].item(), pred_label[0].item()))
# 可视化图片
from matplotlib import pyplot as plt
plt.imshow(img[0][0])




你可能感兴趣的:(paddle,python,深度学习)