网络模型(MLP-全连接神经网络)

概念

多层感知机,用于特征融合。
网络模型(MLP-全连接神经网络)_第1张图片
h = wx + b(w:权重,b:偏移量)
每个神经元之间都有自己的权重,参数很多,反向传播时更新。

实验(手写数字识别)

数据集:MNIST。
网络结构:全连接 + 标准化(BN) + 激活(ReLU)。
优化器:Adam。
损失函数:交叉熵(CrossEntropyLoss),自带 one-hot 类型和 softmax。
输出:one-hot 类型,结果为最大的索引值。

网络

import torch
from torch import nn


class MyNet(nn.Module):
    def __init__(self):
        super().__init__()
        # 全连接 + 标准化(BN) + 激活(ReLU)
        self.mlp = nn.Sequential(
            nn.Linear(1 * 28 * 28, 128), nn.BatchNorm1d(128), nn.ReLU(inplace=True),
            nn.Linear(128, 64), nn.BatchNorm1d(64), nn.ReLU(inplace=True),
            # 输出层:返回 one-hot 类型
            nn.Linear(64, 10)
        )

    def forward(self, x):
        return self.mlp(x)

训练

import torch
from torch import nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import os
from PIL import Image, ImageDraw, ImageFont
from matplotlib import pyplot as plt

from net import MyNet


batch_size = 100
net_path = r"modules/mynet.pth"

train_flag = False

# 数据集
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
if train_flag:
    dataset = datasets.MNIST(r"data", train=True, transform=transform, download=True)
    dataloader = DataLoader(dataset, batch_size, shuffle=True)
else:
    dataset = datasets.MNIST(r"data", train=False, transform=transform, download=False)
    dataloader = DataLoader(dataset, batch_size, shuffle=False)


if __name__ == '__main__':
    # 加载网络
    if os.path.isfile(net_path):
        net = torch.load(net_path)
    else:
        net = MyNet()
    opt = torch.optim.Adam(net.parameters())
    loss_fn = nn.CrossEntropyLoss()

    if train_flag:
        # 训练
        net.train()
        while True:
            for i, (x, y) in enumerate(dataloader):
                x = x.reshape(batch_size, -1)
                out = net(x)
                loss = loss_fn(out, y)
                opt.zero_grad()
                loss.backward()
                opt.step()
                # 结果是 one-hot 类型,取最大索引
                result = torch.argmax(out, 1)
                acc = torch.mean(torch.eq(result, y).float())
                print("i:{},loss:{:.5},acc:{:.3}".format(i, loss, acc))
            # 保存网络
            torch.save(net, net_path)
    else:
        # 测试
        net.eval()
        font = ImageFont.truetype(r"arial.ttf", size=10)
        plt.ion()
        for x, y in dataloader:
            # [n,c,h,w] → [h,w]
            img_array = x[0][0] * 255
            img = Image.fromarray(img_array.numpy())
            draw = ImageDraw.ImageDraw(img)

            x = x.reshape(batch_size, -1)
            out = net(x)
            result = torch.argmax(out, 1)
            draw.text((0, 0), str(result[0].item()), 255, font)

            plt.imshow(img)
            plt.pause(0.5)
        plt.ioff()

你可能感兴趣的:(AI)