Pytorch中的一些训练技巧

冻结bn层

如果你使用了预训练模型,并且显卡不支持你使用很大的batch size,那么冰冻bn的参数就是很好的选择,因为在imageNet上预训练的模型,bn层会获得很好的running mean和running var。

   for name, m in model.named_modules():
        if isinstance(m, nn.BatchNorm2d):
            m.eval()
            m.weight.requires_grad = False
            m.bias.requires_grad = False

固定随机种子

如果想检查你设计的网络究竟正不正常,最好将每次调试都把参数初始化设置为一致的。

torch.manual_seed(0)
torch.cuda.manual_seed_all(0)

提升训练阶段forward的速度

如果你的网络没有控制流,我建议在import之后就加入以下两行代码

torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = True

原因是:设置 torch.backends.cudnn.benchmark=True 将会让程序在开始时花费一点额外时间,为整个网络的每个卷积层搜索最适合它的卷积实现算法,进而实现网络的加速。适用场景是网络结构固定(不是动态变化的),网络的输入形状(包括 batch size,图片大小,输入的通道)是不变的,其实也就是一般情况下都比较适用。反之,如果卷积层的设置一直变化,将会导致程序不停地做优化,反而会耗费更多的时间。

warmUp预热+学习率呈余弦变化

余弦学习率的公式如下:
l r = b a s e _ l r × 0.5 × ( 1 + c o s ( g l o b a l _ s t e p m a x _ s t e p s ∗ P i ) lr = base\_lr \times 0.5 \times (1+ cos(\frac{global\_step}{max\_steps} * Pi) lr=base_lr×0.5×(1+cos(max_stepsglobal_stepPi)
WarmUp的公式如下:
l r = b a s e _ l r × 1.0 × g l o b a l _ s t e p w a r m u p _ t o t a l s t e p s i f g l o b a l _ s t e p < w a r m u p _ t o t a l _ s t e p s lr = base\_lr \times 1.0 \times \frac{global\_step}{warmup\_total_steps} if global\_step < warmup\_total\_steps lr=base_lr×1.0×warmup_totalstepsglobal_stepifglobal_step<warmup_total_steps

loss采用标签平滑

标签平滑的作用不再多说,各大比赛优秀方案几乎都采用标签平滑。下面附一段从旷世开源的shuffleNet系列中截取的一段,用于分类任务的标签平滑的loss。

class CrossEntropyLabelSmooth(nn.Module):

	def __init__(self, num_classes, epsilon):
		super(CrossEntropyLabelSmooth, self).__init__()
		self.num_classes = num_classes
		self.epsilon = epsilon
		self.logsoftmax = nn.LogSoftmax(dim=1)

	def forward(self, inputs, targets):
		log_probs = self.logsoftmax(inputs)
		targets = torch.zeros_like(log_probs).scatter_(1, targets.unsqueeze(1), 1)
		targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
		loss = (-targets * log_probs).mean(0).sum()
		return loss

criterion_smooth = CrossEntropyLabelSmooth(1000, 0.1)

你可能感兴趣的:(Pytorch)