Pytorch 05-进阶训练技巧

5.1 自定义损失函数

PyTorch在torch.nn模块为我们提供了许多常用的损失函数,比如:MSELoss,L1Loss,BCELoss… 但是随着深度学习的发展,出现了越来越多的非官方提供的Loss,比如DiceLoss,HuberLoss,SobolevLoss… 这些Loss Function专门针对一些非通用的模型,PyTorch不能将他们全部添加到库中去,因此这些损失函数的实现则需要我们通过自定义损失函数来实现。另外,在科学研究中,我们往往会提出全新的损失函数来提升模型的表现,这时我们既无法使用PyTorch自带的损失函数,也没有相关的博客供参考,此时自己实现损失函数就显得更为重要了。

经过本节的学习,你将收获:

  • 掌握如何自定义损失函数

5.1.1 以函数方式定义

事实上,损失函数仅仅是一个函数而已,因此我们可以通过直接以函数定义的方式定义一个自己的函数,如下所示:

def my_loss(output, target):
    loss = torch.mean((output - target)**2)
    return loss

5.1.2 以类方式定义

虽然以函数定义的方式很简单,但是以类方式定义更加常用,在以类方式定义损失函数时,我们如果看每一个损失函数的继承关系我们就可以发现Loss函数部分继承自_loss, 部分继承自_WeightedLoss, 而_WeightedLoss继承自_loss _loss继承自 nn.Module。我们可以将其当作神经网络的一层来对待,同样地,我们的损失函数类就需要继承自nn.Module类,在下面的例子中我们以DiceLoss为例向大家讲述。

Dice Loss是一种在分割领域常见的损失函数,定义如下:

D S C = 2 ∣ X ∩ Y ∣ ∣ X ∣ + ∣ Y ∣ DSC = \frac{2|X∩Y|}{|X|+|Y|} DSC=X+Y2∣XY
实现代码如下:

import torch.nn as nn
class DiceLoss(nn.Module):
    def __init__(self,weight=None,size_average=True):
        super(DiceLoss,self).__init__()
        
    def forward(self,inputs,targets,smooth=1):
        inputs = F.sigmoid(inputs)       
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        intersection = (inputs * targets).sum()                   
        dice = (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)  
        return 1 - dice
# 使用方法    
criterion = DiceLoss()
loss = criterion(input,targets)

除此之外,常见的损失函数还有BCE-Dice Loss,Jaccard/Intersection over Union (IoU) Loss,Focal Loss…

class DiceBCELoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(DiceBCELoss, self).__init__()

    def forward(self, inputs, targets, smooth=1):
        inputs = F.sigmoid(inputs)       
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        intersection = (inputs * targets).sum()                     
        dice_loss = 1 - (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)  
        BCE = F.binary_cross_entropy(inputs, targets, reduction='mean')
        Dice_BCE = BCE + dice_loss
        
        return Dice_BCE
--------------------------------------------------------------------
    
class IoULoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(IoULoss, self).__init__()

    def forward(self, inputs, targets, smooth=1):
        inputs = F.sigmoid(inputs)       
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        intersection = (inputs * targets).sum()
        total = (inputs + targets).sum()
        union = total - intersection 
        
        IoU = (intersection + smooth)/(union + smooth)
                
        return 1 - IoU
--------------------------------------------------------------------
    
ALPHA = 0.8
GAMMA = 2

class FocalLoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(FocalLoss, self).__init__()

    def forward(self, inputs, targets, alpha=ALPHA, gamma=GAMMA, smooth=1):
        inputs = F.sigmoid(inputs)       
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        BCE = F.binary_cross_entropy(inputs, targets, reduction='mean')
        BCE_EXP = torch.exp(-BCE)
        focal_loss = alpha * (1-BCE_EXP)**gamma * BCE
                       
        return focal_loss

注:

