SAM:Sharpness Awareness Minimization锐度感知最小化
SAM不是一个新的优化器,它与其他常见的优化器一起使用,比如SGD/Adam。
论文:2020 Sharpness-Aware Minimization for Efficiently Improving Generalization
论文地址:https://arxiv.org/pdf/2010.01412v2.pdf
项目地址:GitHub - davda54/sam: SAM: Sharpness-Aware Minimization (PyTorch)
(依旧建议大家使用GPU去训练,一般电脑cpu可以运行,但是非常卡,能卡出数据集,但是没卡出结果。)
下载解压后非常简单,把sam.py文件直接复制到example文件夹下就可以直接跑train.py.
运行后会自动下载数据集,会进行批次训练。
运行结果:(我改的epochs比较小,改大效果更好)
重要部分如下train.py:
import argparse
import torch
from model.wide_res_net import WideResNet#导入模型中的wide_res_net网络
from model.smooth_cross_entropy import smooth_crossentropy#导入损失函数
from data.cifar import Cifar#导入数据集
from utility.log import Log#导入工具类日志文件
from utility.initialize import initialize#导入工具类初始化
from utility.step_lr import StepLR#导入工具类阶梯学习率
from utility.bypass_bn import enable_running_stats, disable_running_stats#导入工具类绕过BN,启用运行统计,禁用运行统计
import sys; sys.path.append("..")#导入sys.path中需要用到的XXX包,然后加载
from sam import SAM#引入SAM
if __name__ == "__main__":
#创建解析器(arg对象)
parser = argparse.ArgumentParser()
#添加参数
parser.add_argument("--adaptive", default=True, type=bool, help="True if you want to use the Adaptive SAM.")
parser.add_argument("--batch_size", default=12, type=int, help="Batch size used in the training and validation loop.")
parser.add_argument("--depth", default=16, type=int, help="Number of layers.")
parser.add_argument("--dropout", default=0.0, type=float, help="Dropout rate.")
parser.add_argument("--epochs", default=2, type=int, help="Total number of epochs.")
parser.add_argument("--label_smoothing", default=0.1, type=float, help="Use 0.0 for no label smoothing.")
parser.add_argument("--learning_rate", default=0.1, type=float, help="Base learning rate at the start of the training.")
parser.add_argument("--momentum", default=0.9, type=float, help="SGD Momentum.")
parser.add_argument("--threads", default=2, type=int, help="Number of CPU threads for dataloaders.")
parser.add_argument("--rho", default=2.0, type=int, help="Rho parameter for SAM.")
parser.add_argument("--weight_decay", default=0.0005, type=float, help="L2 weight decay.")
parser.add_argument("--width_factor", default=8, type=int, help="How many times wider compared to normal ResNet.")
#解析参数
args = parser.parse_args()
#初始化
initialize(args, seed=42)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#定义数据集
dataset = Cifar(args.batch_size, args.threads)
#记录日志
log = Log(log_each=10)
#定义模型
model = WideResNet(args.depth, args.width_factor, args.dropout, in_channels=3, labels=10).to(device)
#定义基础优化器
base_optimizer = torch.optim.SGD
#定义第二个优化器SAM
optimizer = SAM(model.parameters(), base_optimizer, rho=args.rho, adaptive=args.adaptive, lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay)
#将optimizer作为参数传递给scheduler,每次通过调用scheduler.step()就会更新optimizer中每一个param_group[‘lr’],每过固定个epoch,学习率会按照gamma倍率进行衰减。
scheduler = StepLR(optimizer, args.learning_rate, args.epochs)
for epoch in range(args.epochs):
model.train()
log.train(len_dataset=len(dataset.train))
for batch in dataset.train:
inputs, targets = (b.to(device) for b in batch)
# first forward-backward step
enable_running_stats(model)
predictions = model(inputs)
loss = smooth_crossentropy(predictions, targets, smoothing=args.label_smoothing)
loss.mean().backward()
optimizer.first_step(zero_grad=True)
# second forward-backward step
disable_running_stats(model)
smooth_crossentropy(model(inputs), targets, smoothing=args.label_smoothing).mean().backward()
optimizer.second_step(zero_grad=True)
with torch.no_grad():
correct = torch.argmax(predictions.data, 1) == targets
log(model, loss.cpu(), correct.cpu(), scheduler.lr())
scheduler(epoch)
model.eval()
log.eval(len_dataset=len(dataset.test))
with torch.no_grad():
for batch in dataset.test:
inputs, targets = (b.to(device) for b in batch)
predictions = model(inputs)
loss = smooth_crossentropy(predictions, targets)
correct = torch.argmax(predictions, 1) == targets
log(model, loss.cpu(), correct.cpu())
log.flush()
把utility文件夹复制到自己的项目下,
把sam.py复制到项目根目录,在train.py里导入包。
(我的项目是多个独立的网络,几个网络就写几遍)
base_optimizer = torch.optim.SGD
optimizer = SAM(model.parameters(), base_optimizer, rho=args.rho, adaptive=args.adaptive, lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay)
scheduler = StepLR(optimizer, args.learning_rate, args.epochs)
#一定要根据自己的项目去改相关参数等
...
#opt.zero_grad()注释掉原本的
#添加SAM工具类里的函数
enable_running_stats(model_context)
enable_running_stats(model_body)
enable_running_stats(emotic_model)
#我的项目是三个网络。如果是一个网络的话,写一次
#类似于enable_running_stats(model)
...
loss.backward()
opt.first_step(zero_grad=True)#在项目的loss反向传播后先用优化器first step
#添加SAM工具类里的函数
disable_running_stats(model_context)
disable_running_stats(model_body)
disable_running_stats(emotic_model)
...
loss.backward()
opt.second_step(zero_grad=True)#在项目的loss反向传播后再用优化器second step
# opt.step()注释掉原本的
原项目:
修改后:
原项目:
修改后:
#黄色框为修改位置
以上。(全是自己的理解,不正确望指正,感谢。)