Pytorch(六) —— 模型调优tricks

Pytorch(六) —— 模型调优tricks

  • 1.正则化 Regularization
    • 1.1 L1正则化
    • 1.2 L2正则化
  • 2 动量与学习率衰减
    • 2.1 momentum
    • 2.2 learning rate tunning
  • 3. Early Stopping
  • 4. Dropout

1.正则化 Regularization

1.1 L1正则化

import torch
import torch.nn.functional as F
from torch import nn

device=torch.device("cuda:0")
MLP = nn.Sequential(nn.Linear(128,64),
                    nn.ReLU(inplace=True),
                    nn.Linear(64,32),
                    nn.ReLU(inplace=True),
                    nn.Linear(32,10)
)
MLP.to(device) 
loss_classify = nn.CrossEntropyLoss().to(device)
# L1范数
l1_loss = 0
for param in MLP.parameters():
    l1_loss += torch.sum(torch.abs(param))
loss = loss_classify+l1_loss

1.2 L2正则化

import torch
import torch.nn.functional as F
from torch import nn

device=torch.device("cuda:0")
MLP = nn.Sequential(nn.Linear(128,64),
                    nn.ReLU(inplace=True),
                    nn.Linear(64,32),
                    nn.ReLU(inplace=True),
                    nn.Linear(32,10)
)
MLP.to(device) 


# L2范数
opt = torch.optim.SGD(MLP.parameters(),lr=0.001,weight_decay=0.1) # 通过weight_decay实现L2
loss = nn.CrossEntropyLoss().to(device)

2 动量与学习率衰减

2.1 momentum

opt = torch.optim.SGD(model.parameters(),lr=0.001,momentum=0.78,weight_decay=0.1)

2.2 learning rate tunning

  • torch.optim.lr_scheduler.ReduceLROnPlateau() 当损失函数值不降低时使用
  • torch.optim.lr_scheduler.StepLR() 按照一定步数降低学习率
opt = torch.optim.SGD(net.parameters(),lr=1)
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer=opt,mode="min",factor=0.1,patience=10)
for epoch in torch.arange(1000):
    loss_val = train(...)
    lr_scheduler.step(loss_val) # 监听loss
opt = torch.optim.SGD(net.parameters(),lr=1)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer=opt,step_size=30,gamma=0.1)
for epoch in torch.arange(1000):
    lr_scheduler.step() # 监听loss
    train(...)

3. Early Stopping

点击这里

4. Dropout

model = nn.Sequential(
nn.Linear(256,128),
nn.Dropout(p=0.5),
nn.ReLu(),
)

by CyrusMay 2022 07 03

你可能感兴趣的:(Pytorch,深度学习(神经网络)专题,pytorch,深度学习,python,人工智能,机器学习)