在自定义损失函数时,涉及到数学运算时,我们最好全程使用PyTorch提供的张量计算接口,这样就不需要我们实现自动求导功能并且我们可以直接调用cuda,使用numpy或者scipy的数学运算时,操作会有些麻烦,大家可以自己下去进行探索。关于PyTorch使用Class定义损失函数的原因,可以参考PyTorch的讨论区(https://discuss.pytorch.org/t/should-i-define-my-custom-loss-function-as-a-class/89468 )

5.2 动态调整学习率

学习率的选择是深度学习中一个困扰人们许久的问题,学习速率设置过小,会极大降低收敛速度,增加训练时间;学习率太大,可能导致参数在最优解两侧来回振荡。但是当我们选定了一个合适的学习率后,经过许多轮的训练后,可能会出现准确率震荡或loss不再下降等情况,说明当前学习率已不能满足模型调优的需求。此时我们就可以通过一个适当的学习率衰减策略来改善这种现象,提高我们的精度。这种设置方式在PyTorch中被称为scheduler,也是我们本节所研究的对象。

经过本节的学习,你将收获:

  • 如何根据需要选取已有的学习率调整策略
  • 如何自定义设置学习调整策略并实现

5.2.1 使用官方scheduler

  • 了解官方提供的API

在训练神经网络的过程中,学习率是最重要的超参数之一,作为当前较为流行的深度学习框架,PyTorch已经在torch.optim.lr_scheduler为我们封装好了一些动态调整学习率的方法供我们使用,如下面列出的这些scheduler。

  • lr_scheduler.LambdaLR
  • lr_scheduler.MultiplicativeLR
  • lr_scheduler.StepLR
  • lr_scheduler.MultiStepLR
  • lr_scheduler.ExponentialLR
  • lr_scheduler.CosineAnnealingLR
  • lr_scheduler.ReduceLROnPlateau
  • lr_scheduler.CyclicLR
  • lr_scheduler.OneCycleLR
  • lr_scheduler.CosineAnnealingWarmRestarts
  • 使用官方API

关于如何使用这些动态调整学习率的策略,PyTorch官方也很人性化的给出了使用实例代码帮助大家理解,我们也将结合官方给出的代码来进行解释。

# 选择一种优化器
optimizer = torch.optim.Adam(...) 
# 选择上面提到的一种或多种动态调整学习率的方法
scheduler1 = torch.optim.lr_scheduler.... 
scheduler2 = torch.optim.lr_scheduler....
...
schedulern = torch.optim.lr_scheduler....
for epoch in range(100):
    train(...)
    validate(...)
    optimizer.step()
    # 需要在优化器参数更新之后再动态调整学习率
	scheduler1.step() 
	...
    schedulern.step()

我们在使用官方给出的torch.optim.lr_scheduler时,需要将scheduler.step()放在optimizer.step()后面进行使用。

5.2.2 自定义scheduler

虽然PyTorch官方给我们提供了许多的API,但是在实验中也有可能碰到需要我们自己定义学习率调整策略的情况,而我们的方法是自定义函数adjust_learning_rate来改变param_grouplr的值,在下面的叙述中会给出一个简单的实现。

假设我们现在正在做实验,需要学习率每30轮下降为原来的1/10,假设已有的官方API中没有符合我们需求的,那就需要自定义函数来实现学习率的改变。

def adjust_learning_rate(optimizer, epoch):
    lr = args.lr * (0.1 ** (epoch // 30))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

有了adjust_learning_rate函数的定义,在训练的过程就可以调用我们的函数来实现学习率的动态变化

def adjust_learning_rate(optimizer,...):
    ...
optimizer = torch.optim.SGD(model.parameters(),lr = args.lr,momentum = 0.9)
for epoch in range(10):
    train(...)
    validate(...)
    adjust_learning_rate(optimizer,epoch)

5.3 模型微调 - timm

除了使用torchvision.models进行预训练以外,还有一个常见的预训练模型库,叫做timm,这个库是由来自加拿大温哥华Ross Wightman创建的。里面提供了许多计算机视觉的SOTA模型,可以当作是torchvision的扩充版本,并且里面的模型在准确度上也较高。在本章内容中,我们主要是针对这个库的预训练模型的使用做叙述,其他部分内容(数据扩增,优化器等)如果大家感兴趣,可以参考以下两个链接。

  • Github链接:https://github.com/rwightman/pytorch-image-models
  • 官网链接:https://fastai.github.io/timmdocs/
    https://rwightman.github.io/pytorch-image-models/

timm的安装

关于timm的安装,我们可以选择以下两种方式进行:

  1. 通过pip安装
pip install timm
  1. 通过git与pip进行安装
git clone https://github.com/rwightman/pytorch-image-models
cd pytorch-image-models && pip install -e .

如何查看预训练模型种类

  1. 查看timm提供的预训练模型
    截止到2022.3.27日为止,timm提供的预训练模型已经达到了592个,我们可以通过timm.list_models()方法查看timm提供的预训练模型(注:本章测试代码均是在jupyter notebook上进行)
import timm
avail_pretrained_models = timm.list_models(pretrained=True)
len(avail_pretrained_models)
716
  1. 查看特定模型的所有种类
    每一种系列可能对应着不同方案的模型,比如Resnet系列就包括了ResNet18,50,101等模型,我们可以在timm.list_models()传入想查询的模型名称(模糊查询),比如我们想查询densenet系列的所有模型。
all_densnet_models = timm.list_models("*densenet*")
all_densnet_models
['densenet121',
 'densenet121d',
 'densenet161',
 'densenet169',
 'densenet201',
 'densenet264',
 'densenet264d_iabn',
 'densenetblur121d',
 'tv_densenet121']

我们发现以列表的形式返回了所有densenet系列的所有模型。

  1. 查看模型的具体参数
    当我们想查看下模型的具体参数的时候,我们可以通过访问模型的default_cfg属性来进行查看,具体操作如下
model = timm.create_model('resnet34',num_classes=10,pretrained=True)
model.default_cfg
Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet34-43635321.pth" to C:\Users\Administrator/.cache\torch\hub\checkpoints\resnet34-43635321.pth
{'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet34-43635321.pth',
 'num_classes': 1000,
 'input_size': (3, 224, 224),
 'pool_size': (7, 7),
 'crop_pct': 0.875,
 'interpolation': 'bilinear',
 'mean': (0.485, 0.456, 0.406),
 'std': (0.229, 0.224, 0.225),
 'first_conv': 'conv1',
 'classifier': 'fc',
 'architecture': 'resnet34'}

除此之外,我们可以通过访问这个链接 查看提供的预训练模型的准确度等信息。

使用和修改预训练模型

在得到我们想要使用的预训练模型后,我们可以通过timm.create_model()的方法来进行模型的创建,我们可以通过传入参数pretrained=True,来使用预训练模型。同样的,我们也可以使用跟torchvision里面的模型一样的方法查看模型的参数,类型/

import timm
import torch
model = timm.create_model('resnet34',pretrained=True)
x = torch.randn(1,3,224,224)
output = model(x)
output.shape
torch.Size([1, 1000])
  • 查看某一层模型参数(以第一层卷积为例)
model = timm.create_model('resnet34',pretrained=True)
list(dict(model.named_children())['conv1'].parameters())
[Parameter containing:
 tensor([[[[-2.9398e-02, -3.6421e-02, -2.8832e-02,  ..., -1.8349e-02,
            -6.9210e-03,  1.2127e-02],
           [-3.6199e-02, -6.0810e-02, -5.3891e-02,  ..., -4.2744e-02,
            -7.3169e-03, -1.1834e-02],
           [-5.5227e-02, -7.2939e-02, -7.2351e-02,  ..., -7.1478e-02,
            -3.4954e-02, -4.0736e-02],
           ...,
           [-2.3535e-02, -5.1401e-02, -6.0092e-02,  ..., -7.2196e-02,
            -4.7019e-02, -2.5696e-02],
           [ 7.5010e-04, -1.6182e-02, -2.8829e-02,  ..., -3.6269e-02,
             1.1061e-02,  1.2223e-02],
           [ 8.1931e-03, -2.2972e-02, -1.6335e-02,  ..., -3.1251e-02,
             1.0593e-02, -6.2252e-03]],
 
          [[ 1.5982e-02,  3.5244e-03,  3.9051e-03,  ...,  2.4932e-02,
             2.8518e-02,  2.5999e-02],
           [ 5.7787e-03, -1.7177e-02, -1.6884e-02,  ...,  4.3224e-04,
             3.5190e-02,  3.2971e-02],
           [-1.1372e-02, -3.4875e-02, -4.0418e-02,  ..., -3.1364e-02,
            -3.9620e-03,  3.5324e-03],
           ...,
           [ 1.8378e-02, -1.2310e-02, -2.3896e-02,  ..., -3.4648e-02,
            -5.9046e-03,  9.5815e-03],
           [ 3.3632e-02,  1.8952e-02, -3.9773e-03,  ..., -8.6194e-03,
             4.1143e-02,  3.8466e-02],
           [ 2.8176e-02,  1.4781e-02,  2.0247e-03,  ..., -5.2044e-03,
             1.5965e-02,  2.7605e-02]],
 
          [[ 5.9068e-03, -2.8958e-03,  1.7977e-03,  ...,  3.4210e-02,
             2.4736e-02,  3.4174e-02],
           [ 6.4512e-03, -1.8409e-02, -1.9689e-02,  ...,  5.4515e-03,
             2.0301e-02,  2.1734e-02],
           [-1.3722e-02, -3.8365e-02, -4.8514e-02,  ..., -3.8446e-02,
            -2.2309e-02, -1.9627e-02],
           ...,
           [ 2.4020e-02, -7.6510e-03, -2.7373e-02,  ..., -3.1334e-02,
            -2.3510e-02,  9.8293e-03],
           [ 3.6783e-02,  2.5813e-02, -7.5813e-04,  ..., -8.0455e-03,
             2.6314e-02,  3.6545e-02],
           [ 4.2100e-02,  1.3871e-02,  9.2876e-03,  ...,  2.4132e-03,
             2.3073e-02,  2.0096e-02]]],


[[[ 1.2699e-02, 7.8746e-02, 9.2036e-02, …, 6.0966e-02,
3.8728e-02, 1.9452e-02],
[ 6.9079e-02, 1.0348e-01, 1.2755e-01, …, 8.5614e-02,
6.7287e-02, 4.6632e-02],
[ 3.1290e-02, 4.5216e-02, 7.1451e-02, …, 6.3798e-02,
3.3150e-02, 1.2368e-02],
…,
[-4.6626e-02, -1.0566e-01, -1.0894e-01, …, -1.1296e-01,
-1.0397e-01, -1.0445e-01],
[-2.4946e-02, -1.1115e-01, -1.2051e-01, …, -1.0522e-01,
-7.1987e-02, -5.7973e-02],
[-3.6680e-02, -1.5061e-01, -1.4171e-01, …, -1.4660e-01,
-1.0837e-01, -6.9888e-02]],

          [[ 2.9085e-02,  1.0295e-01,  1.2334e-01,  ...,  9.6272e-02,
             8.8013e-02,  3.7989e-02],
           [ 7.4002e-02,  1.1823e-01,  1.4064e-01,  ...,  1.0589e-01,
             8.2465e-02,  4.4503e-02],
           [ 2.0373e-02,  4.3695e-02,  6.0397e-02,  ...,  6.2300e-02,
             3.5771e-02,  7.4404e-03],
           ...,
           [-7.1929e-02, -1.6460e-01, -1.6390e-01,  ..., -1.8566e-01,
            -1.4208e-01, -1.3558e-01],
           [-6.4060e-02, -1.5390e-01, -1.8820e-01,  ..., -1.6545e-01,
            -1.1246e-01, -8.1300e-02],
           [-1.8824e-02, -1.5407e-01, -1.5888e-01,  ..., -1.5950e-01,
            -1.2171e-01, -6.0570e-02]],
 
          [[-3.6927e-03,  4.2394e-02,  4.3013e-02,  ...,  5.4485e-02,
             3.4540e-02,  1.5897e-02],
           [ 6.9949e-02,  8.1008e-02,  9.5157e-02,  ...,  6.6470e-02,
             5.0810e-02,  3.4765e-02],
           [ 4.4855e-02,  2.6341e-02,  5.7328e-02,  ...,  3.4239e-02,
             2.3280e-02,  1.2214e-02],
           ...,
           [-3.6780e-02, -9.1447e-02, -1.0104e-01,  ..., -1.3580e-01,
            -1.0961e-01, -8.6838e-02],
           [-3.3737e-03, -9.1030e-02, -1.1993e-01,  ..., -8.7861e-02,
            -5.6091e-02, -8.7438e-03],
           [ 5.9034e-03, -9.0646e-02, -9.0353e-02,  ..., -9.2046e-02,
            -4.5530e-02,  2.1971e-03]]],


[[[-3.0788e-02, 3.7539e-02, 1.3240e-02, …, -6.9663e-02,
-2.4128e-02, -2.2348e-02],
[ 3.3298e-02, 7.0664e-02, 4.1210e-02, …, -8.3079e-02,
-9.5852e-02, -1.2252e-01],
[ 3.5788e-02, 8.9131e-02, 9.1836e-02, …, -9.1667e-02,
-1.1329e-01, -1.4652e-01],
…,
[ 2.3059e-02, 1.0482e-01, 7.3017e-02, …, -7.1013e-02,
-1.1391e-01, -1.3957e-01],
[ 1.0296e-02, 7.8055e-02, 5.8967e-02, …, -5.2540e-02,
-9.0790e-02, -1.1660e-01],
[-1.0729e-02, 5.2049e-02, 2.4772e-02, …, -4.6772e-02,
-6.7351e-02, -1.0711e-01]],

          [[-6.2012e-03,  5.1805e-02,  3.8967e-02,  ..., -1.1435e-01,
            -8.3199e-02, -5.3895e-02],
           [ 5.9550e-02,  1.0831e-01,  5.8051e-02,  ..., -1.4360e-01,
            -1.6896e-01, -1.6967e-01],
           [ 9.2306e-02,  1.4292e-01,  1.2316e-01,  ..., -1.4846e-01,
            -1.8625e-01, -1.9735e-01],
           ...,
           [ 8.1894e-02,  1.6054e-01,  1.2899e-01,  ..., -1.2706e-01,
            -1.7632e-01, -1.8648e-01],
           [ 4.7747e-02,  1.1020e-01,  7.7540e-02,  ..., -9.0749e-02,
            -1.5322e-01, -1.6463e-01],
           [ 2.6540e-02,  6.9399e-02,  5.2547e-02,  ..., -9.7936e-02,
            -1.2632e-01, -1.6084e-01]],
 
          [[-3.0269e-02,  3.2487e-02,  9.7011e-03,  ..., -7.3868e-02,
            -2.9728e-02, -5.6681e-03],
           [ 2.6645e-02,  6.3480e-02,  2.8077e-02,  ..., -1.0677e-01,
            -1.0693e-01, -1.0135e-01],
           [ 5.3862e-02,  1.0905e-01,  8.7459e-02,  ..., -1.2461e-01,
            -1.3154e-01, -1.2584e-01],
           ...,
           [ 4.9134e-02,  1.4032e-01,  9.9532e-02,  ..., -9.8972e-02,
            -1.2388e-01, -1.0447e-01],
           [ 2.0668e-02,  9.0405e-02,  5.4750e-02,  ..., -8.3718e-02,
            -8.8763e-02, -9.4110e-02],
           [ 1.7352e-03,  6.6561e-02,  4.5325e-02,  ..., -7.7592e-02,
            -6.7677e-02, -7.7212e-02]]],


…,


[[[-1.7410e-08, -2.3751e-08, -3.4891e-08, …, -3.7248e-08,
-2.1249e-08, -1.2071e-08],
[ 9.6440e-09, 1.1014e-09, 9.4535e-10, …, -1.0436e-08,
3.8801e-09, 1.2071e-08],
[ 3.2439e-09, -7.0286e-09, -2.1053e-10, …, -1.1040e-08,
-2.0537e-09, 1.5091e-08],
…,
[ 1.7211e-08, -3.7678e-09, -4.5318e-09, …, -1.8988e-08,
-3.6635e-09, 1.6287e-09],
[ 1.3862e-08, 4.3777e-09, -5.2030e-09, …, -1.1054e-08,
-3.7851e-09, 2.0465e-09],
[-2.4859e-09, 4.7488e-09, -7.5893e-12, …, -7.3165e-09,
-1.2239e-09, 8.7856e-09]],

          [[-5.6945e-08, -7.4338e-08, -8.0817e-08,  ..., -8.6512e-08,
            -8.0827e-08, -6.8603e-08],
           [-2.3565e-08, -4.2844e-08, -4.9689e-08,  ..., -5.7269e-08,
            -4.5902e-08, -2.6869e-08],
           [-2.3128e-08, -4.3358e-08, -5.2465e-08,  ..., -5.6927e-08,
            -5.1460e-08, -2.8609e-08],
           ...,
           [-2.2332e-08, -3.5878e-08, -4.2284e-08,  ..., -6.2352e-08,
            -5.2872e-08, -3.7298e-08],
           [-3.3191e-08, -3.2705e-08, -4.5063e-08,  ..., -5.3950e-08,
            -4.8288e-08, -3.6929e-08],
           [-3.7008e-08, -2.9574e-08, -3.4617e-08,  ..., -4.3226e-08,
            -3.4198e-08, -2.3976e-08]],
 
          [[-3.6700e-08, -5.3476e-08, -5.3741e-08,  ..., -7.8672e-08,
            -7.8789e-08, -6.6036e-08],
           [ 6.6692e-10, -5.9142e-09, -1.4588e-08,  ..., -2.0856e-08,
            -2.1511e-08, -1.0350e-08],
           [ 9.3269e-10, -1.2033e-08, -2.0487e-08,  ..., -2.0307e-08,
            -2.7283e-08, -1.2541e-08],
           ...,
           [-2.5785e-09, -4.0823e-09, -7.0379e-09,  ..., -3.1335e-08,
            -3.0376e-08, -1.8040e-08],
           [-2.0965e-08, -2.6964e-09, -1.5237e-08,  ..., -3.6272e-08,
            -3.9100e-08, -2.7551e-08],
           [-4.4247e-08, -1.6797e-08, -1.1698e-08,  ..., -2.8883e-08,
            -3.0110e-08, -2.3072e-08]]],


[[[ 2.4550e-02, -6.4176e-02, -6.2492e-02, …, -3.1555e-02,
2.1815e-02, 3.6154e-02],
[-4.2626e-03, -1.1115e-01, -1.1057e-01, …, -1.6824e-02,
7.1532e-02, 1.1135e-01],
[-9.8472e-03, -1.3956e-01, -1.2700e-01, …, 2.4138e-02,
9.2866e-02, 1.2774e-01],
…,
[-2.2453e-02, -1.2908e-01, -1.4458e-01, …, 2.4338e-02,
1.0059e-01, 1.3346e-01],
[-1.1554e-02, -1.1118e-01, -1.2784e-01, …, 2.9475e-02,
7.8353e-02, 9.2674e-02],
[-7.4159e-03, -7.9511e-02, -8.4516e-02, …, -6.8034e-03,
5.7650e-02, 5.3259e-02]],

          [[ 2.2340e-02, -6.0133e-02, -7.7689e-02,  ..., -3.1634e-02,
             2.8327e-02,  5.5094e-02],
           [-3.1156e-02, -1.2689e-01, -1.5196e-01,  ..., -2.0853e-02,
             7.9520e-02,  1.4480e-01],
           [-3.0083e-02, -1.3677e-01, -1.6207e-01,  ..., -3.5956e-03,
             1.1898e-01,  1.7122e-01],
           ...,
           [-9.5634e-05, -1.2835e-01, -1.7461e-01,  ...,  9.7906e-03,
             1.2607e-01,  1.7413e-01],
           [ 7.9960e-03, -1.1012e-01, -1.5902e-01,  ..., -1.8531e-02,
             7.6518e-02,  1.1004e-01],
           [ 1.0051e-02, -6.7243e-02, -1.0776e-01,  ..., -3.3451e-02,
             5.4970e-02,  8.3284e-02]],
 
          [[ 2.2734e-02, -4.2145e-03, -3.2995e-02,  ..., -3.5991e-02,
            -8.9189e-03,  3.8408e-03],
           [-7.3349e-03, -5.0639e-02, -7.0251e-02,  ..., -2.6339e-02,
             4.2522e-02,  7.6488e-02],
           [-1.2707e-02, -5.3588e-02, -8.8918e-02,  ..., -6.1876e-03,
             7.3334e-02,  1.2001e-01],
           ...,
           [ 6.4172e-03, -4.2672e-02, -1.0888e-01,  ..., -1.9132e-02,
             7.1622e-02,  1.1532e-01],
           [ 1.7066e-02, -2.3699e-02, -9.3930e-02,  ..., -3.5009e-02,
             3.8874e-02,  7.4976e-02],
           [ 3.4658e-02,  1.7698e-02, -5.0406e-02,  ..., -3.6199e-02,
             1.4155e-02,  3.7391e-02]]],


[[[ 1.8113e-02, 2.6494e-02, 3.4953e-02, …, -8.3096e-02,
-4.9551e-02, 3.1324e-02],
[ 3.6661e-02, 2.6457e-02, -1.5455e-02, …, -1.7165e-01,
-7.2320e-02, 6.3055e-02],
[ 2.8326e-02, 3.5265e-03, -2.2922e-02, …, -1.0111e-01,
-5.7018e-03, 8.2738e-02],
…,
[-3.0252e-02, -1.5725e-02, 1.0980e-02, …, 3.8780e-02,
2.3260e-02, -1.2615e-02],
[-2.5042e-02, -1.2567e-03, 3.2277e-03, …, 6.9876e-02,
1.5599e-02, -3.9949e-02],
[-2.4860e-02, -6.1092e-03, 1.4324e-02, …, 6.7449e-02,
1.7068e-03, -5.1163e-02]],

          [[ 1.8835e-02,  4.1821e-02,  5.6398e-02,  ..., -1.2320e-01,
            -8.1981e-02,  3.9670e-02],
           [ 5.9248e-02,  5.0029e-02, -2.1544e-02,  ..., -2.3480e-01,
            -1.1931e-01,  5.1401e-02],
           [ 4.0702e-02,  8.1622e-03, -3.7389e-02,  ..., -1.7560e-01,
            -4.1825e-02,  6.6342e-02],
           ...,
           [-1.9292e-02, -2.2177e-02, -3.1122e-04,  ...,  2.5202e-02,
             2.3837e-02, -7.8956e-03],
           [-1.7456e-02, -9.9589e-03,  9.8308e-03,  ...,  6.4597e-02,
             1.6308e-02, -1.5468e-02],
           [-2.4623e-02, -1.2885e-02,  2.2207e-02,  ...,  5.5527e-02,
             1.7614e-02, -3.1443e-02]],
 
          [[ 1.2254e-02,  5.2811e-02,  7.3135e-02,  ..., -4.8397e-02,
            -3.6057e-02,  5.1934e-02],
           [ 8.0631e-02,  4.3388e-02,  1.1057e-02,  ..., -1.3908e-01,
            -7.4715e-02,  4.1007e-02],
           [ 5.7951e-02,  1.8942e-02, -1.8015e-02,  ..., -1.2828e-01,
            -4.6458e-02,  4.1263e-02],
           ...,
           [-1.4472e-02, -1.7783e-02, -5.0794e-03,  ..., -6.7681e-03,
            -2.8409e-03, -1.0941e-02],
           [-1.4044e-02, -1.3162e-02, -4.8956e-03,  ...,  4.3570e-02,
             1.5312e-02, -1.2432e-02],
           [ 8.4563e-03, -1.7099e-02, -1.2176e-03,  ...,  7.0081e-02,
             2.9756e-02, -4.1400e-03]]]], requires_grad=True)]
  • 修改模型(将1000类改为10类输出)
model = timm.create_model('resnet34',num_classes=10,pretrained=True)
x = torch.randn(1,3,224,224)
output = model(x)
output.shape
torch.Size([1, 10])
  • 改变输入通道数(比如我们传入的图片是单通道的,但是模型需要的是三通道图片)
    我们可以通过添加in_chans=1来改变
model = timm.create_model('resnet34',num_classes=10,pretrained=True,in_chans=1)
x = torch.randn(1,1,224,224)
output = model(x)

模型的保存

timm库所创建的模型是torch.model的子类,我们可以直接使用torch库中内置的模型参数保存和加载的方法,具体操作如下方代码所示

torch.save(model.state_dict(),'./checkpoint/timm_model.pth')
model.load_state_dict(torch.load('./checkpoint/timm_model.pth'))

5.3 模型微调-torchvision

随着深度学习的发展,模型的参数越来越大,许多开源模型都是在较大数据集上进行训练的,比如Imagenet-1k,Imagenet-11k,甚至是ImageNet-21k等。但在实际应用中,我们的数据集可能只有几千张,这时从头开始训练具有几千万参数的大型神经网络是不现实的,因为越大的模型对数据量的要求越大,过拟合无法避免。

假设我们想从图像中识别出不同种类的椅⼦,然后将购买链接推荐给用户。一种可能的方法是先找出100种常见的椅子,为每种椅子拍摄1000张不同⻆度的图像,然后在收集到的图像数据集上训练一个分类模型。这个椅子数据集虽然可能比Fashion-MNIST数据集要庞⼤,但样本数仍然不及ImageNet数据集中样本数的十分之⼀。这可能会导致适用于ImageNet数据集的复杂模型在这个椅⼦数据集上过拟合。同时,因为数据量有限,最终训练得到的模型的精度也可能达不到实用的要求。

为了应对上述问题,一个显⽽易⻅的解决办法是收集更多的数据。然而,收集和标注数据会花费大量的时间和资⾦。例如,为了收集ImageNet数据集,研究人员花费了数百万美元的研究经费。虽然目前的数据采集成本已降低了不少,但其成本仍然不可忽略。

另外一种解决办法是应用迁移学习(transfer learning),将从源数据集学到的知识迁移到目标数据集上。例如,虽然ImageNet数据集的图像大多跟椅子无关,但在该数据集上训练的模型可以抽取较通用的图像特征,从而能够帮助识别边缘、纹理、形状和物体组成等。这些类似的特征对于识别椅子也可能同样有效。

迁移学习的一大应用场景是模型微调(finetune)。简单来说,就是我们先找到一个同类的别人训练好的模型,把别人现成的训练好了的模型拿过来,换成自己的数据,通过训练调整一下参数。 在PyTorch中提供了许多预训练好的网络模型(VGG,ResNet系列,mobilenet系列…),这些模型都是PyTorch官方在相应的大型数据集训练好的。学习如何进行模型微调,可以方便我们快速使用预训练模型完成自己的任务。

经过本节的学习,你将收获:

  • 掌握模型微调的流程
  • 了解PyTorch提供的常用model
  • 掌握如何指定训练模型的部分层

5.3.1 模型微调的流程

  1. 在源数据集(如ImageNet数据集)上预训练一个神经网络模型,即源模型。
  2. 创建一个新的神经网络模型,即目标模型。它复制了源模型上除了输出层外的所有模型设计及其参数。我们假设这些模型参数包含了源数据集上学习到的知识,且这些知识同样适用于目标数据集。我们还假设源模型的输出层跟源数据集的标签紧密相关,因此在目标模型中不予采用。
  3. 为目标模型添加一个输出⼤小为⽬标数据集类别个数的输出层,并随机初始化该层的模型参数。
  4. 在目标数据集上训练目标模型。我们将从头训练输出层,而其余层的参数都是基于源模型的参数微调得到的。

5.3.2 使用已有模型结构

这里我们以torchvision中的常见模型为例,列出了如何在图像分类任务中使用PyTorch提供的常见模型结构和参数。对于其他任务和网络结构,使用方式是类似的:

  • 实例化网络
import torchvision.models as models
resnet18 = models.resnet18()
# resnet18 = models.resnet18(pretrained=False)  等价于与上面的表达式
alexnet = models.alexnet()
vgg16 = models.vgg16()
squeezenet = models.squeezenet1_0()
densenet = models.densenet161()
inception = models.inception_v3()
googlenet = models.googlenet()
shufflenet = models.shufflenet_v2_x1_0()
mobilenet_v2 = models.mobilenet_v2()
mobilenet_v3_large = models.mobilenet_v3_large()
mobilenet_v3_small = models.mobilenet_v3_small()
resnext50_32x4d = models.resnext50_32x4d()
wide_resnet50_2 = models.wide_resnet50_2()
mnasnet = models.mnasnet1_0()
D:\anaconda\lib\site-packages\torchvision\models\inception.py:43: FutureWarning: The default weight initialization of inception_v3 will be changed in future releases of torchvision. If you wish to keep the old behavior (which leads to long initialization times due to scipy/scipy#11299), please set init_weights=True.
  warnings.warn(
D:\anaconda\lib\site-packages\torchvision\models\googlenet.py:47: FutureWarning: The default weight initialization of GoogleNet will be changed in future releases of torchvision. If you wish to keep the old behavior (which leads to long initialization times due to scipy/scipy#11299), please set init_weights=True.
  warnings.warn(
  • 传递pretrained参数

通过True或者False来决定是否使用预训练好的权重,在默认状态下pretrained = False,意味着我们不使用预训练得到的权重,当pretrained = True,意味着我们将使用在一些数据集上预训练得到的权重。

import torchvision.models as models
resnet18 = models.resnet18(pretrained=True)
alexnet = models.alexnet(pretrained=True)
squeezenet = models.squeezenet1_0(pretrained=True)
vgg16 = models.vgg16(pretrained=True)
densenet = models.densenet161(pretrained=True)
inception = models.inception_v3(pretrained=True)
googlenet = models.googlenet(pretrained=True)
shufflenet = models.shufflenet_v2_x1_0(pretrained=True)
mobilenet_v2 = models.mobilenet_v2(pretrained=True)
mobilenet_v3_large = models.mobilenet_v3_large(pretrained=True)
mobilenet_v3_small = models.mobilenet_v3_small(pretrained=True)
resnext50_32x4d = models.resnext50_32x4d(pretrained=True)
wide_resnet50_2 = models.wide_resnet50_2(pretrained=True)
mnasnet = models.mnasnet1_0(pretrained=True)

注意事项:

  1. 通常PyTorch模型的扩展为.pt.pth,程序运行时会首先检查默认路径中是否有已经下载的模型权重,一旦权重被下载,下次加载就不需要下载了。

  2. 一般情况下预训练模型的下载会比较慢,我们可以直接通过迅雷或者其他方式去 这里 查看自己的模型里面model_urls,然后手动下载,预训练模型的权重在LinuxMac的默认下载路径是用户根目录下的.cache文件夹。在Windows下就是C:\Users\\.cache\torch\hub\checkpoint。我们可以通过使用 torch.utils.model_zoo.load_url()设置权重的下载地址。

  3. 如果觉得麻烦,还可以将自己的权重下载下来放到同文件夹下,然后再将参数加载网络。

self.model = models.resnet50(pretrained=False)
self.model.load_state_dict(torch.load('./model/resnet50-19c8e357.pth'))
  1. 如果中途强行停止下载的话,一定要去对应路径下将权重文件删除干净,要不然可能会报错。

5.3.3 训练特定层

在默认情况下,参数的属性.requires_grad = True,如果我们从头开始训练或微调不需要注意这里。但如果我们正在提取特征并且只想为新初始化的层计算梯度,其他参数不进行改变。那我们就需要通过设置requires_grad = False来冻结部分层。在PyTorch官方中提供了这样一个例程。

def set_parameter_requires_grad(model, feature_extracting):
    if feature_extracting:
        for param in model.parameters():
            param.requires_grad = False

在下面我们仍旧使用resnet18为例的将1000类改为4类,但是仅改变最后一层的模型参数,不改变特征提取的模型参数;注意我们先冻结模型参数的梯度,再对模型输出部分的全连接层进行修改,这样修改后的全连接层的参数就是可计算梯度的。

import torchvision.models as models
# 冻结参数的梯度
feature_extract = True
model = models.resnet18(pretrained=True)
set_parameter_requires_grad(model, feature_extract)
# 修改模型
num_ftrs = model.fc.in_features
model.fc = nn.Linear(in_features=num_ftrs, out_features=4, bias=True)

之后在训练过程中,model仍会进行梯度回传,但是参数更新则只会发生在fc层。通过设定参数的requires_grad属性,我们完成了指定训练模型的特定层的目标,这对实现模型微调非常重要。

5.4 半精度训练

我们提到PyTorch时候,总会想到要用硬件设备GPU的支持,也就是“卡”。GPU的性能主要分为两部分:算力和显存,前者决定了显卡计算的速度,后者则决定了显卡可以同时放入多少数据用于计算。在可以使用的显存数量一定的情况下,每次训练能够加载的数据更多(也就是batch size更大),则也可以提高训练效率。另外,有时候数据本身也比较大(比如3D图像、视频等),显存较小的情况下可能甚至batch size为1的情况都无法实现。因此,合理使用显存也就显得十分重要。

我们观察PyTorch默认的浮点数存储方式用的是torch.float32,小数点后位数更多固然能保证数据的精确性,但绝大多数场景其实并不需要这么精确,只保留一半的信息也不会影响结果,也就是使用torch.float16格式。由于数位减了一半,因此被称为“半精度”:

显然半精度能够减少显存占用,使得显卡可以同时加载更多数据进行计算。本节会介绍如何在PyTorch中设置使用半精度计算。

经过本节的学习,你将收获:

  • 如何在PyTorch中设置半精度训练
  • 使用半精度训练的注意事项

5.4.1 半精度训练的设置

在PyTorch中使用autocast配置半精度训练,同时需要在下面三处加以设置:

  • import autocast
from torch.cuda.amp import autocast
  • 模型设置

在模型定义中,使用python的装饰器方法,用autocast装饰模型中的forward函数。关于装饰器的使用,可以参考这里:

@autocast()   
def forward(self, x):
    ...
    return x
  • 训练过程

在训练过程中,只需在将数据输入模型及其之后的部分放入“with autocast():“即可:

 for x in train_loader:
	x = x.cuda()
	with autocast():
        output = model(x)
        ...

注意:

半精度训练主要适用于数据本身的size比较大(比如说3D图像、视频等)。当数据本身的size并不大时(比如手写数字MNIST数据集的图片尺寸只有28*28),使用半精度训练则可能不会带来显著的提升。

5.5 数据增强-imgaug

深度学习最重要的是数据。我们需要大量数据才能避免模型的过度拟合。但是我们在许多场景无法获得大量数据,例如医学图像分析。数据增强技术的存在是为了解决这个问题,这是针对有限数据问题的解决方案。数据增强一套技术,可提高训练数据集的大小和质量,以便我们可以使用它们来构建更好的深度学习模型。
在计算视觉领域,生成增强图像相对容易。即使引入噪声或裁剪图像的一部分,模型仍可以对图像进行分类,数据增强有一系列简单有效的方法可供选择,有一些机器学习库来进行计算视觉领域的数据增强,比如:imgaug 官网它封装了很多数据增强算法,给开发者提供了方便。通过本章内容,您将学会以下内容:

  • imgaug的简介和安装
  • 使用imgaug对数据进行增强

5.5.1 imgaug简介和安装

5.5.1.1 imgaug简介

imgaug是计算机视觉任务中常用的一个数据增强的包,相比于torchvision.transforms,它提供了更多的数据增强方法,因此在各种竞赛中,人们广泛使用imgaug来对数据进行增强操作。除此之外,imgaug官方还提供了许多例程让我们学习,本章内容仅是简介,希望起到抛砖引玉的功能。

  1. Github地址:imgaug
  2. Readthedocs:imgaug
  3. 官方提供notebook例程:notebook

5.5.1.2 imgaug的安装

imgaug的安装方法和其他的Python包类似,我们可以通过以下两种方式进行安装

conda

conda config --add channels conda-forge
conda install imgaug

pip

```shell
# install imgaug either via pypi

pip install imgaug

# install the latest version directly from github

pip install git+https://github.com/aleju/imgaug.git
```

5.5.2 imgaug的使用

imgaug仅仅提供了图像增强的一些方法,但是并未提供图像的IO操作,因此我们需要使用一些库来对图像进行导入,建议使用imageio进行读入,如果使用的是opencv进行文件读取的时候,需要进行手动改变通道,将读取的BGR图像转换为RGB图像。除此以外,当我们用PIL.Image进行读取时,因为读取的图片没有shape的属性,所以我们需要将读取到的img转换为np.array()的形式再进行处理。因此官方的例程中也是使用imageio进行图片读取。

单张图片处理

m在该单元,我们仅以几种数据增强操作为例,主要目的是教会大家如何使用imgaug来对数据进行增强操作。

import imageio
import imgaug as ia
%matplotlib inline
# 图片的读取
img = imageio.imread("./Lenna.jpg")

# 使用Image进行读取
# img = Image.open("./Lenna.jpg")
# image = np.array(img)
# ia.imshow(image)
# 可视化图片
ia.imshow(img)

现在我们已经得到了需要处理的图片,imgaug包含了许多从Augmenter继承的数据增强的操作。在这里我们以Affine为例子。

from imgaug import augmenters as iaa

# 设置随机数种子
ia.seed(4)

# 实例化方法
rotate = iaa.Affine(rotate=(-4,45))
img_aug = rotate(image=img)
ia.imshow(img_aug)

这是对一张图片进行一种操作方式,但实际情况下,我们可能对一张图片做多种数据增强处理。这种情况下,我们就需要利用imgaug.augmenters.Sequential()来构造我们数据增强的pipline,该方法与torchvison.transforms.Compose()相类似。

iaa.Sequential(children=None, # Augmenter集合
               random_order=False, # 是否对每个batch使用不同顺序的Augmenter list
               name=None,
               deterministic=False,
               random_state=None)
# 构建处理序列
aug_seq = iaa.Sequential([
    iaa.Affine(rotate=(-25,25)),
    iaa.AdditiveGaussianNoise(scale=(10,60)),
    iaa.Crop(percent=(0,0.2))
])
# 对图片进行处理,image不可以省略,也不能写成images
image_aug = aug_seq(image=img)
ia.imshow(image_aug)

总的来说,对单张图片处理的方式基本相同,我们可以根据实际需求,选择合适的数据增强方法来对数据进行处理。

对批次图片进行处理

在实际使用中,我们通常需要处理更多份的图像数据。此时,可以将图形数据按照NHWC的形式或者由列表组成的HWC的形式对批量的图像进行处理。主要分为以下两部分,对批次的图片以同一种方式处理和对批次的图片进行分部分处理。

对批次的图片以同一种方式处理

对一批次的图片进行处理时,我们只需要将待处理的图片放在一个list中,并将image改为image即可进行数据增强操作,具体实际操作如下:

images = [img,img,img,img,]
images_aug = rotate(images=images)
ia.imshow(np.hstack(images_aug))

我们就可以得到如下的展示效果:

在上述的例子中,我们仅仅对图片进行了仿射变换,同样的,我们也可以对批次的图片使用多种增强方法,与单张图片的方法类似,我们同样需要借助Sequential来构造数据增强的pipline。

aug_seq = iaa.Sequential([
    iaa.Affine(rotate=(-25, 25)),
    iaa.AdditiveGaussianNoise(scale=(10, 60)),
    iaa.Crop(percent=(0, 0.2))
])
# 传入时需要指明是images参数
images_aug = aug_seq.augment_images(images = images)
#images_aug = aug_seq(images = images) 
ia.imshow(np.hstack(images_aug))
#### 对批次的图片分部分处理
imgaug相较于其他的数据增强的库,有一个很有意思的特性,即就是我们可以通过`imgaug.augmenters.Sometimes()`对batch中的一部分图片应用一部分Augmenters,剩下的图片应用另外的Augmenters。
iaa.Sometimes(p=0.5,  # 代表划分比例
              then_list=None,  # Augmenter集合。p概率的图片进行变换的Augmenters。
              else_list=None,  #1-p概率的图片会被进行变换的Augmenters。注意变换的图片应用的Augmenter只能是then_list或者else_list中的一个。
              name=None,
              deterministic=False,
              random_state=None)

对不同大小的图片进行处理

上面提到的图片都是基于相同的图像。以下的示例具有不同图像大小的情况,我们从维基百科加载三张图片,将它们作为一个批次进行扩充,然后一张一张地显示每张图片。具体的操作跟单张的图片都是十分相似,因此不做过多赘述。

# 构建pipline
seq = iaa.Sequential([
    iaa.CropAndPad(percent=(-0.2, 0.2), pad_mode="edge"),  # crop and pad images
    iaa.AddToHueAndSaturation((-60, 60)),  # change their color
    iaa.ElasticTransformation(alpha=90, sigma=9),  # water-like effect
    iaa.Cutout()  # replace one squared area within the image by a constant intensity value
], random_order=True)
# 加载不同大小的图片
images_different_sizes = [
    imageio.imread("https://upload.wikimedia.org/wikipedia/commons/e/ed/BRACHYLAGUS_IDAHOENSIS.jpg"),
    imageio.imread("https://upload.wikimedia.org/wikipedia/commons/c/c9/Southern_swamp_rabbit_baby.jpg"),
    imageio.imread("https://upload.wikimedia.org/wikipedia/commons/9/9f/Lower_Keys_marsh_rabbit.jpg")
]
# 对图片进行增强
images_aug = seq(images=images_different_sizes)
# 可视化结果
print("Image 0 (input shape: %s, output shape: %s)" % (images_different_sizes[0].shape, images_aug[0].shape))
ia.imshow(np.hstack([images_different_sizes[0], images_aug[0]]))

print("Image 1 (input shape: %s, output shape: %s)" % (images_different_sizes[1].shape, images_aug[1].shape))
ia.imshow(np.hstack([images_different_sizes[1], images_aug[1]]))

print("Image 2 (input shape: %s, output shape: %s)" % (images_different_sizes[2].shape, images_aug[2].shape))
ia.imshow(np.hstack([images_different_sizes[2], images_aug[2]]))

5.5.3 imgaug在PyTorch的应用

关于PyTorch中如何使用imgaug每一个人的模板是不一样的,我在这里也仅仅给出imgaug的issue里面提出的一种解决方案,大家可以根据自己的实际需求进行改变。
具体链接:how to use imgaug with pytorch

import numpy as np
from imgaug import augmenters as iaa
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
# 构建pipline
tfs = transforms.Compose([
    iaa.Sequential([
        iaa.flip.Fliplr(p=0.5),
        iaa.flip.Flipud(p=0.5),
        iaa.GaussianBlur(sigma=(0.0, 0.1)),
        iaa.MultiplyBrightness(mul=(0.65, 1.35)),
    ]).augment_image,
    # 不要忘记了使用ToTensor()
    transforms.ToTensor()
])
# 自定义数据集
class CustomDataset(Dataset):
    def __init__(self, n_images, n_classes, transform=None):
		# 图片的读取,建议使用imageio
        self.images = np.random.randint(0, 255,
                                        (n_images, 224, 224, 3),
                                        dtype=np.uint8)
        self.targets = np.random.randn(n_images, n_classes)
        self.transform = transform

    def __getitem__(self, item):
        image = self.images[item]
        target = self.targets[item]

        if self.transform:
            image = self.transform(image)

        return image, target

    def __len__(self):
        return len(self.images)


def worker_init_fn(worker_id):
    imgaug.seed(np.random.get_state()[1][0] + worker_id)


custom_ds = CustomDataset(n_images=50, n_classes=10, transform=tfs)
custom_dl = DataLoader(custom_ds, batch_size=64,
                       num_workers=4, pin_memory=True, 
                       worker_init_fn=worker_init_fn)

关于num_workers在Windows系统上只能设置成0,但是当我们使用Linux远程服务器时,可能使用不同的num_workers的数量,这是我们就需要注意worker_init_fn()函数的作用了。它保证了我们使用的数据增强在num_workers>0时是对数据的增强是随机的。

5.5.4 总结

数据扩充是我们需要掌握的基本技能,除了imgaug以外,我们还可以去学习其他的数据增强库,包括但不局限于Albumentations,Augmentor。除去imgaug以外,我还强烈建议大家学下Albumentations,因为Albumentations跟imgaug都有着丰富的教程资源,大家可以有需求访问Albumentations教程。

5.6 使用argparse进行调参

引言

在深度学习中时,超参数的修改和保存是非常重要的一步,尤其是当我们在服务器上跑我们的模型时,如何更方便的修改超参数是我们需要考虑的一个问题。这时候,要是有一个库或者函数可以解析我们输入的命令行参数再传入模型的超参数中该多好。到底有没有这样的一种方法呢?答案是肯定的,这个就是 Python 标准库的一部分:Argparse。那么下面让我们看看他是多么方便。通过本节课,您将会收获以下内容

  • argparse的简介
  • argparse的使用
  • 如何使用argparse修改超参数

5.6.1 argparse简介

argsparse是python的命令行解析的标准模块,内置于python,不需要安装。这个库可以让我们直接在命令行中就可以向程序中传入参数。我们可以使用python file.py来运行python文件。而argparse的作用就是将命令行传入的其他参数进行解析、保存和使用。在使用argparse后,我们在命令行输入的参数就可以以这种形式python file.py --lr 1e-4 --batch_size 32来完成对常见超参数的设置。

5.6.2 argparse的使用

总的来说,我们可以将argparse的使用归纳为以下三个步骤。

  • 创建ArgumentParser()对象
  • 调用add_argument()方法添加参数
  • 使用parse_args()解析参数
    在接下来的内容中,我们将以实际操作来学习argparse的使用方法。
import argparse

# 创建ArgumentParser()对象
parser = argparse.ArgumentParser()
# 添加参数
parser.add_argument('-o', '--output', action='store_true', 
    help="shows output")
# action = `store_true` 会将output参数记录为True
# type 规定了参数的格式
# default 规定了默认值
parser.add_argument('--lr', type=float, default=3e-5, help='select the learning rate, default=1e-3') 

parser.add_argument('--batch_size', type=int, required=True, help='input batch size')  
# 使用parse_args()解析函数
args = parser.parse_args()

if args.output:
    print("This is some output")
    print(f"learning rate:{args.lr} ")

我们在命令行使用python demo.py --lr 3e-4 --batch_size 32,就可以看到以下的输出

argparse的参数主要可以分为可选参数和必选参数。可选参数就跟我们的lr参数相类似,未输入的情况下会设置为默认值。必选参数就跟我们的batch_size参数相类似,当我们给参数设置required =True后,我们就必须传入该参数,否则就会报错。看到我们的输入格式后,我们可能会有这样一个疑问,我输入参数的时候不使用–可以吗?答案是肯定的,不过我们需要在设置上做出一些改变。

import argparse

# 位置参数
parser = argparse.ArgumentParser()

parser.add_argument('name')
parser.add_argument('age')

args = parser.parse_args()

print(f'{args.name} is {args.age} years old')

当我们不实用–后,将会严格按照参数位置进行解析。

$ positional_arg.py Peter 23
Peter is 23 years old

总的来说,argparse的使用很简单,以上这些操作就可以帮助我们进行参数的修改,在下面的部分,我将会分享我是如何在模型训练中使用argparse进行超参数的修改。

5.6.3 更加高效使用argparse修改超参数

每个人都有着不同的超参数管理方式,在这里我将分享我使用argparse管理超参数的方式,希望可以对大家有一些借鉴意义。通常情况下,为了使代码更加简洁和模块化,我一般会将有关超参数的操作写在config.py,然后在train.py或者其他文件导入就可以。具体的config.py可以参考如下内容。

import argparse  
  
def get_options(parser=argparse.ArgumentParser()):  
  
    parser.add_argument('--workers', type=int, default=0,  
                        help='number of data loading workers, you had better put it '  
                              '4 times of your gpu')  
  
    parser.add_argument('--batch_size', type=int, default=4, help='input batch size, default=64')  
  
    parser.add_argument('--niter', type=int, default=10, help='number of epochs to train for, default=10')  
  
    parser.add_argument('--lr', type=float, default=3e-5, help='select the learning rate, default=1e-3')  
  
    parser.add_argument('--seed', type=int, default=118, help="random seed")  
  
    parser.add_argument('--cuda', action='store_true', default=True, help='enables cuda')  
    parser.add_argument('--checkpoint_path',type=str,default='',  
                        help='Path to load a previous trained model if not empty (default empty)')  
    parser.add_argument('--output',action='store_true',default=True,help="shows output")  
  
    opt = parser.parse_args()  
  
    if opt.output:  
        print(f'num_workers: {opt.workers}')  
        print(f'batch_size: {opt.batch_size}')  
        print(f'epochs (niters) : {opt.niter}')  
        print(f'learning rate : {opt.lr}')  
        print(f'manual_seed: {opt.seed}')  
        print(f'cuda enable: {opt.cuda}')  
        print(f'checkpoint_path: {opt.checkpoint_path}')  
  
    return opt  

opt = get_options()

随后在train.py等其他文件,我们就可以使用下面的这样的结构来调用参数。

# 导入必要库
import config

opt = config.get_options()

manual_seed = opt.seed
num_workers = opt.workers
batch_size = opt.batch_size
lr = opt.lr
niters = opt.niters
checkpoint_path = opt.checkpoint_path

# 随机数的设置,保证复现结果
def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

...


if __name__ == '__main__':
	set_seed(manual_seed)
	for epoch in range(niters):
		train(model,lr,batch_size,num_workers,checkpoint_path)
		val(model,lr,batch_size,num_workers,checkpoint_path)

总结

argparse给我们提供了一种新的更加便捷的方式,在后面我们将结合其他Python标准库(pickle,json,logging)实现参数的保存和模型输出的记录。如果大家还想进一步的了解argparse的使用,大家可以点击下面提供的连接进行更深的学习和了解。

  1. Python argparse 教程
  2. argparse 官方教程

你可能感兴趣的:(Pytorch,pytorch,深度学习,人工智能)