卷积神经网络(CNN)模型训练代码

使用pytorch框架搭建一个浅层的CNN网络,并训练该网络。训练的时候在每一个epoch,将训练集的准确率、召回率和精度以及测试集的准确率、召回率和精度写到tensorboardX的log中。本文是一个cnn的starter项目,cnn训练代码是通用的,只需要重新设计网络和准备数据就可以适配到其他项目中。下面直接上代码。

import torch
from torch import nn
import torch.nn.functional as F
import os
import tensorboardX
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


class VGGBaseSimpleS2(nn.Module):
    def __init__(self):
        super(VGGBaseSimpleS2, self).__init__()

        self.conv1 = nn.Sequential(
            nn.Conv2d(1, 12, kernel_size=3, stride=1, padding=1),
            #nn.BatchNorm2d(16),
            nn.ReLU()
        )
        # 6*6
        self.max_pooling1 = nn.MaxPool2d(kernel_size=2, stride=1)
        # 5*5
        self.conv2_1 = nn.Sequential(
            nn.Conv2d(12, 24, kernel_size=3, stride=1, padding=1),
            nn.ReLU()
        )
        self.max_pooling2_1 = nn.MaxPool2d(kernel_size=2, stride=1)
        # 4*4
        self.conv2_2 = nn.Sequential(
            nn.Conv2d(24, 24, kernel_size=3, stride=1, padding=1),
            nn.ReLU()
        )
        self.max_pooling2 = nn.MaxPool2d(kernel_size=2, stride=2)
        # 2*2
        # 2*2
        self.fc = nn.Linear(24*2*2, 2)

    def forward(self, x):
        batchsize = x.size(0)
        out = self.conv1(x)
        out = self.max_pooling1(out)

        out = self.conv2_1(out)
        out = self.conv2_2(out)
        out = self.max_pooling2(out)

        out = out.view(batchsize, -1)
        out = self.fc(out)
        out = F.log_softmax(out, dim=1)
        return out


class TrainingDataSet(Dataset):
    def __init__(self):
        super(TrainingDataSet, self).__init__()
        self.data_dict_X = X_train
        self.data_dict_y = y_train

    def __getitem__(self, index):
        t = self.data_dict_X[index, 0:36]
        t = torch.tensor(t).view(6, 6)
        return t, self.data_dict_y[index]

    def __len__(self):
        return len(self.data_dict_y)


class TestDataSet(Dataset):
    def __init__(self):
        super(TestDataSet, self).__init__()
        self.data_dict_X = X_validate
        self.data_dict_y = y_validate

    def __getitem__(self, index):
        t = self.data_dict_X[index, 0:36]
        t = torch.tensor(t).view(6, 6)
        return t, self.data_dict_y[index]

    def __len__(self):
        return len(self.data_dict_y)


