前言:
这里主要结合手写数字识别训练,验证过程,简单了解一下
Pytorch 主要应用的API函数,跟Numpy 不同,这里面有很多API
可以自动计算微分,梯度更新 等
参考:
CSDN
一 实现效果
1.1 梯度更新过程
1.2 测试集验证结果
test acc: 0.912 total_num: 10000
1.3 预测图像显示
下面为实际图像,上面为预测值
二 训练验证过程
# -*- coding: utf-8 -*-
"""
Created on Mon Nov 28 16:09:06 2022
@author: chengxf2
"""
import torch
import torch.optim as optim
from torch import nn
import torch.nn.functional as F
from down_data import load_pic
from util import one_hot
from util import plot_curve
from util import plot_image
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
#xW^T+b
self.fc1 = nn.Linear(28*28,256)
self.fc2 = nn.Linear(256, 64)
self.fc3 = nn.Linear(64, 10)
self.maxIter = 2
'''
前向传播
'''
def forward(self,x):
#x: [m,28*28] 图片个数m
#H1 =xW^T+b
h1 = self.fc1(x)
a1 =F.relu(h1)
#h2 = xW^T+b
h2 = self.fc2(a1)
a2 = F.relu(h2)
#h3 = h2w3+b3
h3 = self.fc3(a2)
#dim=0代表是列,dim=1代表是行
a3 = F.softmax(h3, dim=1)
return a3
'''
训练模型
args:
train_loader: 训练的数据集
'''
def train(train_loader):
#w1, b1, w2 b2, w3 b3 3层神经网络结构
net = Net()
optimizer = optim.SGD(net.parameters(),lr=0.01)
maxIter = 20
trainLoss = []
for epoch in range(maxIter): #对数据集进行递归
for batch_idx, (x,y) in enumerate(train_loader):
optimizer.zero_grad()
#print(x.shape, y.shape) #[m,1,28,28] 实际为4维
x =x.view(x.size(0),28*28) #图片维度切换[1,28*28]
out = net(x)
y_onehot = one_hot(y)#
loss = F.cross_entropy(out,y_onehot)
loss.backward() #计算梯度
optimizer.step() #梯度更新 w= w-lr*grad
trainLoss.append(loss.item()) #保存梯度
if batch_idx%1000 ==0:
print("epoch:%d batch_idx: %d loss: %7.4f"%(epoch, batch_idx, loss.item()))
#[w1,b1,w2,b2,w3,b3]
plot_curve(trainLoss)
return net
# 使用测试机来验证
def verify(test_loader,net):
total_correct = 0
for x,y in test_loader:
N = x.size(0) #样本个数
x = x.view(N,28*28)
out = net(x) #[50,10]
#print("\n out: ",out,out.shape) #[50, 10]
pred = out.argmax(dim = 1) #torch.Size([50]
correct = pred.eq(y).sum().float().item()
total_correct += correct
total_num = len(test_loader.dataset)
acc = total_correct/total_num
print("test acc: %7.3f"%acc,"\t total_num: ",total_num)
# 图形化显示
def verify_show(test_loader,net):
x,y = next(iter(test_loader)) #单独取一个batch
N = x.size(0)#N=1 ,1, 1, 28, 28
print("\n N: %d"%N,x.shape, x.type) #[50, 10]
X = x.view(N,28*28) #N 取决于batch_size
out = net(X)
#out [N,10] ==> pred: [N]
pred = out.argmax(dim=1)
plot_image(x,pred,'predict: ')
if __name__ == "__main__":
train_loader,test_loader = load_pic()
net = train(train_loader)
verify(test_loader, net)
#verify_show(test_loader, net)