[chapter 24][PyTorch][多分类问题实战]

前言

     这里面结合手写数字识别的例子,回顾一下多分类问题.

前面有讲过同一个模型,不同的训练方法差异很大主要有以下几个原因:

        epoch: 迭代次数

       learning_rate: 学习率

       网络深度

       权重系数的初始化

    

这里结合一个三层的神经网络实例,讲解一下

在训练过程中,发现通过torch.nn.init.kaiming_normal_ 初始化方法,相对

正太分布的初始化,训练的模型有了质的提升


一 手写数字识别例子

       

# -*- coding: utf-8 -*-
"""
Created on Wed Mar 29 16:46:03 2023

@author: chengxf2
"""

import torch
import torch.nn.functional as F
import torch.nn as nn
import torchvision



class classify:
    
    '''
    参数初始化
    [200,784],[200]
    '''
    def init_para(self):
        w1,b1= torch.rand(200,784, requires_grad =True),\
               torch.zeros(200, requires_grad= True)
        
        w2,b2= torch.rand(200,200, requires_grad =True),\
               torch.zeros(200, requires_grad= True)
               
        w3,b3= torch.rand(10,200, requires_grad =True),\
               torch.zeros(10, requires_grad= True)
               
        
        return  w1,b1,w2,b2,w3,b3
               
    
    def forward(self,x,w1,b1,w2,b2,w3,b3):
        
        #线性层1
        a1 = torch.matmul(x, w1.T)+b1
        o1 = F.relu(a1)
        
        #线性层2
        a2 = torch.matmul(o1, w2.T)+b2
        o2 = F.relu(a2)
        
        #线性层3
        a3 = torch.matmul(o2, w3.T)+b3
        o3 = F.relu(a3) #最后一层可以不用,后面要SOftmax +llr
        
        return o3
    
    '''
    torch.utils.data.DataLoader 
   
        作用:主要是对数据进行 batch 的划分
               数据加载器,结合了数据集和取样器,并且可以提供多个线程处理数据集。
               在训练模型时使用到此函数,用来 把训练数据分成多个小组 ,
               此函数 每次抛出一组数据 。直至把所有的数据都抛出。就是做一个数据的初始化
        优点:
              快速迭代数据。 
        
        参数:
              dataset=torch_dataset,       # torch TensorDataset format
              batch_size=BATCH_SIZE,       # mini batch size
              shuffle=True,                # 要不要打乱数据 (打乱比较好)
              num_workers=2,               # 多线程来读数据
         
           
    
    args:
        
    torchvision是pytorch的一个图形库,它服务于PyTorch深度学习框架的,主要用来构建计算机视觉模型。以下是torchvision的构成:
    1  torchvision.datasets: 一些加载数据的函数及常用的数据集接口;
    2 torchvision.transforms: 常用的图片变换,例如裁剪、旋转等;
    3 torchvision.utils: 其他的一些有用的方法。
    '''
    def load_data(self,batch_size_train=10 ,batch_size_test = 10):
        
        # 数据集
        train_loader = torch.utils.data.DataLoader(
                             torchvision.datasets.MNIST('./data/', train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),batch_size=batch_size_train, shuffle=True)
        
        
        test_loader = torch.utils.data.DataLoader(
        torchvision.datasets.MNIST('./data/', train=False, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),batch_size=batch_size_test, shuffle=True)
        
        return train_loader, test_loader
    
    
    '''
    训练过程
    '''
    
    def train(self,w1,b1,w2,b2,w3,b3,train_loader, epoch):
        
         
         optimizer = torch.optim.SGD([w1,b1,w2,b2,w3,b3],lr =self.learning_rate)
         criteon = nn.CrossEntropyLoss()

             
         for batch_idx, (data,target) in enumerate(train_loader):
                 
                 data = data.view(-1,28*28)
                 logits = self.forward(data, w1, b1, w2, b2, w3, b3)
                 
                 loss = criteon(logits, target)
                 
                 optimizer.zero_grad()
                 loss.backward()
                 optimizer.step() #更新梯度
                 
                 if batch_idx %1000 ==0:
                     print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                         epoch, batch_idx * len(data), len(train_loader.dataset),
                         100. * batch_idx / len(train_loader), loss.item()))
                     
     
    '''
    测试一下loss
    args
       test_loader, 测试数据集
    '''
    def test(self, test_loader):
         
            test_loss = 0
            correct = 0
            criteon = nn.CrossEntropyLoss()
            for data, target in test_loader:
                        data = data.view(-1,28*28)
                        logtis = self.forward(data)
                        
                        test_loss += criteon(logtis, target).item()
                        
                        pred = logtis.data.max(1, keepdim=True)[1]
                        correct += pred.eq(target.data).sum()
                        
            test_loss /= len(test_loader.dataset)
          
            print('nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)n'.format(
                test_loss, correct, len(test_loader.dataset),
                100. * correct / len(test_loader.dataset)))
         
         
        
        
    def run(self):
            
            #torch.nn.init.kaiming_normal_(w1)
            #torch.nn.init.kaiming_normal_(w2)
            #torch.nn.init.kaiming_normal_(w3)
            
            print("\n step1: loadData")
            train_loader, test_loader = self.load_data()
            
            print("\n step2: init paramater")
            w1,b1,w2,b2,w3,b3 = self.init_para() 
            #torch.nn.init.kaiming_normal_(w1)
            #torch.nn.init.kaiming_normal_(w2)
            #torch.nn.init.kaiming_normal_(w3)
            
            print("\n step3 train")
            
            
            for epoch in range(self.epochs):
                
                self.train(w1, b1, w2, b2, w3, b3, train_loader, epoch)

           
        
        
        
    
    def __init__(self):

        self.batch_size = 5
        self.learning_rate = 1e-3
        self.epochs = 10
        
        

if __name__ =="__main__":
    
     model = classify()
     model.run()
        

参考:

《课时13 手写数字识别初体验-5_哔哩哔哩_bilibili》

你可能感兴趣的:(pytorch,分类,深度学习)