【迁移学习】Transfer Learning

import numpy as np
import torch
import torch.nn as nn
from torch.utils import data
import torchvision
import torchvision.models as models
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from tensorboardX import SummaryWriter

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
writer = SummaryWriter('runs/CIFAR10_resnet18')

trans_train = transforms.Compose([transforms.RandomResizedCrop((224, 224)),
                                  transforms.RandomHorizontalFlip(),
                                  transforms.ToTensor(),
                                  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])

trans_test = transforms.Compose([transforms.Resize(256),
                                 transforms.CenterCrop(224),
                                 transforms.ToTensor(),
                                 transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])

data_path = './data'
trainset = torchvision.datasets.CIFAR10(data_path, train=True, transform=trans_train, download=True)
testset = torchvision.datasets.CIFAR10(data_path, train=False, transform=trans_test, download=False)

train_batch_size = 256
test_batch_size = 512
trainloader = torch.utils.data.DataLoader(trainset, batch_size=train_batch_size, shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=test_batch_size, shuffle=True, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

model = models.resnet18(pretrained=True)
# model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)

# print(model)
for param in model.parameters():
    param.requires_grad = False

model.fc = nn.Linear(512, 10)
model.to(device)

# # ---------------------- show the number of weight ----------------------
# total_params = sum(p.numel() for p in model.parameters())
# total_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
# print('total number of parameters:{}'.format(total_params))
# print('total number of trainable parameters:{}'.format(total_trainable_params))

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.fc.parameters(), lr=0.001, weight_decay=0.001, momentum=0.9)

# ---------------------- model training ----------------------

epochs = 20

train_epoch_loss, test_epoch_loss, train_epoch_acc, test_epoch_acc = [], [], [], []  # 用来记录每个epoch的训练、测试误差以及准确率

for epoch in range(epochs):

    # -------------- train --------------
    model.train()
    train_loss, train_correct = 0, 0
    for step, (train_img, train_label) in enumerate(trainloader):
        train_img, train_label = train_img.to(device), train_label.to(device)
        output = model(train_img)
        loss = criterion(output, train_label)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        correct_num = torch.sum(torch.argmax(output, dim=1) == train_label)
        train_correct += correct_num
        train_loss += loss

        writer.add_scalar('train_loss_batch', loss.item(), step)
        accurat_rate = correct_num / train_batch_size
        writer.add_scalar('train_accurate_batch', accurat_rate.item(), step)

    train_epoch_loss.append(train_loss / len(trainloader))
    train_epoch_acc.append(train_correct / len(trainset))
    writer.add_scalar('train_loss_epoch', train_loss / len(trainloader), epoch)
    writer.add_scalar('train_accurate_epoch', train_correct / len(trainset), epoch)
    
    

    # -------------- valid --------------

    model.eval()
    test_loss, test_correct = 0, 0
    for test_img, test_label in testloader:
        test_img, test_label = test_img.to(device), test_label.to(device)
        output = model(test_img)
        loss = criterion(output, test_label)

        correct_num = torch.sum(torch.argmax(output, dim=1) == test_label)
        test_correct += correct_num
        test_loss += loss

    test_epoch_loss.append(test_loss / len(testloader))
    test_epoch_acc.append(test_correct / len(testset))
    writer.add_scalar('test_loss_epoch', train_loss / len(trainloader), epoch)
    writer.add_scalar('test_accurate_epoch', train_correct / len(trainset), epoch)

    print('epoch{}, train_loss={}, train_acc={}'.format(epoch, train_loss/len(trainloader), train_correct/len(trainset)))
    print('epoch{}, valid_loss={}, valid_acc={}'.format(epoch, test_loss/len(testloader),test_correct/len(testset)))
    print('\n')

# -------------  plot the result  -------------

train_loss_array = [loss.item() for loss in train_epoch_loss] 
train_acc_array = [acc.item() for acc in train_epoch_acc]
test_loss_array = [loss.item() for loss in test_epoch_loss]
test_acc_array = [acc.item() for acc in test_epoch_acc]

plt.figure(figsize=(20, 10))
plt.subplot(221)
plt.title('loss')
plt.plot(np.arange(epochs), train_loss_array)
plt.plot(np.arange(epochs), test_loss_array)
plt.grid(True, which='both', axis='both', color='y', linestyle='--', linewidth=1)
plt.show()

