self-attention(pytorch 实现)

来源: MEF-GAN: Multi-Exposure Image Fusion via Generative Adversarial Networks

self-attention(pytorch 实现)_第1张图片

self-attention(pytorch 实现)_第2张图片

 

class Attention(nn.Module):
    def __init__(self, bn=True):
        super(Attention, self).__init__()

        self.conv1 = nn.Conv2d(6, 16, kernel_size=3, stride=2)
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=2)


        self.bn = nn.BatchNorm2d(16)
        self.relu = nn.ReLU()

        self.bn2 = nn.BatchNorm2d(32)


        self.Cv1 = nn.Conv2d(32, 32, kernel_size=1, stride=1)

        self.cv2 = nn.Conv2d(32, 8, kernel_size=1, stride=1)
        self.cv3 = nn.Conv2d(32, 8, kernel_size=1, stride=1)

    def forward(self, under, over):

        x = torch.cat((under, over), dim=1)
        output = self.relu(self.bn(self.conv1(x)))
        output = self.maxpool(output)
        output = self.relu(self.bn2(self.conv2(output)))

        C = self.Cv1(output)
        C = C.view(C.shape[0] * C.shape[1], C.shape[2] * C.shape[3])

        c1 = self.cv2(output)
        c1 = c1.view(c1.shape[0] * c1.shape[2] * c1.shape[3], 8)

        c2 = self.cv3(output)
        c2 = c2.view(c2.shape[0] * c2.shape[2] * c2.shape[3], 8).t()

        c = torch.nn.Softmax(torch.mm(c1, c2), dim=1)

        c = c.view(output.shape[0], c.shape[0], int(c.shape[1] // output.shape[0]))



        c = c.view(c.shape[0] * c.shape[1], c.shape[2])

        attention_map = torch.mm(C, c.t())


        attention_map = attention_map.view(output.shape[0], output.shape[1], output.shape[2] * output.shape[0], output.shape[3] * output.shape[0] )

        attention_map = F.interpolate(attention_map, size=[under.shape[2], under.shape[3]])

        return attention_map

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

你可能感兴趣的:(pytorch,图像处理,深度学习,矩阵)