【无标题】

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
from PIL import Image
import matplotlib.pyplot as plt
 

def to_2tuple(x):
    if isinstance(x, tuple):
        return x
    return (x, x)

def to_4tuple(x):
    if isinstance(x, tuple):
        return x + x
    return (x, x, x, x)

def get_padding(kernel_size, dilation):
    kernel_size, dilation = to_2tuple(kernel_size), to_2tuple(dilation)
    padding = tuple(((k - 1) * d) // 2 for k, d in zip(kernel_size, dilation))
    return padding

def get_stride(kernel_size):
    kernel_size = to_2tuple(kernel_size)
    stride = kernel_size
    return stride

def get_dilation(kernel_size):
    kernel_size = to_2tuple(kernel_size)
    dilation = tuple(k // 2 for k in kernel_size)
    return dilation

def get_norm(norm_type, out_channels):
    if norm_type == 'bn':
        return nn.BatchNorm2d(out_channels)
    elif norm_type == 'gn':
        return nn.GroupNorm(32, out_channels)

class ModulatedDeformConv(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        stride=1,
        padding=0,
        dilation=1,
        deformable_groups=1,
        groups=1,
        bias=True,
        norm='',
        activation=F.relu,
        use_deform=True,
        use_sk=False,
        use_se=False,
        use_mix=False,
        attn_dim=64
    ):
        super(ModulatedDeformConv, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = to_4tuple(kernel_size)
        self.stride = to_4tuple(stride)
        self.padding = to_4tuple(padding)
        self.dilation = to_4tuple(dilation)
        self.deformable_groups = deformable_groups
        self.groups = groups
        self.use_deform = use_deform
        self.use_sk = use_sk
        self.use_se = use_se
        self.use_mix = use_mix

        self.weight = nn.Parameter(
            torch.Tensor(out_channels, in_channels // groups, *self.kernel_size)
        )
        if bias:
            self.bias = nn.Parameter(torch.Tensor(out_channels))
        else:
            self.bias = None

        self.norm = get_norm(norm, out_channels)
        self.activation = activation

        self.conv_offset = nn.Conv2d(
            self.in_channels,
            self.deformable_groups * 2 * np.prod(self.kernel_size),
            kernel_size=self.kernel_size,
            stride=self.stride,
            padding=self.padding,
            dilation=self.dilation,
            bias=True
        )

        if self.use_sk:
            self.conv_sk = nn.Conv2d(
                self.in_channels,
                self.in_channels,
                kernel_size=1,
                stride=1,
                padding=0,
                bias=False
            )
            self.bn_sk = get_norm(norm, self.in_channels)
            self.activation_sk = nn.ReLU(inplace=True)

        if self.use_se:
            self.linear1 = nn.Conv2d(in_channels, attn_dim, kernel_size=1, stride=1, padding=0)
            self.activation_se1 = nn.ReLU(inplace=True)
            self.linear2 = nn.Conv2d(attn_dim, self.in_channels, kernel_size=1, stride=1, padding=0)
            self.activation_se2 = nn.Sigmoid()

        if self.use_mix:
            self.linear_mix = nn.Linear(self.in_channels, self.in_channels, bias=True)
            nn.init.normal_(self.linear_mix.weight, std=0.01)
            nn.init.constant_(self.linear_mix.bias, 0)

        self.reset_parameters()

    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            nn.init.uniform_(self.bias, -bound, bound)

    def forward(self, x):
        offset = self.conv_offset(x)
        B, C, H, W = offset.size()
        offset = offset.view(B, self.deformable_groups, 2 * np.prod(self.kernel_size), H, W).contiguous()
        offset_mean = torch.mean(torch.abs(offset[:, :, :9, :, :])) * 1. / 9
        offset = offset.view(
            B,
            self.deformable_groups,
            2,
            np.prod(self.kernel_size),
            H,
            W
        ).permute(0, 1, 4, 5, 3, 2).contiguous()
        if self.use_deform:
            q = torch.cat((
                torch.zeros_like(offset[:, :, :, :, :1, :1]),
                offset[:, :, :, :, :1, :1]
            ), dim=-1)
            q = q.view(B, -1, *q.size()[3:])
            x = torch.nn.functional.deform_conv2d(
                x,
                q,
                self.weight,
                self.stride,
                self.padding,
                self.dilation,
                self.groups,
            )
        else:
            x = F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
        if self.use_mix:
            x = self.linear_mix(x)
        if self.norm is not None:
            x = self.norm(x)
        if self.activation is not None:
            x = self.activation(x)

        if self.use_se:
            attn = torch.sigmoid(self.linear2(self.activation_se1(self.linear1(F.avg_pool2d(x, x.size()[2:])))))
            x = x * attn

        if self.use_sk:
            residual = x
            x = self.conv_sk(x)
            x = self.bn_sk(x)
            x = self.activation_sk(x)
            x = x + residual
        
        return x, offset_mean


class DeformableAttention(nn.Module):

    def __init__(self, dim_in, dim_out, dim_inner, n_head, window_size, stride=1):
        super().__init__()

        self.dim = dim_out
        self.n_head = n_head
        self.head_dim = self.dim // self.n_head
        self.window_size = window_size
        self.stride = stride

        self.in_conv = nn.Conv2d(dim_in, dim_inner, 1, bias=False)
        self.qkv_conv = nn.Conv2d(dim_inner, self.dim * 3, 1, bias=False)
        self.attn_drop = nn.Dropout(0.1)

        self.norm = nn.BatchNorm2d(dim_in)
        self.proj = nn.Conv2d(self.dim, dim_out, 1)
        self.proj_drop = nn.Dropout(0.1)

        self.local_attn = LocalAttention(self.dim, n_head, window_size, 0.1, 0.1)

    def forward(self, x):
        B, C, H, W = x.shape

        x = self.norm(x)

        x = self.in_conv(x)

        qkv = self.qkv_conv(x)
        qkv = qkv.reshape(B, 3, self.n_head, self.head_dim, H, W)
        qkv = qkv.permute(1, 2, 0, 3, 4, 5)
        q, k, v = qkv[0], qkv[1], qkv[2]
        q, k, v = map(lambda t: einops.rearrange(t, 'b hh c h w -> (b hh) (h w) c'), [q, k, v])

        attn_mask = torch.zeros((H * W, H * W), device=x.device).float()
        for i in range(H * W):
            h, w = i // W, i % W
            for j in range(max(0, h - self.window_size), min(H, h + self.window_size + 1)):
                for k in range(max(0, w - self.window_size), min(W, w + self.window_size + 1)):
                    attn_mask[i, j * W + k] = 1.0

        attn_mask = attn_mask.view(1, H * W, H * W).repeat(B * self.n_head, 1, 1)
        attn_mask = einops.rearrange(attn_mask, 'b h s t -> b h s t ()')

        q_dot_k = torch.einsum('bhxd,bhyd->bhxy', q, k)
        q_dot_k *= (self.head_dim ** -0.5)

        q_dot_k.masked_fill_(attn_mask, float('-inf'))

        attn = q_dot_k.softmax(dim=-1)
        attn = self.attn_drop(attn 

你可能感兴趣的:(python,深度学习,开发语言)