transformer

class PatchEmbedding(nn.Module):
    def __init__(self,inchannel,patch_size=16,emb_size=768):
        super(PatchEmbedding,self).__init__()
        self.projecction = nn.Sequential(
            nn.Conv2d(inchannel,emb_size,kernel_size=patch_size,stride=patch_size),
            Rearrange('b e (num_h) (num_w)-> b (num_h num_w) e')#e=h*w*c
        )
    def forward(self,x):
        x = self.projecction(x)
        return x

class MHA(nn.Module):
    def __init__(self,patchsize,inchannel,emb_size,num_heads,dropout=0):
        super(MHA,self).__init__()
        self.patchsize=patchsize
        self.inchannel =inchannel
        self.emb_size=emb_size
        self.num_heads = num_heads
        self.qkv = nn.Linear(emb_size,emb_size*3)
        self.att_drop = nn.Dropout(dropout)
        self.projection = nn.Linear(emb_size,emb_size)


    def forward(self,x,mask=None):
        qkv = rearrange(self.qkv(x),"b n (h d qkv)->(qkv) b h n d",\
                        h=self.num_heads,qkv=3)
        q,k,v = qkv[0],qkv[1],qkv[2]
        energy = torch.einsum('bhqd,bhkd->bhqk',q,k)#batch,num_heads,q_length,k_length
        if mask is not None:
            fill_value = torch.finfo(torch.float32).min
            energy.mask_fill(mask == 0,fill_value)
        scaling = self.emb_size**(1/2)
        att = F.softmax(energy,-1)/scaling
        att = self.att_drop(att)
        #sum up over the third axis
        out = torch.einsum('bhqk,bhkv->bhqv',att,v)
        out =rearrange(out,"b h n d->b n (h d)")
        out = self.projection(out)
        return out


你可能感兴趣的:(transformer,深度学习,人工智能)