深度学习模型架构-基础版-cpu

import torch
from torch import nn
from torch import optim
from torch.nn import functional as F
import torchvision

from matplotlib import pyplot

batch_size = 512
# 1-数据加载
transform = torchvision.transforms.Compose(
    [torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((0.1307,), (0.3081,))])
train_datasets = torchvision.datasets.MNIST('mnist', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_datasets, batch_size=batch_size, shuffle=True)
test_datasets = torchvision.datasets.MNIST('mnist', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(test_datasets, batch_size=batch_size, shuffle=True)

# 查看数据
x, y = next(iter(train_loader))
print(x.shape, y.shape, x.min(), x.max())


# 2-创建模型
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(784, 128, bias=True)
        self.fc2 = nn.Linear(128, 64, bias=True)
        self.fc3 = nn.Linear(64, 10, bias=True)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


# 3-定义模型、优化器
net = Net()
optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)


# 辅助工具 独热编码
def one_hot(label, depth=10):
    out = torch.zeros(label.size(0), depth)
    idx = torch.LongTensor(label).view(-1, 1)
    out.scatter_(dim=1, index=idx, value=1)
    return out


# 4-训练模型
for epoch in range(10):
    # 训练
    net.train()
    for batch_idx, (x, y) in enumerate(train_loader):
        x = x.view(x.size(0), x.size(1) * x.size(2) * x.size(3))
        out = net(x)
        y_onehot = one_hot(y)
        # 清理优化器,否则会造成叠加
        optimizer.zero_grad()
        loss = F.mse_loss(out, y_onehot)
        loss.backward()
        # w = w - lr*grad
        optimizer.step()

        # print(batch_idx, ' ', loss)

    #     测试
    corrects = 0
    net.eval()
    for batch_idx, (x, y) in enumerate(test_loader):
        x = x.view(x.size(0), x.size(1) * x.size(2) * x.size(3))
        out = net(x)
        pred = out.argmax(dim=1)
        correct = pred.eq(y).sum().float().item()
        corrects += correct
    totalnum = len(test_loader.dataset)
    acc = corrects / totalnum
    print(acc)

你可能感兴趣的:(深度学习算法与模型,深度学习,人工智能)