【pytorch】register_buffer的使用

这篇文章讲解很清晰,以下内容仅做补充,探讨哪些对象需要手动注册,哪些会自动注册

在 PyTorch 中,哪些对象会自动注册为模型的一部分取决于它们的类型以及你如何定义它们。下面列出不需要手动注册、会自动注册的几种情况:

1. nn.Parameter

  • 自动注册:任何你在 nn.Module 中定义为 nn.Parameter 的张量都会自动注册为模型的参数。它们会被视为模型的可训练参数,并且会被包含在模型的 state_dict() 中,也会参与反向传播和优化。

  • 如何使用

    class MyModel(nn.Module):
        def __init__(self):
            super(MyModel, self).__init__()
            # 自动注册为模型参数
            self.weight = nn.Parameter(torch.randn(5, 5))
    
        def forward(self, x):
            return x * self.weight
    
  • 特点

    • 被自动注册为模型的参数,参与梯度计算和更新。
    • 会在 model.parameters() 中找到,并且会在模型保存、加载以及转移设备时自动管理。

2. nn.Module 子类(如 nn.Conv2d, nn.Linear 等)

  • 自动注册:所有继承自 nn.Module 的子类(如 nn.Conv2d, nn.Linear, nn.ReLU 等)会自动注册为模型的一部分。它们包含的权重和偏置会自动注册为模型参数,并参与训练和保存。

  • 如何使用

    class MyModel(nn.Module):
        def __init__(self):
            super(MyModel, self).__init__()
            # 自动注册为模型的一部分
            self.conv = nn.Conv2d(1, 1, 3, 1, 1)
    
        def forward(self, x):
            return self.conv(x)
    
  • 特点

    • 子模块中的所有 nn.Parameter 自动成为模型的一部分。
    • 例如 nn.Conv2d 的权重和偏置会在 state_dict() 中找到。
    • 可以通过 model.parameters() 获取所有可训练参数。

3. nn.Module 内的属性如果是其他 nn.Module 的实例

  • 当你将另一个 nn.Module 对象作为属性放在自定义模型中时,这个子模块的参数会自动注册为主模块的一部分。

  • 如何使用

    class MyModel(nn.Module):
        def __init__(self):
            super(MyModel, self).__init__()
            # 自动注册子模块
            self.linear1 = nn.Linear(10, 5)
            self.linear2 = nn.Linear(5, 1)
    
        def forward(self, x):
            x = self.linear1(x)
            return self.linear2(x)
    
  • 特点

    • 子模块会自动注册,子模块的参数也会作为主模块的一部分。
    • 子模块也可以递归地包含其他模块,这些都会自动注册。

4. buffers 使用 register_buffer 显式注册

  • 不会自动注册:对于不需要训练的张量或常量(例如 BN 层中的均值、方差、位置编码等),需要使用 register_buffer 手动注册。这些张量不会参与梯度更新,但会随着模型保存、加载以及转移到设备。

  • 如何使用

    class MyModel(nn.Module):
        def __init__(self):
            super(MyModel, self).__init__()
            # 手动注册 buffer
            self.register_buffer('my_buffer', torch.randn(5, 5))
    
        def forward(self, x):
            return x + self.my_buffer
    
  • 特点

    • 不参与梯度计算,但在 state_dict 中可见。
    • 可以通过 .to(device) 自动转移到指定设备。

5. 不会自动注册的对象

  • 普通的 Python 对象、torch.Tensor 或者 list/dict 类型不会自动注册为模型的一部分。你需要手动使用 register_buffer 或者 nn.Parameter 来使其成为模型的成员,否则这些对象不会在模型的 state_dict() 中出现,也不会随着模型迁移到 GPU/CPU。

  • 例子

    class MyModel(nn.Module):
        def __init__(self):
            super(MyModel, self).__init__()
            # 不会自动注册
            self.tensor = torch.randn(5, 5)  # 这个张量不会自动成为模型的部分
    
        def forward(self, x):
            return x + self.tensor
    
  • 解决方法:如果你希望 tensor 也能成为模型的一部分,使用 register_buffer

    self.register_buffer('my_buffer', self.tensor)
    

总结:

  • 自动注册

    • nn.Parameter: 任何 nn.Parameter 类型的属性会自动成为模型的参数。
    • nn.Module 子类: 任何包含在 nn.Module 中的子模块(如 nn.Conv2d)会自动注册为模型的一部分。
  • 需要手动注册

    • 非可训练的常量(buffers:需要使用 register_buffer 来显式注册,它们不参与梯度计算,但会保存、加载以及转移到设备。

你可能感兴趣的:(pytorch,人工智能,python)