torch_filter 一个即插即用的torch模块,用于实现图像自定义滤波。
有时,我们需要在神经网络中实现一些图像滤波操作。因此,我们设计了这个简单的torch_filter模块。模块代码如下:
import torch
import torch.nn as nn
import numpy as np
import cv2
class torch_filter(nn.Module):
def __init__(self, filter_weight, is_grad=False):
super(torch_filter, self).__init__()
assert type(filter_weight) == np.ndarray
k=filter_weight.shape[0]
filter=torch.tensor(filter_weight).unsqueeze(dim=0).unsqueeze(dim=0)
filters = torch.cat([filter, filter, filter], dim=0)
self.conv = nn.Conv2d(3, 3, kernel_size=k, groups=3, bias=False, padding=int((k-1)/2))
self.conv.weight.data.copy_(filters)
self.conv.requires_grad_(is_grad)
def forward(self,x):
output = self.conv(x)
output = torch.clip(output, 0, 1)
return output
if __name__ == '__main__':
weight = np.ones((5,5))
net=torch_filter(weight,is_grad=False)
img=torch.randn((9,3,256,256))
img=net(img)
print(img.shape)#torch.Size([9, 3, 256, 256])
filter_weight:自定义过滤器权重。请注意,filter的两个维度都必须是相等的和奇数的。例如(3×3)、(5×5)。类型必须为np.array。
is_grad:True or False,该模块是否参与反向传播过程。
我们用以下代码实现了图像锐化操作,filter_weight设置成:
filter_weight=np.array ([[-1, -1, -1], [-1, 9, -1], [-1, -1, -1 ]])
测试代码如下:
import torch
import cv2
import numpy as np
from torch_filter import torch_filter
weight = np.array([[-1, -1, -1], [-1, 9, -1], [-1, -1, -1]])
net=torch_filter(weight,is_grad=False)
img=cv2.imread(r"images/img.png")#输入图片路径
image = np.transpose((np.array(img, np.float64))/255, [2, 0, 1])
image = torch.from_numpy(image).type(torch.FloatTensor)
image = image.unsqueeze(dim=0)
image_sharp=net(image)
image_sharp=image_sharp.cpu().detach().numpy().copy().squeeze()
predictimag=np.transpose(image_sharp, [1, 2, 0])*255
cv2.imwrite(r'images/new_img.png',predictimag)#输出图片路径
https://github.com/deepxzy/torch_filter