使用Transformer编码器进行序列数据分类(上)

使用Transformer编码器进行序列数据分类(上)

搭建Encoder Class

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
len_traj = 13
batch_size = 3
d_obs = 6
d_embed = 7 # embedding dimension
n_heads = 8
d_k = 16
d_hidden = 16
d_class = 2
n_layers = 4 # Encoder内含
trajectory = torch.rand(batch_size, len_traj, d_obs)

class Embedding(nn.Module):
    '''将轨迹序列映射到隐空间'''
    def __init__(self, inpt_dim, embed_dim):
        super(Embedding, self).__init__()
        self.fc = nn.Linear(inpt_dim, embed_dim)
    
    def forward(self, x):
        x = self.fc(x)
        return x

class ScaledDotProductAttention(nn.Module):
    def __init__(self, d_k):
        super(ScaledDotProductAttention, self).__init__()
        self.d_k = d_k
    
    def forward(self, Q, K, V):
        # scores : [batch_size x n_heads x len_q(=len_k) x len_k(=len_q)] [1,8,5,5]
        scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(self.d_k)
        attn = nn.Softmax(dim=-1)(scores)
        context = torch.matmul(attn, V)
        return context, attn

class MultiHeadAttention(nn.Module):
    def __init__(self, d_embed, d_k, n_heads):
        super(MultiHeadAttention, self).__init__()
        self.n_heads = n_heads
        self.d_k = d_k
        self.W_Q = nn.Linear(d_embed, d_k * n_heads) # d_embed,7维, d_k,16*8=128维
        self.W_K = nn.Linear(d_embed, d_k * n_heads)
        self.W_V = nn.Linear(d_embed, d_k * n_heads)
        self.fc = nn.Linear(n_heads * d_k, d_embed)
        self.layer_norm = nn.LayerNorm(d_embed)
        self.DotProduct = ScaledDotProductAttention(d_k)

    def forward(self, x):
        # q: [batch_size x len_q x d_model], k: [batch_size x len_k x d_model], v: [batch_size x len_k x d_model]
        residual, batch_size = x, x.size(0) # 残差跨层连接
        
        # q_s = k_s = v_s: [batch_size, n_heads, len_q, d_k]
        q_s = self.W_Q(x).view(batch_size, -1, self.n_heads, self.d_k).transpose(1,2)
        k_s = self.W_K(x).view(batch_size, -1, self.n_heads, self.d_k).transpose(1,2)
        v_s = self.W_V(x).view(batch_size, -1, self.n_heads, self.d_k).transpose(1,2)
        
        # context: [batch_size, n_heads, len_q, d_k]
        # attn: [batch_size, n_heads, len_q(=len_k), len_k(=len_q)]
        context, attn = self.DotProduct(q_s, k_s, v_s) # context是attn✖V
        # contiguous()的功能类似deepcopy
        context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.n_heads * self.d_k) # context: [batch_size x len_q x n_heads * d_k] 最后一个维度是将8个head concat起来,维度依然512
        
        output = self.fc(context) # [batch_size, len_q, d_embed]
        return self.layer_norm(output + residual), attn # output: [batch_size, len_q, d_model]

class PoswiseFeedForwardNet(nn.Module):
    # 该模块也可用linear+ReLU实现
    def __init__(self, d_embed, d_hidden):
        super(PoswiseFeedForwardNet, self).__init__()
        self.conv1 = nn.Conv1d(in_channels=d_embed, out_channels=d_hidden, kernel_size=1)
        self.conv2 = nn.Conv1d(in_channels=d_hidden, out_channels=d_embed, kernel_size=1)
        self.layer_norm = nn.LayerNorm(d_embed)
    def forward(self, x):
        residual = x # [batch_size, len_q, d_model]
        x = nn.ReLU()(self.conv1(x.transpose(1, 2)))
        x = self.conv2(x).transpose(1, 2)
        return self.layer_norm(x + residual)

class EncoderLayer(nn.Module):
    def __init__(self, d_embed, d_k, n_heads, d_hidden):
        super(EncoderLayer, self).__init__()
        self.MultiHeadAttention = MultiHeadAttention(d_embed, d_k, n_heads)
        self.PoswiseFeedForwardNet = PoswiseFeedForwardNet(d_embed, d_hidden)

    def forward(self, x):
        x, attn = self.MultiHeadAttention(x) # x to same Q,K,V
        x = self.PoswiseFeedForwardNet(x) # x: [batch_size, len_q, d_embed]
        return x, attn

class Encoder(nn.Module):
    '''
    using transformer encoder to classify sequential data
    '''
    def __init__(self, d_obs, d_embed, d_class, d_k, d_hidden, n_heads, n_layers):
        super(Encoder, self).__init__()
        self.embedding = Embedding(inpt_dim=d_obs, embed_dim=d_embed) # state dimension,embedding dimension
        self.layers = nn.ModuleList([EncoderLayer(d_embed, d_k, n_heads, d_hidden) for _ in range(n_layers)])
        self.fc = nn.Linear(d_embed, d_class)

    def forward(self, x): # enc_inputs : [batch_size x source_len]
        y = self.embedding(x)
        attentions = []
        for layer in self.layers:
            y, attention = layer(y)
            attentions.append(attention)
        
        y = y.mean(dim=1) # [batch_size, d_embed]
        out = F.log_softmax(self.fc(y), dim=-1)
        return out, attentions

if __name__ == '__main__':
    encoder = Encoder(d_obs, d_embed, d_class, d_k, d_hidden, n_heads, n_layers)
    for _ in range(10):
        trajectory = torch.rand(batch_size, len_traj, d_obs, dtype=torch.float64)
        pred, _ = encoder(trajectory)
        print(torch.argmax(pred,-1))


    '''
    from torchinfo import summary
    summary(encoder, (batch_size, len_traj, d_obs))
    print(context.shape, attn[0].shape)
    '''

你可能感兴趣的:(机器学习,transformer,分类,深度学习,序列数据)