SWA,全程为“Stochastic Weight Averaging”(随机权重平均)。它是一种深度学习中提高模型泛化能力的一种常用技巧。
其思路为:对于模型的权重,不直接使用最后的权重,而是将之前的权重做个平均。
该方法适用于深度学习,不限领域、不限Optimzer,可以和多种技巧同时使用。
我们的模型参数记为: θ = { w 0 , w 1 , w 2 , ⋯ , w n } \theta=\{w_0, w_1, w_2, \cdots, w_n\} θ={w0,w1,w2,⋯,wn}, n n n 为模型总参数量。
对于模型的训练,会在epoch结束后保存一个副本,第 t t t 个epoch的模型参数记为 θ t \theta_t θt。
则我们模型的最终参数为:
θ ˉ = 1 T ∑ t = 1 T θ t \bar{\theta} = \frac{1}{T} \sum^T_{t=1}\theta_t θˉ=T1t=1∑Tθt
其中 T T T 表示我们有 T T T 个不同个模型参数的副本。
该公式的意思就是将前面t个模型的权重取平均,然后作为最终的模型参数。
注意事项:
通常我们在使用SWA时会有如下的超参数:
本节展示一下Pytorch Lightning中对SWA的实现,以便更清晰的认识SWA。
在开始看代码前,明确几个在Pytorch Lightning实现中的几个重要的概念:
class StochasticWeightAveraging(Callback):
def __init__(
self,
swa_lrs: Union[float, List[float]], # swa的学习率
# swa_epoch_start: 从第0.8位置的epoch开始,例如一共100个epoch,那就从第81个epoch开始swa。
# 若指定整数,则会从指定的epoch开始swa。
swa_epoch_start: Union[int, float] = 0.8,
annealing_epochs: int = 10, # 模拟退火的epoch数。SWALR学习策略用的参数
annealing_strategy: str = "cos", # 模拟退火策略。SWALR学习策略用的参数
avg_fn: Optional[_AVG_FN] = None, # 平局函数,做模型参数平均时使用的函数,通常不需要指定。会使用默认的。
device: Optional[Union[torch.device, str]] = torch.device("cpu"), # 平均后的model存在哪个device上
):
...
def on_train_epoch_start(self, ...): # 在每个epoch开始前执行
if (not self._initialized) and (self.swa_start <= trainer.current_epoch <= self.swa_end):
# 初始化SWA,在整个SWA过程中只执行一遍
self._initialized = True
...
# 使用原来的optimizer
optimizer = trainer.optimizers[0]
...
# 使用SWALR学习率策略(SWA Learning Scheduler),后面会讲
self._swa_scheduler = cast(
LRScheduler,
SWALR(
optimizer,
swa_lr=self._swa_lrs, # type: ignore[arg-type]
anneal_epochs=self._annealing_epochs,
anneal_strategy=self._annealing_strategy,
last_epoch=trainer.max_epochs if self._annealing_strategy == "cos" else -1,
),
)
# end if, 初始化代码结束。
# 接下来是SWA在epoch开始前的处理逻辑
if (self.swa_start <= trainer.current_epoch <= self.swa_end):
# 在SWA期间,每个epoch开始前将当前的模型参数更新到“平均模型”上。
self.update_parameters(self._average_model, pl_module, self.n_averaged, self._avg_fn)
if trainer.current_epoch == self.swa_end + 1:
# 到最后结束的时候,将平均模型的参数迁移到模型上。
self.transfer_weights(self._average_model, pl_module)
@staticmethod
def update_parameters(
average_model: "pl.LightningModule", model: "pl.LightningModule", n_averaged: Tensor, avg_fn: _AVG_FN
) -> None:
for p_swa, p_model in zip(average_model.parameters(), model.parameters()):
device = p_swa.device
p_swa_ = p_swa.detach()
p_model_ = p_model.detach().to(device)
src = p_model_ if n_averaged == 0 else avg_fn(p_swa_, p_model_, n_averaged.to(device))
p_swa_.copy_(src)
n_averaged += 1
@staticmethod
def avg_fn(averaged_model_parameter: Tensor, model_parameter: Tensor, num_averaged: Tensor) -> Tensor:
return averaged_model_parameter + (model_parameter - averaged_model_parameter) / (num_averaged + 1)
从上述Pytorch Lightning对SWA实现的源码中我们可以获得以下信息:
SWA学习率
和从哪个epoch
开始这两个最重要的参数。在上面我们提到了Pytorch Lightning实现中,在SWA期间使用的是SWALR。
SWALR使用的是“模拟退火”策略,简单来说就是:学习率是从原本的学习率逐渐过度到SWA学习率的。例如,原本你使用的学习率是0.1,指定的SWA学习率为0.01,从第20个epoch开始进行SWA。那么并不是到第20个epoch后学习率立刻从0.1变到0.01,而是从0.1逐渐过度到0.01,过度的epoch数就是指定的annealing_epochs
参数,而过度时减小的策略就是annealing_strategy
参数。
这里不使用难以理解的源码或数学,而是来通过几组实验来直观的观察一下SWALR策略下的学习率的变化来进行解释:
上述实验为:模型训练过程中学习率随epoch的变化,横坐标为epoch,纵坐标为这个epoch使用的学习率。其中图上的几个参数分别为:
例如对于图一意思就是:模型一开始在Optimizer上指定的学习率是0.1,SWA学习率为0.001,从第2个epoch开始进行SWA,总共进行10(annealing_epochs
) 个epoch将学习率从0.1逐渐过度到0.001,学习率调整使用cos策略。
从上述图中很容易得出以下结论:
实验环境与代码如下:
lightning==2.0.1
pytorch==1.13.0
实验代码如下:
import torch
import torch.nn as nn
import lightning.pytorch as pl
from lightning.pytorch.callbacks import StochasticWeightAveraging
from matplotlib import pyplot as plt
import numpy as np
def plot_swa_lr_curve(model_lr, # 模型的学习率
swa_lr, # swa的学习率
swa_epoch_start=2, # 从哪个epoch开始swa
annealing_epochs=10, # 模拟退火的epoch数
annealing_strategy='cos' # 模拟退火策略
):
lrs = []
# 定义一个简单的模型,用于测试
class SimpleModel(pl.LightningModule):
def __init__(self):
super(SimpleModel, self).__init__()
self.linear = nn.Linear(1, 1)
def training_step(self, batch, batch_idx, *args, **kwargs):
return nn.functional.mse_loss(self.linear(torch.rand(4, 1)), torch.rand(4, 1))
def configure_optimizers(self):
# 使用model_lr作为测试模型的学习率
return torch.optim.SGD(self.parameters(), lr=model_lr)
# 重写一下StochasticWeightAveraging,用于记录学习率变化
class MyStochasticWeightAveraging(StochasticWeightAveraging):
def on_train_epoch_start(self, *args, **kwargs):
super().on_train_epoch_start(*args, **kwargs)
if hasattr(self._swa_scheduler, "_last_lr"):
# 记录lr的变化
lrs.append(self._swa_scheduler._last_lr[0])
else:
lrs.append(model_lr)
# 定义trainer进行训练
trainer = pl.Trainer(
callbacks=[MyStochasticWeightAveraging(swa_lrs=swa_lr, swa_epoch_start=swa_epoch_start,
annealing_epochs=annealing_epochs,
annealing_strategy=annealing_strategy)],
max_epochs=20,
num_sanity_val_steps=0,
enable_progress_bar=False, # Use custom progress bar
accelerator='cpu',
)
# 训练模型
trainer.fit(SimpleModel(), train_dataloaders=range(10))
plt.plot(np.arange(1, len(lrs)+1).astype(dtype=np.str), lrs)
plt.xlabel("epoch")
plt.ylabel("learning rate")
plt.text(0.7, 0.9, "model_lr: %s" % model_lr, fontsize=11, transform=plt.gca().transAxes)
plt.text(0.7, 0.8, "swa_lr: %s" % swa_lr, fontsize=11, transform=plt.gca().transAxes)
plt.text(0.6, 0.7, "swa_epoch_start: %s" % swa_epoch_start, fontsize=11, transform=plt.gca().transAxes)
plt.text(0.6, 0.6, "annealing_epochs: %s" % annealing_epochs, fontsize=11, transform=plt.gca().transAxes)
plt.text(0.6, 0.5, "annealing_strategy: %s" % annealing_strategy, fontsize=11, transform=plt.gca().transAxes)
plt.show()
print("lrs:", lrs) # 输出lr的变化
return lrs
plot_swa_lr_curve(0.1, 0.001)
Averaging Weights Leads to Wider Optima and Better Generalization(原论文): https://arxiv.org/abs/1803.05407
PyTorch 1.6 now includes Stochastic Weight Averaging: https://pytorch.org/blog/pytorch-1.6-now-includes-stochastic-weight-averaging/