Fourier卷积

Fast Fourier Convolution (FFC) for Image Classification

https://github.com/pkumivision/FFC

设计模块

强度和相位cat在一起卷积

Fourier卷积_第1张图片

 

 代码实现

class FourierUnit(nn.Module):

    def __init__(self, in_channels, out_channels, groups=1):
        # bn_layer not used
        super(FourierUnit, self).__init__()
        self.groups = groups
        self.conv_layer = torch.nn.Conv2d(in_channels=in_channels * 2, out_channels=out_channels * 2,
                                          kernel_size=1, stride=1, padding=0, groups=self.groups, bias=False)
        self.bn = torch.nn.BatchNorm2d(out_channels * 2)
        self.relu = torch.nn.ReLU(inplace=True)

    def forward(self, x):
        batch, c, h, w = x.size()
        r_size = x.size()

        # (batch, c, h, w/2+1, 2)
        ffted = torch.rfft(x, signal_ndim=2, normalized=True)
        # (batch, c, 2, h, w/2+1)
        ffted = ffted.permute(0, 1, 4, 2, 3).contiguous()
        ffted = ffted.view((batch, -1,) + ffted.size()[3:])

        ffted = self.conv_layer(ffted)  # (batch, c*2, h, w/2+1)
        ffted = self.relu(self.bn(ffted))

        ffted = ffted.view((batch, -1, 2,) + ffted.size()[2:]).permute(
            0, 1, 3, 4, 2).contiguous()  # (batch,c, t, h, w/2+1, 2)

        output = torch.irfft(ffted, signal_ndim=2,
                             signal_sizes=r_size[2:], normalized=True)

        return output

具体使用

Fourier卷积_第2张图片

 代码操作

class SpectralTransform(nn.Module):

    def __init__(self, in_channels, out_channels, stride=1, groups=1, enable_lfu=True):
        # bn_layer not used
        super(SpectralTransform, self).__init__()
        self.enable_lfu = enable_lfu
        if stride == 2:
            self.downsample = nn.AvgPool2d(kernel_size=(2, 2), stride=2)
        else:
            self.downsample = nn.Identity()

        self.stride = stride
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels //
                      2, kernel_size=1, groups=groups, bias=False),
            nn.BatchNorm2d(out_channels // 2),
            nn.ReLU(inplace=True)
        )
        self.fu = FourierUnit(
            out_channels // 2, out_channels // 2, groups)
        if self.enable_lfu:
            self.lfu = FourierUnit(
                out_channels // 2, out_channels // 2, groups)
        self.conv2 = torch.nn.Conv2d(
            out_channels // 2, out_channels, kernel_size=1, groups=groups, bias=False)

    def forward(self, x):
        x = self.downsample(x) #[b,c,h,w]:[2,512,8,8]
        x = self.conv1(x)  #[b,out_channels/2,h,w]:[b,64,8,8]
        output = self.fu(x) #[b,out_channels/2,h,w]:[b,64,8,8]

        if self.enable_lfu:
            n, c, h, w = x.shape
            split_no = 2
            split_s_h = h // split_no
            split_s_w = w // split_no
            xs = torch.cat(torch.split(
                x[:, :c // 4], split_s_h, dim=-2), dim=1).contiguous() #[b,32,4,8]
            xs = torch.cat(torch.split(xs, split_s_w, dim=-1),
                           dim=1).contiguous()                         #[b,64,4,4]
            xs = self.lfu(xs)                                           #[b,64,4,4]
            xs = xs.repeat(1, 1, split_no, split_no).contiguous()       #[b,64,8,8]
        else:
            xs = 0

        output = self.conv2(x + output + xs)

        return output

你可能感兴趣的:(Fourier/phase信息,深度学习)