❀精度优化❀优化策略1:网络+SAM优化器

一:SAM优化器介绍:

SAM:Sharpness Awareness Minimization锐度感知最小化

SAM不是一个新的优化器,它与其他常见的优化器一起使用,比如SGD/Adam

论文:2020 Sharpness-Aware Minimization for Efficiently Improving Generalization

❀精度优化❀优化策略1:网络+SAM优化器_第1张图片

论文地址:https://arxiv.org/pdf/2010.01412v2.pdf

项目地址:GitHub - davda54/sam: SAM: Sharpness-Aware Minimization (PyTorch)

(依旧建议大家使用GPU去训练,一般电脑cpu可以运行,但是非常卡,能卡出数据集,但是没卡出结果。)

下载解压后非常简单,把sam.py文件直接复制到example文件夹下就可以直接跑train.py.

❀精度优化❀优化策略1:网络+SAM优化器_第2张图片

运行后会自动下载数据集,会进行批次训练。

运行结果:(我改的epochs比较小,改大效果更好)

❀精度优化❀优化策略1:网络+SAM优化器_第3张图片

重要部分如下train.py:

import argparse
import torch

from model.wide_res_net import WideResNet#导入模型中的wide_res_net网络
from model.smooth_cross_entropy import smooth_crossentropy#导入损失函数
from data.cifar import Cifar#导入数据集
from utility.log import Log#导入工具类日志文件
from utility.initialize import initialize#导入工具类初始化
from utility.step_lr import StepLR#导入工具类阶梯学习率
from utility.bypass_bn import enable_running_stats, disable_running_stats#导入工具类绕过BN,启用运行统计,禁用运行统计

import sys; sys.path.append("..")#导入sys.path中需要用到的XXX包,然后加载
from sam import SAM#引入SAM


if __name__ == "__main__":
    #创建解析器(arg对象)
    parser = argparse.ArgumentParser()
    #添加参数
    parser.add_argument("--adaptive", default=True, type=bool, help="True if you want to use the Adaptive SAM.")
    parser.add_argument("--batch_size", default=12, type=int, help="Batch size used in the training and validation loop.")
    parser.add_argument("--depth", default=16, type=int, help="Number of layers.")
    parser.add_argument("--dropout", default=0.0, type=float, help="Dropout rate.")
    parser.add_argument("--epochs", default=2, type=int, help="Total number of epochs.")
    parser.add_argument("--label_smoothing", default=0.1, type=float, help="Use 0.0 for no label smoothing.")
    parser.add_argument("--learning_rate", default=0.1, type=float, help="Base learning rate at the start of the training.")
    parser.add_argument("--momentum", default=0.9, type=float, help="SGD Momentum.")
    parser.add_argument("--threads", default=2, type=int, help="Number of CPU threads for dataloaders.")
    parser.add_argument("--rho", default=2.0, type=int, help="Rho parameter for SAM.")
    parser.add_argument("--weight_decay", default=0.0005, type=float, help="L2 weight decay.")
    parser.add_argument("--width_factor", default=8, type=int, help="How many times wider compared to normal ResNet.")
    #解析参数
    args = parser.parse_args()
    #初始化
    initialize(args, seed=42)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    #定义数据集
    dataset = Cifar(args.batch_size, args.threads)
    #记录日志
    log = Log(log_each=10)
    #定义模型
    model = WideResNet(args.depth, args.width_factor, args.dropout, in_channels=3, labels=10).to(device)
    #定义基础优化器
    base_optimizer = torch.optim.SGD
    #定义第二个优化器SAM
    optimizer = SAM(model.parameters(), base_optimizer, rho=args.rho, adaptive=args.adaptive, lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay)
    #将optimizer作为参数传递给scheduler,每次通过调用scheduler.step()就会更新optimizer中每一个param_group[‘lr’],每过固定个epoch,学习率会按照gamma倍率进行衰减。
    scheduler = StepLR(optimizer, args.learning_rate, args.epochs)
    
    for epoch in range(args.epochs):
        model.train()
        log.train(len_dataset=len(dataset.train))

        for batch in dataset.train:
            inputs, targets = (b.to(device) for b in batch)

            # first forward-backward step
            enable_running_stats(model)
            predictions = model(inputs)
            loss = smooth_crossentropy(predictions, targets, smoothing=args.label_smoothing)
            loss.mean().backward()
            optimizer.first_step(zero_grad=True)

            # second forward-backward step
            disable_running_stats(model)
            smooth_crossentropy(model(inputs), targets, smoothing=args.label_smoothing).mean().backward()
            optimizer.second_step(zero_grad=True)

            with torch.no_grad():
                correct = torch.argmax(predictions.data, 1) == targets
                log(model, loss.cpu(), correct.cpu(), scheduler.lr())
                scheduler(epoch)

        model.eval()
        log.eval(len_dataset=len(dataset.test))

        with torch.no_grad():
            for batch in dataset.test:
                inputs, targets = (b.to(device) for b in batch)

                predictions = model(inputs)
                loss = smooth_crossentropy(predictions, targets)
                correct = torch.argmax(predictions, 1) == targets
                log(model, loss.cpu(), correct.cpu())

    log.flush()

二:把SAM应用到自己的项目上:

step1:把SAM的工具文件复制到自己的项目下

把utility文件夹复制到自己的项目下,

❀精度优化❀优化策略1:网络+SAM优化器_第4张图片

把sam.py复制到项目根目录,在train.py里导入包。

step2:把数据集改为自己的数据集

step3:把网络改为自己的网络

(我的项目是多个独立的网络,几个网络就写几遍)

step4:添加基础优化器和SAM

    base_optimizer = torch.optim.SGD
    optimizer = SAM(model.parameters(), base_optimizer, rho=args.rho, adaptive=args.adaptive, lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay)
    scheduler = StepLR(optimizer, args.learning_rate, args.epochs)
#一定要根据自己的项目去改相关参数等

step5:把损失函数改为自己原本的损失函数添加SAM工具类

            ...

            #opt.zero_grad()注释掉原本的

            #添加SAM工具类里的函数
            enable_running_stats(model_context)
            enable_running_stats(model_body)
            enable_running_stats(emotic_model)
            #我的项目是三个网络。如果是一个网络的话,写一次
            #类似于enable_running_stats(model)

            ...
 
            loss.backward()
            opt.first_step(zero_grad=True)#在项目的loss反向传播后先用优化器first step

             #添加SAM工具类里的函数
            disable_running_stats(model_context)
            disable_running_stats(model_body)
            disable_running_stats(emotic_model)

            ...

            loss.backward()
            opt.second_step(zero_grad=True)#在项目的loss反向传播后再用优化器second step

            # opt.step()注释掉原本的

step6:添加SAM所需的超参数(可选,不改也不会出错)


原项目:

❀精度优化❀优化策略1:网络+SAM优化器_第5张图片

修改后:

❀精度优化❀优化策略1:网络+SAM优化器_第6张图片

❀精度优化❀优化策略1:网络+SAM优化器_第7张图片

原项目:

❀精度优化❀优化策略1:网络+SAM优化器_第8张图片

 

 修改后:

❀精度优化❀优化策略1:网络+SAM优化器_第9张图片

#黄色框为修改位置

 加入SAM优化器后,比原来精度提高了将近3%。

 

 

以上。(全是自己的理解,不正确望指正,感谢。)

你可能感兴趣的:(-,精度优化,-,-,编码理解,-,深度学习,pytorch,计算机视觉)