https://github.com/pkumivision/FFC
强度和相位cat在一起卷积
代码实现
class FourierUnit(nn.Module):
def __init__(self, in_channels, out_channels, groups=1):
# bn_layer not used
super(FourierUnit, self).__init__()
self.groups = groups
self.conv_layer = torch.nn.Conv2d(in_channels=in_channels * 2, out_channels=out_channels * 2,
kernel_size=1, stride=1, padding=0, groups=self.groups, bias=False)
self.bn = torch.nn.BatchNorm2d(out_channels * 2)
self.relu = torch.nn.ReLU(inplace=True)
def forward(self, x):
batch, c, h, w = x.size()
r_size = x.size()
# (batch, c, h, w/2+1, 2)
ffted = torch.rfft(x, signal_ndim=2, normalized=True)
# (batch, c, 2, h, w/2+1)
ffted = ffted.permute(0, 1, 4, 2, 3).contiguous()
ffted = ffted.view((batch, -1,) + ffted.size()[3:])
ffted = self.conv_layer(ffted) # (batch, c*2, h, w/2+1)
ffted = self.relu(self.bn(ffted))
ffted = ffted.view((batch, -1, 2,) + ffted.size()[2:]).permute(
0, 1, 3, 4, 2).contiguous() # (batch,c, t, h, w/2+1, 2)
output = torch.irfft(ffted, signal_ndim=2,
signal_sizes=r_size[2:], normalized=True)
return output
代码操作
class SpectralTransform(nn.Module):
def __init__(self, in_channels, out_channels, stride=1, groups=1, enable_lfu=True):
# bn_layer not used
super(SpectralTransform, self).__init__()
self.enable_lfu = enable_lfu
if stride == 2:
self.downsample = nn.AvgPool2d(kernel_size=(2, 2), stride=2)
else:
self.downsample = nn.Identity()
self.stride = stride
self.conv1 = nn.Sequential(
nn.Conv2d(in_channels, out_channels //
2, kernel_size=1, groups=groups, bias=False),
nn.BatchNorm2d(out_channels // 2),
nn.ReLU(inplace=True)
)
self.fu = FourierUnit(
out_channels // 2, out_channels // 2, groups)
if self.enable_lfu:
self.lfu = FourierUnit(
out_channels // 2, out_channels // 2, groups)
self.conv2 = torch.nn.Conv2d(
out_channels // 2, out_channels, kernel_size=1, groups=groups, bias=False)
def forward(self, x):
x = self.downsample(x) #[b,c,h,w]:[2,512,8,8]
x = self.conv1(x) #[b,out_channels/2,h,w]:[b,64,8,8]
output = self.fu(x) #[b,out_channels/2,h,w]:[b,64,8,8]
if self.enable_lfu:
n, c, h, w = x.shape
split_no = 2
split_s_h = h // split_no
split_s_w = w // split_no
xs = torch.cat(torch.split(
x[:, :c // 4], split_s_h, dim=-2), dim=1).contiguous() #[b,32,4,8]
xs = torch.cat(torch.split(xs, split_s_w, dim=-1),
dim=1).contiguous() #[b,64,4,4]
xs = self.lfu(xs) #[b,64,4,4]
xs = xs.repeat(1, 1, split_no, split_no).contiguous() #[b,64,8,8]
else:
xs = 0
output = self.conv2(x + output + xs)
return output