自注意力机制(Self-Attention)

自注意力机制代码(pytorch版):

import torch
from torch import nn


class SelfAttention(nn.Module):
    """ self attention module"""

    def __init__(self, in_dim):
        super(SelfAttention, self).__init__()
        self.chanel_in = in_dim

        self.query = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
        self.key = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
        self.value = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)

        self.gamma = nn.Parameter(torch.zeros(1))

        self.softmax = nn.Softmax(dim=-1)

    def forward_sing(self, x):
        """
            inputs :
                x : input feature maps( B X C X H X W)
            returns :
                out : attention value + input feature
                attention: B X (HxW) X (HxW)
        """
        m_batchsize, C, height, width = x.size()
        proj_query = self.query(x).reshape(
            m_batchsize, -1, width*height).permute(0, 2, 1)
        proj_key = self.key(x).reshape(m_batchsize, -1, width*height)
        energy = proj_query.bmm(proj_key)
        attention = self.softmax(energy)
        proj_value = self.value(x).reshape(m_batchsize, -1, width*height)

        out = proj_value.bmm(attention.permute(0, 2, 1))
        out = out.reshape(m_batchsize, C, height, width)

        out = self.gamma * out + x
        return out

    def forward(self, x):
        if x.ndim == 5:
            B, T = x.shape[:2]
            x = self.forward_sing(x.flatten(0, 1)).unflatten(0, (B, T))
            return x
        else:
            return self.forward_sing(x)

``

你可能感兴趣的:(机器学习,pytorch,深度学习,机器学习)