import numpy as np
import torch
import torch.nn as nn
from torch.utils import data
import torchvision
import torchvision.models as models
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from tensorboardX import SummaryWriter
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
writer = SummaryWriter('runs/CIFAR10_resnet18')
trans_train = transforms.Compose([transforms.RandomResizedCrop((224, 224)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
trans_test = transforms.Compose([transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
data_path = './data'
trainset = torchvision.datasets.CIFAR10(data_path, train=True, transform=trans_train, download=True)
testset = torchvision.datasets.CIFAR10(data_path, train=False, transform=trans_test, download=False)
train_batch_size = 256
test_batch_size = 512
trainloader = torch.utils.data.DataLoader(trainset, batch_size=train_batch_size, shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=test_batch_size, shuffle=True, num_workers=2)
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
model = models.resnet18(pretrained=True)
# model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
# print(model)
for param in model.parameters():
param.requires_grad = False
model.fc = nn.Linear(512, 10)
model.to(device)
# # ---------------------- show the number of weight ----------------------
# total_params = sum(p.numel() for p in model.parameters())
# total_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
# print('total number of parameters:{}'.format(total_params))
# print('total number of trainable parameters:{}'.format(total_trainable_params))
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.fc.parameters(), lr=0.001, weight_decay=0.001, momentum=0.9)
# ---------------------- model training ----------------------
epochs = 20
train_epoch_loss, test_epoch_loss, train_epoch_acc, test_epoch_acc = [], [], [], [] # 用来记录每个epoch的训练、测试误差以及准确率
for epoch in range(epochs):
# -------------- train --------------
model.train()
train_loss, train_correct = 0, 0
for step, (train_img, train_label) in enumerate(trainloader):
train_img, train_label = train_img.to(device), train_label.to(device)
output = model(train_img)
loss = criterion(output, train_label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
correct_num = torch.sum(torch.argmax(output, dim=1) == train_label)
train_correct += correct_num
train_loss += loss
writer.add_scalar('train_loss_batch', loss.item(), step)
accurat_rate = correct_num / train_batch_size
writer.add_scalar('train_accurate_batch', accurat_rate.item(), step)
train_epoch_loss.append(train_loss / len(trainloader))
train_epoch_acc.append(train_correct / len(trainset))
writer.add_scalar('train_loss_epoch', train_loss / len(trainloader), epoch)
writer.add_scalar('train_accurate_epoch', train_correct / len(trainset), epoch)
# -------------- valid --------------
model.eval()
test_loss, test_correct = 0, 0
for test_img, test_label in testloader:
test_img, test_label = test_img.to(device), test_label.to(device)
output = model(test_img)
loss = criterion(output, test_label)
correct_num = torch.sum(torch.argmax(output, dim=1) == test_label)
test_correct += correct_num
test_loss += loss
test_epoch_loss.append(test_loss / len(testloader))
test_epoch_acc.append(test_correct / len(testset))
writer.add_scalar('test_loss_epoch', train_loss / len(trainloader), epoch)
writer.add_scalar('test_accurate_epoch', train_correct / len(trainset), epoch)
print('epoch{}, train_loss={}, train_acc={}'.format(epoch, train_loss/len(trainloader), train_correct/len(trainset)))
print('epoch{}, valid_loss={}, valid_acc={}'.format(epoch, test_loss/len(testloader),test_correct/len(testset)))
print('\n')
# ------------- plot the result -------------
train_loss_array = [loss.item() for loss in train_epoch_loss]
train_acc_array = [acc.item() for acc in train_epoch_acc]
test_loss_array = [loss.item() for loss in test_epoch_loss]
test_acc_array = [acc.item() for acc in test_epoch_acc]
plt.figure(figsize=(20, 10))
plt.subplot(221)
plt.title('loss')
plt.plot(np.arange(epochs), train_loss_array)
plt.plot(np.arange(epochs), test_loss_array)
plt.grid(True, which='both', axis='both', color='y', linestyle='--', linewidth=1)
plt.show()
plt.figure(figsize=(20, 10))
plt.subplot(222)
plt.title('accurate')
plt.plot(np.arange(epochs), train_acc_array)
plt.plot(np.arange(epochs), test_acc_array)
plt.grid(True, which='both', axis='both', color='y', linestyle='--', linewidth=1)
plt.legend(["train","validation"],loc='lower right')
plt.show()
# -------------- save the result -------------
result_dict = {'train_loss_array': train_loss_array,
'train_acc_array': train_acc_array,
'test_loss_array': test_loss_array,
'test_acc_array': test_acc_array}
np.save('./result_dict.npy', result_dict)
输出:
epoch0, train_loss=1.8071383237838745, train_acc=0.3887999951839447
epoch0, valid_loss=1.2278122901916504, valid_acc=0.6430000066757202
epoch1, train_loss=1.4005506038665771, train_acc=0.5360999703407288
epoch1, valid_loss=1.030735969543457, valid_acc=0.6850999593734741
epoch2, train_loss=1.2940409183502197, train_acc=0.5644800066947937
epoch2, valid_loss=0.9407730102539062, valid_acc=0.7059999704360962
epoch3, train_loss=1.2393066883087158, train_acc=0.578819990158081
epoch3, valid_loss=0.8911893963813782, valid_acc=0.715499997138977
epoch4, train_loss=1.2145596742630005, train_acc=0.5823799967765808
epoch4, valid_loss=0.8617193102836609, valid_acc=0.7218999862670898
epoch5, train_loss=1.1909451484680176, train_acc=0.5880199670791626
epoch5, valid_loss=0.8370893597602844, valid_acc=0.7269999980926514
epoch6, train_loss=1.182749629020691, train_acc=0.5904200077056885
epoch6, valid_loss=0.8229374289512634, valid_acc=0.7293999791145325
epoch7, train_loss=1.1616133451461792, train_acc=0.5995399951934814
epoch7, valid_loss=0.8094478845596313, valid_acc=0.7342000007629395
epoch8, train_loss=1.1525970697402954, train_acc=0.6015200018882751
epoch8, valid_loss=0.8026527762413025, valid_acc=0.7366999983787537
epoch9, train_loss=1.144952416419983, train_acc=0.6024999618530273
epoch9, valid_loss=0.7950977683067322, valid_acc=0.7354999780654907
epoch10, train_loss=1.140042781829834, train_acc=0.6040599942207336
epoch10, valid_loss=0.7850207686424255, valid_acc=0.7365999817848206
epoch11, train_loss=1.1367998123168945, train_acc=0.6043599843978882
epoch11, valid_loss=0.7832964658737183, valid_acc=0.7390999794006348
epoch12, train_loss=1.1333338022232056, train_acc=0.6078799962997437
epoch12, valid_loss=0.7704198956489563, valid_acc=0.7419999837875366
epoch13, train_loss=1.1298826932907104, train_acc=0.6068999767303467
epoch13, valid_loss=0.767668604850769, valid_acc=0.7426999807357788
epoch14, train_loss=1.1242992877960205, train_acc=0.6079999804496765
epoch14, valid_loss=0.773628830909729, valid_acc=0.7387999892234802
epoch15, train_loss=1.118688941001892, train_acc=0.6112200021743774
epoch15, valid_loss=0.757527232170105, valid_acc=0.7443000078201294
epoch16, train_loss=1.1208925247192383, train_acc=0.6098399758338928
epoch16, valid_loss=0.7577210068702698, valid_acc=0.7436999678611755
epoch17, train_loss=1.1159234046936035, train_acc=0.6102199554443359
epoch17, valid_loss=0.7527276873588562, valid_acc=0.746399998664856
epoch18, train_loss=1.1142677068710327, train_acc=0.6092199683189392
epoch18, valid_loss=0.7553915977478027, valid_acc=0.7448999881744385
epoch19, train_loss=1.1068326234817505, train_acc=0.6119199991226196
epoch19, valid_loss=0.7486104369163513, valid_acc=0.7450000047683716