torch实现图像滤波(可反向传播)

torch实现图像滤波(可反向传播)

torch_filter 一个即插即用的torch模块,用于实现图像自定义滤波。
有时,我们需要在神经网络中实现一些图像滤波操作。因此,我们设计了这个简单的torch_filter模块。模块代码如下:

import torch
import torch.nn as nn
import numpy as np
import cv2


class torch_filter(nn.Module):
    def __init__(self, filter_weight, is_grad=False):
        super(torch_filter, self).__init__()
        assert type(filter_weight) == np.ndarray
        k=filter_weight.shape[0]
        filter=torch.tensor(filter_weight).unsqueeze(dim=0).unsqueeze(dim=0)
        filters = torch.cat([filter, filter, filter], dim=0)

        self.conv = nn.Conv2d(3, 3, kernel_size=k, groups=3, bias=False, padding=int((k-1)/2))
        self.conv.weight.data.copy_(filters)
        self.conv.requires_grad_(is_grad)


    def forward(self,x):
        output = self.conv(x)
        output = torch.clip(output, 0, 1)
        return output

if __name__ == '__main__':
    weight = np.ones((5,5))
    net=torch_filter(weight,is_grad=False)
    img=torch.randn((9,3,256,256))
    img=net(img)
    print(img.shape)#torch.Size([9, 3, 256, 256])


torch_filter.py的输入参数

filter_weight:自定义过滤器权重。请注意,filter的两个维度都必须是相等的和奇数的。例如(3×3)、(5×5)。类型必须为np.array。
is_grad:True or False,该模块是否参与反向传播过程。

测试

我们用以下代码实现了图像锐化操作,filter_weight设置成:

 filter_weight=np.array ([[-1, -1, -1], [-1, 9, -1], [-1, -1, -1 ]])

测试代码如下:

import torch
import cv2
import numpy as np
from torch_filter import torch_filter
weight = np.array([[-1, -1, -1], [-1, 9, -1], [-1, -1, -1]])
net=torch_filter(weight,is_grad=False)
img=cv2.imread(r"images/img.png")#输入图片路径
image = np.transpose((np.array(img, np.float64))/255, [2, 0, 1])
image = torch.from_numpy(image).type(torch.FloatTensor)
image = image.unsqueeze(dim=0)
image_sharp=net(image)
image_sharp=image_sharp.cpu().detach().numpy().copy().squeeze()
predictimag=np.transpose(image_sharp, [1, 2, 0])*255
cv2.imwrite(r'images/new_img.png',predictimag)#输出图片路径

输入:torch实现图像滤波(可反向传播)_第1张图片
输出:

效果符合预期

参考源码

https://github.com/deepxzy/torch_filter

你可能感兴趣的:(python,深度学习,计算机视觉)