深度学习笔记--Transformer中position encoding的源码理解与实现

1--源码

import torch
import math
import numpy as np
import torch.nn as nn

class Pos_Embed(nn.Module):
    def __init__(self, channels, num_frames, num_joints):
        super().__init__()
        
        # 根据帧序和节点序生成位置向量
        pos_list = [] 
        for tk in range(num_frames):
            for st in range(num_joints):
                pos_list.append(st)

        position = torch.from_numpy(np.array(pos_list)).unsqueeze(1).float()  # num_frames*num_joints, 1

        pe = torch.zeros(num_frames * num_joints, channels)  # T*N, C

        div_term = torch.exp(torch.arange(0, channels, 2).float() * -(math.log(10000.0) / channels))

        pe[:, 0::2] = torch.sin(position * div_term)  # 偶数列 # 偶数C维度sin
        pe[:, 1::2] = torch.cos(position * div_term)  # 奇数列 # 奇数C维度cos
        pe = pe.view(num_frames, num_joints, channels).permute(2, 0, 1).unsqueeze(0)  # T N C -> C T N -> 1 C T N
        self.register_buffer('pe', pe)

    def forward(self, x):  # nctv # BCTN
        x = self.pe[:, :, :x.size(2)]
        return x

if __name__ == "__main__":
    B = 2
    C = 4
    T = 120
    N = 25
    x = torch.rand((B, C, T, N))

    Pos_embed_1 = Pos_Embed(C, T, N)
    PE = Pos_embed_1(x)
    # print(PE.shape) # 1 C T N
    x = x + PE

    print("All Done !")

2--源码分析与理解

原理理解:Positional Encoding(位置编码)

代码解释:

深度学习笔记--Transformer中position encoding的源码理解与实现_第1张图片

①代码 div_term = torch.exp(torch.arange(0, channels, 2).float() * -(math.log(10000.0) / channels)):

令:channels = C, torch.arange(0, channels, 2).float() = k(则k = 0, 2, ..., C-2);

-(math.log(10000.0) / channels)  \large {\color{Red} =\frac{-\log_{e}1000}{C}}

则:torch.arange(0, channels, 2).float() * -(math.log(10000.0) / channels)\large {\color{Red} =\frac{-k\log_{e}10000}{C}}

torch.exp(torch.arange(0, channels, 2).float() * -(math.log(10000.0) / channels))\LARGE {\color{Red} =e^{\frac{-k\log_{e}10000}{C}} = e^{\log_{e}\frac{-10000k}{C}} = \frac{-10000k}{C}};

②代码:pe[:, 0::2] = torch.sin(position * div_term)  pe[:, 1::2] = torch.cos(position * div_term):

令:position = p,则position * div_term\large {\color{Red} =p*\frac{-10000k}{C}=\frac{p}{10000^{\frac{k}{c}}}};

k等价为2ipe[:, 0::2]pe[:, 1::2]分别取行数列和奇数列,就可以得到上图绿框所示的公式。

3--参考

参考1

参考2

你可能感兴趣的:(transformer,python,深度学习)