plt.figure(figsize=(20, 10))
plt.subplot(222)
plt.title('accurate')
plt.plot(np.arange(epochs), train_acc_array)
plt.plot(np.arange(epochs), test_acc_array)
plt.grid(True, which='both', axis='both', color='y', linestyle='--', linewidth=1)
plt.legend(["train","validation"],loc='lower right')
plt.show()

# -------------- save the result  -------------
result_dict = {'train_loss_array': train_loss_array,
               'train_acc_array': train_acc_array,
               'test_loss_array': test_loss_array,
               'test_acc_array': test_acc_array}

np.save('./result_dict.npy', result_dict)

输出:

epoch0, train_loss=1.8071383237838745, train_acc=0.3887999951839447
epoch0, valid_loss=1.2278122901916504, valid_acc=0.6430000066757202

epoch1, train_loss=1.4005506038665771, train_acc=0.5360999703407288
epoch1, valid_loss=1.030735969543457, valid_acc=0.6850999593734741

epoch2, train_loss=1.2940409183502197, train_acc=0.5644800066947937
epoch2, valid_loss=0.9407730102539062, valid_acc=0.7059999704360962

epoch3, train_loss=1.2393066883087158, train_acc=0.578819990158081
epoch3, valid_loss=0.8911893963813782, valid_acc=0.715499997138977

epoch4, train_loss=1.2145596742630005, train_acc=0.5823799967765808
epoch4, valid_loss=0.8617193102836609, valid_acc=0.7218999862670898

epoch5, train_loss=1.1909451484680176, train_acc=0.5880199670791626
epoch5, valid_loss=0.8370893597602844, valid_acc=0.7269999980926514

epoch6, train_loss=1.182749629020691, train_acc=0.5904200077056885
epoch6, valid_loss=0.8229374289512634, valid_acc=0.7293999791145325

epoch7, train_loss=1.1616133451461792, train_acc=0.5995399951934814
epoch7, valid_loss=0.8094478845596313, valid_acc=0.7342000007629395

epoch8, train_loss=1.1525970697402954, train_acc=0.6015200018882751
epoch8, valid_loss=0.8026527762413025, valid_acc=0.7366999983787537

epoch9, train_loss=1.144952416419983, train_acc=0.6024999618530273
epoch9, valid_loss=0.7950977683067322, valid_acc=0.7354999780654907

epoch10, train_loss=1.140042781829834, train_acc=0.6040599942207336
epoch10, valid_loss=0.7850207686424255, valid_acc=0.7365999817848206

epoch11, train_loss=1.1367998123168945, train_acc=0.6043599843978882
epoch11, valid_loss=0.7832964658737183, valid_acc=0.7390999794006348

epoch12, train_loss=1.1333338022232056, train_acc=0.6078799962997437
epoch12, valid_loss=0.7704198956489563, valid_acc=0.7419999837875366

epoch13, train_loss=1.1298826932907104, train_acc=0.6068999767303467
epoch13, valid_loss=0.767668604850769, valid_acc=0.7426999807357788

epoch14, train_loss=1.1242992877960205, train_acc=0.6079999804496765
epoch14, valid_loss=0.773628830909729, valid_acc=0.7387999892234802

epoch15, train_loss=1.118688941001892, train_acc=0.6112200021743774
epoch15, valid_loss=0.757527232170105, valid_acc=0.7443000078201294

epoch16, train_loss=1.1208925247192383, train_acc=0.6098399758338928
epoch16, valid_loss=0.7577210068702698, valid_acc=0.7436999678611755

epoch17, train_loss=1.1159234046936035, train_acc=0.6102199554443359
epoch17, valid_loss=0.7527276873588562, valid_acc=0.746399998664856

epoch18, train_loss=1.1142677068710327, train_acc=0.6092199683189392
epoch18, valid_loss=0.7553915977478027, valid_acc=0.7448999881744385

epoch19, train_loss=1.1068326234817505, train_acc=0.6119199991226196
epoch19, valid_loss=0.7486104369163513, valid_acc=0.7450000047683716

【迁移学习】Transfer Learning_第1张图片

你可能感兴趣的:(经典网络学习,#,pytorch,迁移学习,深度学习,pytorch)