def cnn_classification():

    batch_size = 256
    trainDataLoader = DataLoader(TrainingDataSet(), batch_size=batch_size, shuffle=False)
    testDataLoader = DataLoader(TestDataSet(), batch_size=batch_size, shuffle=False)

    epoch_num = 200
    #lr = 0.001
    lr = 0.001
    net = VGGBaseSimpleS2().to(device)
    print(net)
    # loss
    loss_func = nn.CrossEntropyLoss()

    # optimizer
    optimizer = torch.optim.Adam(net.parameters(), lr=lr)
    # optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.9)

    if not os.path.exists("logCNN"):
        os.mkdir("logCNN")
    writer = tensorboardX.SummaryWriter("logCNN")

    for epoch in range(epoch_num):

        train_sum_loss = 0
        train_sum_correct = 0
        train_sum_fp = 0
        train_sum_fn = 0
        train_sum_tp = 0
        train_sum_tn = 0
        for i, data in enumerate(trainDataLoader):
            net.train()
            inputs, labels = data
            inputs = inputs.unsqueeze(1).to(torch.float32)
            labels = labels.type(torch.LongTensor)

            inputs, labels = inputs.to(device), labels.to(device)
            outputs = net(inputs)
            loss = loss_func(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            _, pred = torch.max(outputs.data, dim=1)
            acc = pred.eq(labels.data).cpu().sum()

            one = torch.ones_like(labels)
            zero = torch.zeros_like(labels)
            tn = ((labels == zero) * (pred == zero)).sum()
            tp = ((labels == one) * (pred == one)).sum()
            fp = ((labels == zero) * (pred == one)).sum()
            fn = ((labels == one) * (pred == zero)).sum()
            train_sum_fn += fn.item()
            train_sum_fp += fp.item()
            train_sum_tn += tn.item()
            train_sum_tp += tp.item()

            train_sum_loss += loss.item()
            train_sum_correct += acc.item()

        train_loss = train_sum_loss * 1.0 / len(trainDataLoader)
        train_correct = train_sum_correct * 1.0 / len(trainDataLoader) / batch_size

        train_precision = train_sum_tp * 1.0 / (train_sum_fp + train_sum_tp)
        train_recall = train_sum_tp * 1.0 / (train_sum_fn + train_sum_tp)
        writer.add_scalar("train loss", train_loss, global_step=epoch)
        writer.add_scalar("train correct", train_correct, global_step=epoch)
        writer.add_scalar("train precision", train_precision, global_step=epoch)
        writer.add_scalar("train recall", train_recall, global_step=epoch)

        if not os.path.exists("models_aug_CNN"):
            os.mkdir("models_aug_CNN")
        torch.save(net.state_dict(), "models_aug_CNN/{}.pth".format(epoch + 1))
        scheduler.step()

        sum_loss = 0
        sum_correct = 0

        test_sum_fp = 0
        test_sum_fn = 0
        test_sum_tp = 0
        test_sum_tn = 0
        for i, data in enumerate(testDataLoader):
            net.eval()
            inputs, labels = data
            inputs = inputs.unsqueeze(1).to(torch.float32)
            labels = labels.type(torch.LongTensor)

            inputs, labels = inputs.to(device), labels.to(device)
            outputs = net(inputs)
            loss = loss_func(outputs, labels)

            _, pred = torch.max(outputs.data, dim=1)
            acc = pred.eq(labels.data).cpu().sum()

            one = torch.ones_like(labels)
            zero = torch.zeros_like(labels)
            tn = ((labels == zero) * (pred == zero)).sum()
            tp = ((labels == one) * (pred == one)).sum()
            fp = ((labels == zero) * (pred == one)).sum()
            fn = ((labels == one) * (pred == zero)).sum()
            test_sum_fn += fn.item()
            test_sum_fp += fp.item()
            test_sum_tn += tn.item()
            test_sum_tp += tp.item()

            sum_loss += loss.item()
            sum_correct += acc.item()

        test_precision = test_sum_tp * 1.0 / (test_sum_fp + test_sum_tp)
        test_recall = test_sum_tp * 1.0 / (test_sum_fn + test_sum_tp)
        test_loss = sum_loss * 1.0 / len(testDataLoader)
        test_correct = sum_correct * 1.0 / len(testDataLoader) / batch_size
        writer.add_scalar("test loss", test_loss, global_step=epoch + 1)
        writer.add_scalar("test correct", test_correct, global_step=epoch + 1)
        writer.add_scalar("test precision", test_precision, global_step=epoch + 1)
        writer.add_scalar("test recall", test_recall, global_step=epoch + 1)

        print("epoch is", epoch, "train loss", train_loss, "train correct", train_correct, "test loss is ",
              test_loss, "test correct is: ", test_correct, "train_precision: ", train_precision, "test_precision: ",
              test_precision, "train_recall: ", train_recall, "test_recall: ", test_recall)

    writer.close()

你可能感兴趣的:(机器学习,cnn,深度学习,pytorch,模型训练,starter)