全变分(Total Variation,TV)模型

在图像复原和图像去噪方面使用。

全变分(Total Variation Loss)源自于图像处理中的全变分去噪(Total Variation Denoising),全变分去噪的有点是既能去噪声,又能保留原图像中的边界信息。而其他简单的去噪方法,如果线性平滑或者中值滤波,在去噪的同时会平滑图像中的边界等信息,损害图像所表达的信息。

全变分去噪的基本思想是,如果图像的细节有很多高频信息(如尖刺、噪点等),那么整幅图像的梯度幅值之和(全变分)是比较大的,如果能够使整幅图像的梯度积分之和降低,就达到了去噪的目的。

参考:(29 封私信 / 2 条消息) 如何理解全变分(Total Variation,TV)模型? - 知乎 (zhihu.com)

PyTorch实现

因为 TV Loss 是对整个 batch 的每一幅图像计算的,所以除以batch_size * c * h * w是一个比较适合作为loss的值,创建TV Loss时可传入一个可选的权重参数(惩罚系数)

作者:imxtx
链接:https://www.zhihu.com/question/47162419/answer/2585330101
来源:知乎
著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。

import torch
import torch.nn as nn
import numpy as np


class TVLoss(nn.Module):
    def __init__(self, weight: float=1) -> None:
        """Total Variation Loss

        Args:
            weight (float): weight of TV loss
        """
        super().__init__()
        self.weight = weight
    
    def forward(self, x):
        batch_size, c, h, w = x.size()
        tv_h = torch.abs(x[:,:,1:,:] - x[:,:,:-1,:]).sum()
        tv_w = torch.abs(x[:,:,:,1:] - x[:,:,:,:-1]).sum()
        return self.weight * (tv_h + tv_w) / (batch_size * c * h * w)


def main():
    tv_loss = TVLoss()
    x = torch.rand([1, 3, 3, 3])
    print(f'Input:\n{x}')
    print(f'The TV Loss is {tv_loss(x).item()}')


if __name__ == "__main__":
    main()

你可能感兴趣的:(计算机视觉,深度学习,人工智能,计算机视觉,深度学习)