pytorch 中register_buffer()

 

今天在看DSSINet代码的ssim.py时,遇到了一个用法

class NORMMSSSIM(torch.nn.Module):

    def __init__(self, sigma=1.0, levels=5, size_average=True, channel=1):
        super(NORMMSSSIM, self).__init__()
        self.sigma = sigma
        self.window_size = 5
        self.levels = levels
        self.size_average = size_average
        self.channel = channel
        self.register_buffer('window', create_window(self.window_size, self.channel, self.sigma))
        self.register_buffer('weights', torch.Tensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333]))

那么这个register_buffer()是干什么用呢?官方解释如下

nn.modules.module.py
Adds a persistent buffer to the module.向模块添加持久缓冲区。

        This is typically used to register a buffer that should not to be
        considered a model parameter. For example, BatchNorm's ``running_mean``
        is not a parameter, but is part of the persistent state.这通常用于注册不应被视为模型参数的缓冲区。例如,BatchNorm的“running_mean”不是参数,而是持久状态的一部分。
        Buffers can be accessed as attributes using given names.
缓冲区可以使用给定的名称作为属性访问。 
        Args:
            name (string): name of the buffer. The buffer can be accessed
                from this module using the given name 名称(字符串):缓冲区的名称。可以使用给定的名称从该模块访问缓冲区
            tensor (Tensor): buffer to be registered.
        Example::
            >>> self.register_buffer('running_mean', torch.zeros(num_features))        

应该就是在内存中定一个常量,同时,模型保存和加载的时候可以写入和读出。

pytorch一般情况下,是将网络中的参数保存成orderedDict形式的,这里的参数其实包含两种,一种是模型中各种module含的参数,即nn.Parameter,我们当然可以在网络中定义其他的nn.Parameter参数,另一种就是buffer,前者每次optim.step会得到更新,而不会更新后者。

class myModel(nn.Module):
    def __init__(self, kernel_size=3):
        super(Depth_guided1, self).__init__()
        self.kernel_size = kernel_size
        self.back_end = torch.nn.Sequential(
            torch.nn.Conv2d(3, 32, 3, padding=1),
            torch.nn.ReLU(True),
            torch.nn.Conv2d(3, 64, 3, padding=1),
            torch.nn.ReLU(True),
            torch.nn.Conv2d(64, 3, 3, padding=1),
            torch.nn.ReLU(True),
        )

        mybuffer = np.arange(1,10,1)
        self.mybuffer_tmp = np.randn((len(mybuffer), 1, 1, 10), dtype='float32')
        self.mybuffer_tmp = torch.from_numpy(self.mybuffer_tmp)
        # register preset variables as buffer
        # So that, in testing , we can use buffer variables.
        self.register_buffer('mybuffer', self.mybuffer_tmp)

        # Learnable weights
        self.conv_weights = nn.Parameter(torch.FloatTensor(64, 10).normal_(mean=0, std=0.01))
        # Other code
        def forward(self):
            ...
            # 这里使用 self.mybuffer!

注记:

1.定义parameter和buffer都只需要传入Tensor即可。也不需要将其转成gpu,这是因为,当网络进行.cuda时候,会自动将里面的层的参数,buffer等转换成相应的GPU上。

2. self.register_buffer可以将tensor注册成buffer,在forward中使用self.mybuffer,而不是self.mybuffer_tmp

3.网络存储时也会将buffer存下,当网络load模型时,会将存储的模型的buffer也进行赋值。

4.buffer的更新在forward中,optim.step只能更新nn.parameter类型的参数。

你可能感兴趣的:(pytorch 中register_buffer())