pytorch 自定义网络层

有时需要自定义卷积核的权重,但torch.nn.Conv2d和torch.nn.Conv3d中的卷积核参数不允许自定义,因此需要使用torch.nn.functional中的Conv2d或nn.Conv3d

以自定义SRM层为例

空间富模型(SRM, Spatial Rich Model)过滤器层使用3个不同参数的卷积核分别提取3种不同的高频残差信号,并将其参数设置为不可训练。SRM过滤器所使用的3个卷积核参数分别为:

pytorch 自定义网络层_第1张图片

实际上是自定义了一个参数不可训练的卷积层。自定义层需要继承 nn.Module 类,并重写 __init__ 和 forward 两个方法。

Conv2d

import torch.nn as nn
import torch
import torch.nn.functional as F
import numpy as np

class SRM2D(nn.Module):
    def __init__(self):
        super().__init__()

        q = [4.0, 12.0, 4.0]
        filter1 = [[0,  0,  0,  0, 0],
                   [0, -1,  2, -1, 0],
                   [0,  2, -4,  2, 0],
                   [0, -1,  2, -1, 0],
                   [0,  0,  0,  0, 0]]

        filter2 = [[-1, 2, -2,  2,-1],
                   [ 2,-6,  8, -6, 2],
                   [-2, 8, -12, 8,-2],
                   [ 2,-6,  8, -6, 2],
                   [-1, 2, -2,  2,-1]]

        filter3 = [[0,  0,  0,  0, 0],
                   [0,  0, -1,  0, 0],
                   [0, -1, +4, -1, 0],
                   [0,  0, -1,  0, 0],
                   [0,  0,  0,  0, 0]]

        filter1 = np.asarray(filter1, dtype=float) / q[0]
        filter2 = np.asarray(filter2, dtype=float) / q[1]
        filter3 = np.asarray(filter3, dtype=float) / q[2]
        
        # 自定义卷积核权重
        self.filter = torch.tensor([[filter1, filter1, filter1], 
                                    [filter2, filter2, filter2], 
                                    [filter3, filter3, filter3]],
                                    dtype=torch.float32)

    def forward(self, input):
        
        def truncate(x):
            neg = ((x + 2) + abs(x + 2)) / 2 - 2
            return -(-neg + 2 + abs(- neg + 2)) / 2 + 2

        result = F.conv2d(input,
                          weight=nn.Parameter(self.filter, requires_grad=False), # 设置为参数不可训练
                          stride=(1, 1, 1),
                          # 因为卷积核大小为5×5,步长为1,若想保持输出和输入大小相等,需设置padding为2
                          padding=(2, 2, 2)) 

        result = truncate(result)
        return result

Conv3d

import torch.nn as nn
import torch
import torch.nn.functional as F
import numpy as np

class SRM3D(nn.Module):
    def __init__(self):
        super().__init__()

        q = [4.0, 12.0, 4.0]
        filter1 = [[0,  0,  0,  0, 0],
                   [0, -1,  2, -1, 0],
                   [0,  2, -4,  2, 0],
                   [0, -1,  2, -1, 0],
                   [0,  0,  0,  0, 0]]

        filter2 = [[-1, 2, -2,  2,-1],
                   [ 2,-6,  8, -6, 2],
                   [-2, 8, -12, 8,-2],
                   [ 2,-6,  8, -6, 2],
                   [-1, 2, -2,  2,-1]]

        filter3 = [[0,  0,  0,  0, 0],
                   [0,  0, -1,  0, 0],
                   [0, -1, +4, -1, 0],
                   [0,  0, -1,  0, 0],
                   [0,  0,  0,  0, 0]]

        filter1 = np.asarray(filter1, dtype=float) / q[0]
        filter2 = np.asarray(filter2, dtype=float) / q[1]
        filter3 = np.asarray(filter3, dtype=float) / q[2]
        
        # 自定义卷积核权重
        filter = torch.tensor([[filter1, filter1, filter1], 
                               [filter2, filter2, filter2], 
                               [filter3, filter3, filter3]],
                               dtype=torch.float32)
        
        # 因为是3D卷积,所以需要扩充维度
        self.filter = torch.unsqueeze(filter, 2)

    def forward(self, input):
        
        def truncate(x):
            neg = ((x + 2) + abs(x + 2)) / 2 - 2
            return -(-neg + 2 + abs(- neg + 2)) / 2 + 2

        result = F.conv3d(input,
                          weight=nn.Parameter(self.filter, requires_grad=False), # 设置为参数不可训练
                          stride=(1, 1, 1),
                          # 因为卷积核大小为5×5,步长为1,若想保持输出和输入大小相等,需设置padding为(0, 2, 2)
                          padding=(0, 2, 2))

        result = truncate(result)
        return result

你可能感兴趣的:(PyTorch学习笔记,pytorch,深度学习,人工智能,计算机视觉)