compressai中的上下文预测模块

理论

1、自回归模型Autoregressive model

用自身产生的参数去预测下一个值,即使用 x 1 , x 2 , . . . , x t x_1,x_2,...,x_t x1,x2,...,xt去预测 x t + 1 x_{t+1} xt+1,用 x 1 , x 2 , . . . , x t , x t + 1 x_1,x_2,...,x_t,x_{t+1} x1,x2,...,xt,xt+1去预测 x t + 2 x_{t+2} xt+2

2、Masked 2D convolution

compressai使用的是PixelCNN模块中的Masked 2D convolution,我理解的就是正常的2D卷积操作,由于当前像素是根据之前解码的像素得到的,所以要让卷积块中位于当前像素的之后的权重为0;

比如说5*5的卷积核:

compressai中的上下文预测模块_第1张图片
Introduced in "Conditional Image Generation with PixelCNN Decoders" _.

代码

简单到我都怀疑作者是不是没按照论文写

class MaskedConv2d(nn.Conv2d):
    r"""Masked 2D convolution implementation, mask future "unseen" pixels.
    Useful for building auto-regressive network components.
    """
    
    def __init__(self, *args: Any, mask_type: str = "A", **kwargs: Any):
        super().__init__(*args, **kwargs)

        if mask_type not in ("A", "B"):
            raise ValueError(f'Invalid "mask_type" value "{mask_type}"')

        # self.register_buffer('name',Tensor)定义一组参数,模型训练时不会更新
        self.register_buffer("mask", torch.ones_like(self.weight.data))
        _, _, h, w = self.mask.size()
        self.mask[:, :, h // 2, w // 2 + (mask_type == "B") :] = 0  # 当前像素所在行像素点所在列及其右侧所有列为0
        self.mask[:, :, h // 2 + 1 :] = 0  # 当前像素下面所有行为0

    def forward(self, x: Tensor) -> Tensor:
        # TODO(begaintj): weight assigment is not supported by torchscript
        self.weight.data *= self.mask
        return super().forward(x)

你可能感兴趣的:(人工智能,python)