我们自己做的宝可梦数据集的图片数量只有1000多张
规模是很小的,而我们使用的是比较强大的resnet,所以很容易出现overfitting的情况
该怎么解决呢
Pokemon和ImageNet都是图片,存在某些共性knowledge
那么我们能不能利用ImageNet的模型来帮助我们提升宝可梦数据集分类的性能
这里不再用自己写的resnet18了,而是加载已经train好的resnet
这里由model=ResNet(18).to(device)变成了nn.Sequential()
resnet一共18层,取0-17层
即把学习好的knowledge解开,只取0-倒数第二层
然后Flatten()将512,1,1打平成512
将resnet18解包成一个公共知识A和一个新的知识B
公共知识A是来自一个已经train好的网络参数,包成一个新的网络A+B
train_transfer.py
import torch from torch import optim, nn import visdom import torchvision from torch.utils.data import DataLoader from pokemon import Pokemon #from resnet import ResNet18 from torchvision.models import resnet18 from utils import Flatten batchsz = 32 lr = 1e-3 epochs = 10 device = torch.device('cuda') torch.manual_seed(1234) train_db = Pokemon('dataset/pokemon',224,mode='train') val_db = Pokemon('dataset/pokemon',224,mode='val') test_db = Pokemon('dataset/pokemon',224,mode='test') train_loader = DataLoader(train_db, batch_size=batchsz, shuffle=True, num_workers=4) val_loader = DataLoader(val_db, batch_size=batchsz, shuffle=True, num_workers=2) test_loader = DataLoader(test_db, batch_size=batchsz, shuffle=True, num_workers=2) viz = visdom.Visdom() def evalute(model, loader): correct = 0 total = len(loader.dataset) for x,y in loader: x,y = x.to(device), y.to(device) with torch.no_grad(): logits = model(x) pred = logits.argmax(dim=1) correct += torch.eq(pred, y).sum().float().item() return correct/total def main(): #model = ResNet18(5).to(device) train_model = resnet18(pretrained=True) model = nn.Sequential(*list(train_model.children())[:-1], #[b,512,1,1] Flatten(), #[b,.512,1,1] => [b,512] nn.Linear(512,5) ).to(device) # x = torch.randn(2,3,224,224) # print(model(x).shape) optimizer = optim.Adam(model.parameters()) criteon = nn.CrossEntropyLoss() #接受的是logits best_acc, best_epoch = 0, 0 best_acc, best_epoch = 0,0 global_step=0 viz.line([0],[-1], win='loss', opts=dict(title='loss')) viz.line([0],[-1], win='val_acc', opts=dict(title='val_acc')) for epoch in range(epochs): for step, (x,y) in enumerate(train_loader): #x: [b,3,224,224] y:[b] x,y = x.to(device), y.to(device) logits = model(x) loss = criteon(logits,y) optimizer.zero_grad() loss.backward() optimizer.step() viz.line([loss.item()],[global_step], win='loss', update='append') global_step += 1 if epoch % 1 ==0: #2个epoch做一个validation val_acc = evalute(model, val_loader) if val_acc > best_acc: best_epoch = epoch best_acc = val_acc torch.save(model.state_dict(), 'best.mdl') viz.line([val_acc],[global_step], win='val_acc', update='append') print('best acc:', best_acc, 'best epoch:',best_epoch) model.load_state_dict(torch.load('best.mdl')) print('loaded from ckpt!') test_acc = evalute(model, test_loader) print('test acc:', test_acc) if __name__=='__main__': main()
utils.py
from matplotlib import pyplot as plt import torch from torch import nn class Flatten(nn.Module): def __init__(self): super(Flatten, self).__init__() def forward(self, x): shape = torch.prod(torch.tensor(x.shape[1:])).item() return x.view(-1, shape) def plot_image(img, label, name): fig = plt.figure() for i in range(6): plt.subplot(2, 3, i + 1) plt.tight_layout() plt.imshow(img[i][0]*0.3081+0.1307, cmap='gray', interpolation='none') plt.title("{}: {}".format(name, label[i].item())) plt.xticks([]) plt.yticks([]) plt.show()
可以看到,之前训练的结果是0.8+, 现在已经0.94了
所以迁移学习的效果是很好的