Diffusion微调报错:RuntimeError: One of the differentiated Tensors does not require grad

最近在开展Diffusion Model模型微调的相关工作时,设置微调层后反传梯度多次遇到以下报错

RuntimeError: One of the differentiated Tensors does not require grad

网络上相关内容较少,特此记录。

代码:OpenAI-UNetModel
Bug定位过程:

  1. 逐层设置微调,发现后面的层可以正常微调,只有部分层设置微调时会报错;
  2. 找到出现该报错的层,查看代码,发现前向函数会调用名为checkpoint()的函数:
def checkpoint(func, inputs, params, flag):
    """
    Evaluate a function without caching intermediate activations, allowing for
    reduced memory at the expense of extra compute in the backward pass.
    :param func: the function to evaluate.
    :param inputs: the argument sequence to pass to `func`.
    :param params: a sequence of parameters `func` depends on but does not
                   explicitly take as arguments.
    :param flag: if False, disable gradient checkpointing.
    """
    if flag:
        args = tuple(inputs) + tuple(params)
        return CheckpointFunction.apply(func, len(inputs), *args)
    else:
        return func(*inputs)


class CheckpointFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, run_function, length, *args):
        ctx.run_function = run_function
        ctx.input_tensors = list(args[:length])
        ctx.input_params = list(args[length:])

        with torch.no_grad():
            output_tensors = ctx.run_function(*ctx.input_tensors)
        return output_tensors

    @staticmethod
    def backward(ctx, *output_grads):
        ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
        with torch.enable_grad():
            # Fixes a bug where the first op in run_function modifies the
            # Tensor storage in place, which is not allowed for detach()'d
            # Tensors.
            shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
            output_tensors = ctx.run_function(*shallow_copies)
        input_grads = torch.autograd.grad(
            output_tensors,
            ctx.input_tensors + ctx.input_params,
            output_grads,
            allow_unused=True,
        )
        del ctx.input_tensors
        del ctx.input_params
        del output_tensors
        return (None, None) + input_grads

由于在微调设置时设置了不微调层的require_grad=False,导致output_grads=None,进而导致torch.autograd.grad报错。解决办法:设置flag=False。

你可能感兴趣的:(pytorch,人工智能,stable,diffusion)