使用pytorch框架搭建一个浅层的CNN网络,并训练该网络。训练的时候在每一个epoch,将训练集的准确率、召回率和精度以及测试集的准确率、召回率和精度写到tensorboardX的log中。本文是一个cnn的starter项目,cnn训练代码是通用的,只需要重新设计网络和准备数据就可以适配到其他项目中。下面直接上代码。
import torch
from torch import nn
import torch.nn.functional as F
import os
import tensorboardX
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class VGGBaseSimpleS2(nn.Module):
def __init__(self):
super(VGGBaseSimpleS2, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(1, 12, kernel_size=3, stride=1, padding=1),
#nn.BatchNorm2d(16),
nn.ReLU()
)
# 6*6
self.max_pooling1 = nn.MaxPool2d(kernel_size=2, stride=1)
# 5*5
self.conv2_1 = nn.Sequential(
nn.Conv2d(12, 24, kernel_size=3, stride=1, padding=1),
nn.ReLU()
)
self.max_pooling2_1 = nn.MaxPool2d(kernel_size=2, stride=1)
# 4*4
self.conv2_2 = nn.Sequential(
nn.Conv2d(24, 24, kernel_size=3, stride=1, padding=1),
nn.ReLU()
)
self.max_pooling2 = nn.MaxPool2d(kernel_size=2, stride=2)
# 2*2
# 2*2
self.fc = nn.Linear(24*2*2, 2)
def forward(self, x):
batchsize = x.size(0)
out = self.conv1(x)
out = self.max_pooling1(out)
out = self.conv2_1(out)
out = self.conv2_2(out)
out = self.max_pooling2(out)
out = out.view(batchsize, -1)
out = self.fc(out)
out = F.log_softmax(out, dim=1)
return out
class TrainingDataSet(Dataset):
def __init__(self):
super(TrainingDataSet, self).__init__()
self.data_dict_X = X_train
self.data_dict_y = y_train
def __getitem__(self, index):
t = self.data_dict_X[index, 0:36]
t = torch.tensor(t).view(6, 6)
return t, self.data_dict_y[index]
def __len__(self):
return len(self.data_dict_y)
class TestDataSet(Dataset):
def __init__(self):
super(TestDataSet, self).__init__()
self.data_dict_X = X_validate
self.data_dict_y = y_validate
def __getitem__(self, index):
t = self.data_dict_X[index, 0:36]
t = torch.tensor(t).view(6, 6)
return t, self.data_dict_y[index]
def __len__(self):
return len(self.data_dict_y)
def cnn_classification():
batch_size = 256
trainDataLoader = DataLoader(TrainingDataSet(), batch_size=batch_size, shuffle=False)
testDataLoader = DataLoader(TestDataSet(), batch_size=batch_size, shuffle=False)
epoch_num = 200
#lr = 0.001
lr = 0.001
net = VGGBaseSimpleS2().to(device)
print(net)
# loss
loss_func = nn.CrossEntropyLoss()
# optimizer
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
# optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.9)
if not os.path.exists("logCNN"):
os.mkdir("logCNN")
writer = tensorboardX.SummaryWriter("logCNN")
for epoch in range(epoch_num):
train_sum_loss = 0
train_sum_correct = 0
train_sum_fp = 0
train_sum_fn = 0
train_sum_tp = 0
train_sum_tn = 0
for i, data in enumerate(trainDataLoader):
net.train()
inputs, labels = data
inputs = inputs.unsqueeze(1).to(torch.float32)
labels = labels.type(torch.LongTensor)
inputs, labels = inputs.to(device), labels.to(device)
outputs = net(inputs)
loss = loss_func(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
_, pred = torch.max(outputs.data, dim=1)
acc = pred.eq(labels.data).cpu().sum()
one = torch.ones_like(labels)
zero = torch.zeros_like(labels)
tn = ((labels == zero) * (pred == zero)).sum()
tp = ((labels == one) * (pred == one)).sum()
fp = ((labels == zero) * (pred == one)).sum()
fn = ((labels == one) * (pred == zero)).sum()
train_sum_fn += fn.item()
train_sum_fp += fp.item()
train_sum_tn += tn.item()
train_sum_tp += tp.item()
train_sum_loss += loss.item()
train_sum_correct += acc.item()
train_loss = train_sum_loss * 1.0 / len(trainDataLoader)
train_correct = train_sum_correct * 1.0 / len(trainDataLoader) / batch_size
train_precision = train_sum_tp * 1.0 / (train_sum_fp + train_sum_tp)
train_recall = train_sum_tp * 1.0 / (train_sum_fn + train_sum_tp)
writer.add_scalar("train loss", train_loss, global_step=epoch)
writer.add_scalar("train correct", train_correct, global_step=epoch)
writer.add_scalar("train precision", train_precision, global_step=epoch)
writer.add_scalar("train recall", train_recall, global_step=epoch)
if not os.path.exists("models_aug_CNN"):
os.mkdir("models_aug_CNN")
torch.save(net.state_dict(), "models_aug_CNN/{}.pth".format(epoch + 1))
scheduler.step()
sum_loss = 0
sum_correct = 0
test_sum_fp = 0
test_sum_fn = 0
test_sum_tp = 0
test_sum_tn = 0
for i, data in enumerate(testDataLoader):
net.eval()
inputs, labels = data
inputs = inputs.unsqueeze(1).to(torch.float32)
labels = labels.type(torch.LongTensor)
inputs, labels = inputs.to(device), labels.to(device)
outputs = net(inputs)
loss = loss_func(outputs, labels)
_, pred = torch.max(outputs.data, dim=1)
acc = pred.eq(labels.data).cpu().sum()
one = torch.ones_like(labels)
zero = torch.zeros_like(labels)
tn = ((labels == zero) * (pred == zero)).sum()
tp = ((labels == one) * (pred == one)).sum()
fp = ((labels == zero) * (pred == one)).sum()
fn = ((labels == one) * (pred == zero)).sum()
test_sum_fn += fn.item()
test_sum_fp += fp.item()
test_sum_tn += tn.item()
test_sum_tp += tp.item()
sum_loss += loss.item()
sum_correct += acc.item()
test_precision = test_sum_tp * 1.0 / (test_sum_fp + test_sum_tp)
test_recall = test_sum_tp * 1.0 / (test_sum_fn + test_sum_tp)
test_loss = sum_loss * 1.0 / len(testDataLoader)
test_correct = sum_correct * 1.0 / len(testDataLoader) / batch_size
writer.add_scalar("test loss", test_loss, global_step=epoch + 1)
writer.add_scalar("test correct", test_correct, global_step=epoch + 1)
writer.add_scalar("test precision", test_precision, global_step=epoch + 1)
writer.add_scalar("test recall", test_recall, global_step=epoch + 1)
print("epoch is", epoch, "train loss", train_loss, "train correct", train_correct, "test loss is ",
test_loss, "test correct is: ", test_correct, "train_precision: ", train_precision, "test_precision: ",
test_precision, "train_recall: ", train_recall, "test_recall: ", test_recall)
writer.close()