【极简】Pytorch中的register_buffer()

register buffer

定义模型能用torch.save保存的、但是不更新参数。

使用:只要是nn.Module的子类就能直接self.调用使用:

class A(nn.Module):
#...
self.register_buffer(
            'betas', torch.linspace(beta_1, beta_T, T).double())
#...

手动定义参数

上述的参数显然可以直接用一个变量直接定义超参。但是缺点是在用torch.save()保存的时候不能保存在参数里面,只能用个文本文件保存在外面。不能直接用torch.load加载,不是很方便。

举个例子,假设你有100个超参,难不成要一个一个记录之后,手动造轮子解析保存的txt嘛?当然也行,但是麻烦。
就比如Diffusion Model中的beta和alpha,在每个timestep时候都是不一样的,这时候手动保存会相当麻烦,用register buffer会相当方便。

普通参数

一般来说模型中的可变参数都是nn.Parameter()类的,这些都是可变的,optimizer会去优化它们。

要是跟register buffer硬凑在一起,把Parameter的require_grad改成False也能充当。但是何必呢?

你可能感兴趣的:(pytorch,pytorch,人工智能,python)