pytorch下搭建网络训练并保存模型

最近在学习pytorch,使用mnist数据集,搭建AlexNet训练并保存模型,将代码做一记录。

建立数据集的方法见pytorch建立自己的数据集(以mnist为例)

搭建网络的方法见用pytorch搭建AlexNet(微调预训练模型及手动搭建)

训练代码如下:

import torch
import os
from torchvision import transforms
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
import DataProcessing as DP
import BuildModel as BM
import torch.nn as nn

if __name__ == '__main__':
    os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
    
    root_path = '/opt/Data/lixiang/ex./pytorch/Alexnet/data/'
    training_path = 'trainingset/'
    test_path = 'testset/'
    model_path = '/opt/Data/lixiang/ex./pytorch/Alexnet/model/'
    
    training_imgfile = training_path + 'trainingset_img.txt'
    training_labelfile = training_path + 'trainingset_label.txt'
    training_imgdata = training_path + 'img/'
    
    test_imgfile = test_path + 'testset_img.txt'
    test_labelfile = test_path + 'testset_label.txt'
    test_imgdata = test_path + 'img/'
    
    #parameter
    batch_size = 128
    epochs = 20
    model_type = 'pre'
    nclasses = 10
    lr = 0.01
    use_gpu = torch.cuda.is_available()
    
    transformations = transforms.Compose(
            [transforms.Scale(256),
             transforms.CenterCrop(224),
             transforms.ToTensor(),
             transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
                    ])
    
    dataset_train = DP.DataProcessingMnist(root_path, training_imgfile, training_labelfile, training_imgdata, transformations)
    dataset_test = DP.DataProcessingMnist(root_path, test_imgfile, test_labelfile, test_imgdata, transformations)
    
    num_train, num_test = len(dataset_train), len(dataset_test)
    
    train_loader = DataLoader(dataset_train, batch_size = batch_size, shuffle = True, num_workers = 0)
    test_loader = DataLoader(dataset_test, batch_size = batch_size, shuffle = False, num_workers = 0)
    
    # build model
    model = BM.BuildAlexNet(model_type, nclasses)
    optimizer = optim.SGD(model.parameters(), lr = lr)
    criterion = nn.CrossEntropyLoss()
    
    for epoch in range(epochs):
        epoch_loss = 0
        correct_num = 0
        for i, traindata in enumerate(train_loader):
            x_train, y_train = traindata
            if use_gpu:
                x_train, y_train = Variable(x_train.cuda()),Variable(y_train.cuda())
                model = model.cuda()
            else:
                x_train, y_train = Variable(x_train),Variable(y_train)
            y_pre = model(x_train)
            _, label_pre = torch.max(y_pre.data, 1)
            if use_gpu:
                y_pre = y_pre.cuda()
                label_pre = label_pre.cuda()
            model.zero_grad()
            loss = criterion(y_pre, y_train)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.data[0]
            correct_num += torch.sum(label_pre == y_train.data)        
            acc = (torch.sum(label_pre == y_train.data).float()/len(y_train))  
            print('batch loss: {} batch acc: {}'.format(loss.data[0],acc.data[0]))
        print('epoch: {} training loss: {}, training acc: {}'.format(epoch, epoch_loss, correct_num.float()/num_train))
        if (epoch+1) % 5 ==0:
            test_loss = 0
            test_acc_num = 0
            for j, testdata in enumerate(test_loader):
                x_test, y_test = testdata
                if use_gpu:
                    x_test, y_test = Variable(x_test.cuda()), Variable(y_test.cuda())
                else:
                    x_test, y_test = Variable(x_test), Variable(y_test)
                y_pre = model(x_test)
                _, label_pre = torch.max(y_pre.data, 1)
                loss = criterion(y_pre, y_test)
                test_loss += loss.data[0]
                test_acc_num += torch.sum(label_pre == y_test.data)
            print('epoch: {} test loss: {} test acc: {}'.format(epoch, test_loss, test_acc_num.float()/num_test))
    torch.save(model.state_dict(), model_path + 'AlexNet_params.pkl')

主要注意的是一些数据类型的问题,比如label的类型要是LongTensor,损失函数nn.CrossEntropyLoss() 的输入target要是类别编号而不是one-hot编码,使用gpu时要把model和输出y_pre,label_pre移动到gpu上。

你可能感兴趣的:(pytorch,python)