import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
from PIL import Image
import matplotlib.pyplot as plt
def to_2tuple(x):
if isinstance(x, tuple):
return x
return (x, x)
def to_4tuple(x):
if isinstance(x, tuple):
return x + x
return (x, x, x, x)
def get_padding(kernel_size, dilation):
kernel_size, dilation = to_2tuple(kernel_size), to_2tuple(dilation)
padding = tuple(((k - 1) * d) // 2 for k, d in zip(kernel_size, dilation))
return padding
def get_stride(kernel_size):
kernel_size = to_2tuple(kernel_size)
stride = kernel_size
return stride
def get_dilation(kernel_size):
kernel_size = to_2tuple(kernel_size)
dilation = tuple(k // 2 for k in kernel_size)
return dilation
def get_norm(norm_type, out_channels):
if norm_type == 'bn':
return nn.BatchNorm2d(out_channels)
elif norm_type == 'gn':
return nn.GroupNorm(32, out_channels)
class ModulatedDeformConv(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
deformable_groups=1,
groups=1,
bias=True,
norm='',
activation=F.relu,
use_deform=True,
use_sk=False,
use_se=False,
use_mix=False,
attn_dim=64
):
super(ModulatedDeformConv, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = to_4tuple(kernel_size)
self.stride = to_4tuple(stride)
self.padding = to_4tuple(padding)
self.dilation = to_4tuple(dilation)
self.deformable_groups = deformable_groups
self.groups = groups
self.use_deform = use_deform
self.use_sk = use_sk
self.use_se = use_se
self.use_mix = use_mix
self.weight = nn.Parameter(
torch.Tensor(out_channels, in_channels // groups, *self.kernel_size)
)
if bias:
self.bias = nn.Parameter(torch.Tensor(out_channels))
else:
self.bias = None
self.norm = get_norm(norm, out_channels)
self.activation = activation
self.conv_offset = nn.Conv2d(
self.in_channels,
self.deformable_groups * 2 * np.prod(self.kernel_size),
kernel_size=self.kernel_size,
stride=self.stride,
padding=self.padding,
dilation=self.dilation,
bias=True
)
if self.use_sk:
self.conv_sk = nn.Conv2d(
self.in_channels,
self.in_channels,
kernel_size=1,
stride=1,
padding=0,
bias=False
)
self.bn_sk = get_norm(norm, self.in_channels)
self.activation_sk = nn.ReLU(inplace=True)
if self.use_se:
self.linear1 = nn.Conv2d(in_channels, attn_dim, kernel_size=1, stride=1, padding=0)
self.activation_se1 = nn.ReLU(inplace=True)
self.linear2 = nn.Conv2d(attn_dim, self.in_channels, kernel_size=1, stride=1, padding=0)
self.activation_se2 = nn.Sigmoid()
if self.use_mix:
self.linear_mix = nn.Linear(self.in_channels, self.in_channels, bias=True)
nn.init.normal_(self.linear_mix.weight, std=0.01)
nn.init.constant_(self.linear_mix.bias, 0)
self.reset_parameters()
def reset_parameters(self):
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
if self.bias is not None:
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
bound = 1 / math.sqrt(fan_in)
nn.init.uniform_(self.bias, -bound, bound)
def forward(self, x):
offset = self.conv_offset(x)
B, C, H, W = offset.size()
offset = offset.view(B, self.deformable_groups, 2 * np.prod(self.kernel_size), H, W).contiguous()
offset_mean = torch.mean(torch.abs(offset[:, :, :9, :, :])) * 1. / 9
offset = offset.view(
B,
self.deformable_groups,
2,
np.prod(self.kernel_size),
H,
W
).permute(0, 1, 4, 5, 3, 2).contiguous()
if self.use_deform:
q = torch.cat((
torch.zeros_like(offset[:, :, :, :, :1, :1]),
offset[:, :, :, :, :1, :1]
), dim=-1)
q = q.view(B, -1, *q.size()[3:])
x = torch.nn.functional.deform_conv2d(
x,
q,
self.weight,
self.stride,
self.padding,
self.dilation,
self.groups,
)
else:
x = F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
if self.use_mix:
x = self.linear_mix(x)
if self.norm is not None:
x = self.norm(x)
if self.activation is not None:
x = self.activation(x)
if self.use_se:
attn = torch.sigmoid(self.linear2(self.activation_se1(self.linear1(F.avg_pool2d(x, x.size()[2:])))))
x = x * attn
if self.use_sk:
residual = x
x = self.conv_sk(x)
x = self.bn_sk(x)
x = self.activation_sk(x)
x = x + residual
return x, offset_mean
class DeformableAttention(nn.Module):
def __init__(self, dim_in, dim_out, dim_inner, n_head, window_size, stride=1):
super().__init__()
self.dim = dim_out
self.n_head = n_head
self.head_dim = self.dim // self.n_head
self.window_size = window_size
self.stride = stride
self.in_conv = nn.Conv2d(dim_in, dim_inner, 1, bias=False)
self.qkv_conv = nn.Conv2d(dim_inner, self.dim * 3, 1, bias=False)
self.attn_drop = nn.Dropout(0.1)
self.norm = nn.BatchNorm2d(dim_in)
self.proj = nn.Conv2d(self.dim, dim_out, 1)
self.proj_drop = nn.Dropout(0.1)
self.local_attn = LocalAttention(self.dim, n_head, window_size, 0.1, 0.1)
def forward(self, x):
B, C, H, W = x.shape
x = self.norm(x)
x = self.in_conv(x)
qkv = self.qkv_conv(x)
qkv = qkv.reshape(B, 3, self.n_head, self.head_dim, H, W)
qkv = qkv.permute(1, 2, 0, 3, 4, 5)
q, k, v = qkv[0], qkv[1], qkv[2]
q, k, v = map(lambda t: einops.rearrange(t, 'b hh c h w -> (b hh) (h w) c'), [q, k, v])
attn_mask = torch.zeros((H * W, H * W), device=x.device).float()
for i in range(H * W):
h, w = i // W, i % W
for j in range(max(0, h - self.window_size), min(H, h + self.window_size + 1)):
for k in range(max(0, w - self.window_size), min(W, w + self.window_size + 1)):
attn_mask[i, j * W + k] = 1.0
attn_mask = attn_mask.view(1, H * W, H * W).repeat(B * self.n_head, 1, 1)
attn_mask = einops.rearrange(attn_mask, 'b h s t -> b h s t ()')
q_dot_k = torch.einsum('bhxd,bhyd->bhxy', q, k)
q_dot_k *= (self.head_dim ** -0.5)
q_dot_k.masked_fill_(attn_mask, float('-inf'))
attn = q_dot_k.softmax(dim=-1)
attn = self.attn_drop(attn