有时需要自定义卷积核的权重,但torch.nn.Conv2d和torch.nn.Conv3d中的卷积核参数不允许自定义,因此需要使用torch.nn.functional中的Conv2d或nn.Conv3d
空间富模型(SRM, Spatial Rich Model)过滤器层使用3个不同参数的卷积核分别提取3种不同的高频残差信号,并将其参数设置为不可训练。SRM过滤器所使用的3个卷积核参数分别为:
实际上是自定义了一个参数不可训练的卷积层。自定义层需要继承 nn.Module 类,并重写 __init__
和 forward 两个方法。
import torch.nn as nn
import torch
import torch.nn.functional as F
import numpy as np
class SRM2D(nn.Module):
def __init__(self):
super().__init__()
q = [4.0, 12.0, 4.0]
filter1 = [[0, 0, 0, 0, 0],
[0, -1, 2, -1, 0],
[0, 2, -4, 2, 0],
[0, -1, 2, -1, 0],
[0, 0, 0, 0, 0]]
filter2 = [[-1, 2, -2, 2,-1],
[ 2,-6, 8, -6, 2],
[-2, 8, -12, 8,-2],
[ 2,-6, 8, -6, 2],
[-1, 2, -2, 2,-1]]
filter3 = [[0, 0, 0, 0, 0],
[0, 0, -1, 0, 0],
[0, -1, +4, -1, 0],
[0, 0, -1, 0, 0],
[0, 0, 0, 0, 0]]
filter1 = np.asarray(filter1, dtype=float) / q[0]
filter2 = np.asarray(filter2, dtype=float) / q[1]
filter3 = np.asarray(filter3, dtype=float) / q[2]
# 自定义卷积核权重
self.filter = torch.tensor([[filter1, filter1, filter1],
[filter2, filter2, filter2],
[filter3, filter3, filter3]],
dtype=torch.float32)
def forward(self, input):
def truncate(x):
neg = ((x + 2) + abs(x + 2)) / 2 - 2
return -(-neg + 2 + abs(- neg + 2)) / 2 + 2
result = F.conv2d(input,
weight=nn.Parameter(self.filter, requires_grad=False), # 设置为参数不可训练
stride=(1, 1, 1),
# 因为卷积核大小为5×5,步长为1,若想保持输出和输入大小相等,需设置padding为2
padding=(2, 2, 2))
result = truncate(result)
return result
import torch.nn as nn
import torch
import torch.nn.functional as F
import numpy as np
class SRM3D(nn.Module):
def __init__(self):
super().__init__()
q = [4.0, 12.0, 4.0]
filter1 = [[0, 0, 0, 0, 0],
[0, -1, 2, -1, 0],
[0, 2, -4, 2, 0],
[0, -1, 2, -1, 0],
[0, 0, 0, 0, 0]]
filter2 = [[-1, 2, -2, 2,-1],
[ 2,-6, 8, -6, 2],
[-2, 8, -12, 8,-2],
[ 2,-6, 8, -6, 2],
[-1, 2, -2, 2,-1]]
filter3 = [[0, 0, 0, 0, 0],
[0, 0, -1, 0, 0],
[0, -1, +4, -1, 0],
[0, 0, -1, 0, 0],
[0, 0, 0, 0, 0]]
filter1 = np.asarray(filter1, dtype=float) / q[0]
filter2 = np.asarray(filter2, dtype=float) / q[1]
filter3 = np.asarray(filter3, dtype=float) / q[2]
# 自定义卷积核权重
filter = torch.tensor([[filter1, filter1, filter1],
[filter2, filter2, filter2],
[filter3, filter3, filter3]],
dtype=torch.float32)
# 因为是3D卷积,所以需要扩充维度
self.filter = torch.unsqueeze(filter, 2)
def forward(self, input):
def truncate(x):
neg = ((x + 2) + abs(x + 2)) / 2 - 2
return -(-neg + 2 + abs(- neg + 2)) / 2 + 2
result = F.conv3d(input,
weight=nn.Parameter(self.filter, requires_grad=False), # 设置为参数不可训练
stride=(1, 1, 1),
# 因为卷积核大小为5×5,步长为1,若想保持输出和输入大小相等,需设置padding为(0, 2, 2)
padding=(0, 2, 2))
result = truncate(result)
return result