深度学习_12_softmax_图片识别优化版代码

因为图片识别很多代码都包装在d2l库里了,直接调用就行了

完整代码:

import torch
from torch import nn
from d2l import torch as d2l

"获取训练集&获取检测集"
batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)

net = nn.Sequential(nn.Flatten(), nn.Linear(784, 10)) # nn.Flatten()将28*28展平成784

"初始化w,b后者不操作默认初始化"
def init_weights(m):
    if type(m) == nn.Linear:
        nn.init.normal_(m.weight, std = 0.01)

net.apply(init_weights) # 给到所有模型

loss = nn.CrossEntropyLoss()

trainer = torch.optim.SGD(net.parameters(), lr=0.1) # net.parameters()将net中数据整合w,b给SGD

if __name__ == '__main__':
    num_epochs = 10
    cnt = 1
    for i in range(num_epochs):
        X, Y = d2l.train_epoch_ch3(net, train_iter, loss, trainer)
        print("训练次数: " + str(cnt))
        cnt += 1
        print("训练损失: {:.4f}".format(X))
        print("训练精度: {:.4f}".format(Y))
        print(".................................")

画图功能不兼容pycharm,所以还是朴素的用输出函数吧

深度学习_12_softmax_图片识别优化版代码_第1张图片

你可能感兴趣的:(深度学习)