用户在训练脚本时往往会遇到loss不收敛或者其他精度问题,而精度问题产生的主要原因之一是因为者梯度更新过程中梯度值过大或过小导致调参不准确所造成,针对梯度值在更新过程中过大或过小原因所造成的精度问题,故在此介绍一种梯度裁剪方法;顾名思义,梯度裁剪(gradient clip)是指当梯度小于或大于某个阈值时,强制调整梯度使其变大或变小的技术。
...
...
GRADIENT_CLIP_TYPE = 0
GRADIENT_CLIP_VALUE = 1.0
class ClipGradients(nn.Cell):
"""
Clip gradients.
Inputs:
grads (tuple[Tensor]): Gradients.
clip_type (int): The way to clip, 0 for 'value', 1 for 'norm'.
clip_value (float): Specifies how much to clip.
Outputs:
tuple[Tensor], clipped gradients.
"""
def __init__(self):
super(ClipGradients, self).__init__()
self.clip_by_norm = nn.ClipByNorm()
self.cast = P.Cast()
self.dtype = P.DType()
def construct(self, grads, clip_type, clip_value):
if clip_type != 0 and clip_type != 1:
return grads
new_grads = ()
for grad in grads:
dt = self.dtype(grad)
if clip_type == 0:
t = ops.clip_by_value(grad, self.cast(F.tuple_to_array((clip_value,)),dt))
else:
t = self.clip_by_norm(grad, self.cast(F.tuple_to_array((clip_value,)),dt))
new_grads = new_grads + (t,)
return new_grads
class TrainOneStepCellV2(TrainOneStepCell):
'''Build train network.'''
def __init__(self, network, optimizer, sens=1.0):
super(TrainOneStepCellV2, self).__init__(network, optimizer, sens=1.0)
self.clip_gradients = ClipGradients()
def construct(self, *inputs):
...
grads = self.grad(self.network, weights)(*inputs, sens)
grads = self.clip_gradients(grads, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE)
...
...
...
# Construct model
model_constructed = TrainOneStepCellV2(model_constructed, opt)
# Train
train_net(model_constructed, net, loss_function, CHECKPOINT_MAX, EPOCH_MAX, TRAIN_PATH, VAL_PATH, TRAIN_BATCH_SIZE, VAL_BATCH_SIZE, REPEAT_SIZE)
...
GRADIENT_CLIP_TYPE = 0
GRADIENT_CLIP_VALUE = 1.0class ClipGradients(nn.Cell):
...
#同低阶代码模型梯度裁剪
...
...
class TrainOneStepCellV2(TrainOneStepCell):
...
def __init__(self, network, optimizer, sens=1.0):
super(TrainOneStepCellV2, self).__init__(network, optimizer, sens=1.0)
self.clip_gradients = ClipGradients()
def construct(self, *inputs):
...
grads = self.clip_gradients(grads, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE)
...
注:目前仅支持非下沉模式
梯度裁剪的低阶模型示例代码以及高阶模型示例代码请至附件处下载。