pytorch中梯度修剪,防止梯度爆炸的方法

def clip_gradient(optimizer, grad_clip):
    for group in optimizer.param_groups:
        for param in group['params']:
            if param.grad is not None:
                param.grad.data.clamp_(-grad_clip, grad_clip)

 

你可能感兴趣的:(pytorch,深度学习,人工智能)