Ptroch训练技巧精简版

Pytorch训练技巧

1.自定义损失函数

pytorch在torch.nn模块里提供了很多损失函数,如MSELoss,L1Loss等,但同时也可以自定义损失函数:

两种定义损失函数方法:

  • 以函数方式

    直接定义一个函数即可

    def my_loss(output, target):
        loss = torch.mean((output - target)**2)
        return loss
    
  • 以类方式

    损失函数继承自_Loss类,_WeightedLoss类,这两个都继承了nn.Module,故自定义类损失应该继承nn.Module。列如

    class DiceLoss(nn.Module):
        def __init__(self,weight=None,size_average=True):
            super(DiceLoss,self).__init__()
            
        def forward(self,inputs,targets,smooth=1):
            inputs = F.sigmoid(inputs)       
            inputs = inputs.view(-1)
            targets = targets.view(-1)
            intersection = (inputs * targets).sum()                   
            dice = (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)  
            return 1 - dice
    
    # 使用方法    
    criterion = DiceLoss()
    loss = criterion(input,targets)
    

2.动态调整学习率

  • 可使用官方api

    # 选择一种优化器
    optimizer = torch.optim.Adam(...) 
    # 选择上面提到的一种或多种动态调整学习率的方法
    scheduler1 = torch.optim.lr_scheduler.... 
    scheduler2 = torch.optim.lr_scheduler....
    ...
    schedulern = torch.optim.lr_scheduler....
    # 进行训练
    for epoch in range(100):
        train(...)
        validate(...)
        optimizer.step()
        # 需要在优化器参数更新之后再动态调整学习率
    # scheduler的优化是在每一轮后面进行的
    scheduler1.step() 
    ...
    schedulern.step()
    
  • 也可以自定义scheduler

    自定义函数adjust_learning_rate来改变param_grouplr的值。

    def adjust_learning_rate(optimizer, epoch):
        lr = args.lr * (0.1 ** (epoch // 30))
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
    
    def adjust_learning_rate(optimizer,...):
        ...
    optimizer = torch.optim.SGD(model.parameters(),lr = args.lr,momentum = 0.9)
    for epoch in range(10):
        train(...)
        validate(...)
        adjust_learning_rate(optimizer,epoch)
    

3.模型微调

通过timm.create_model()的方法来进行模型的创建,我们可以通过传入参数pretrained=True,来使用预训练模型。

import timm
import torch

model = timm.create_model('resnet34',pretrained=True)
x = torch.randn(1,3,224,224)
output = model(x)
output.shape

4.半精度训练

将默认的单精度浮点数torch.float32改为torch.float16,可以节约运行内存

你可能感兴趣的:(深度学习,python,pytorch)