[Pytorch] Transformer+Siamese进行文本相似度分析

摘要

使用Transformer 代替RNN进行sentence-level context的encode


Requirements

  1. Pytorch >= 0.4.1
  2. Python >= 3.5

General Usage Transformer

用于进行encode的transformer我已经写好放在Git上了,小伙伴随意取用,别忘记点星星哟~

https://github.com/WenYanger/General-Transformer-Pytorch


Performance

The training speed improves from average 59 it/s to average 77 it/s.

1. RNN

[Pytorch] Transformer+Siamese进行文本相似度分析_第1张图片

2. Transformer

[Pytorch] Transformer+Siamese进行文本相似度分析_第2张图片


Dummy Siamese

import torch
from torch import nn
import torch.nn.functional as F

import numpy as np
from tqdm import tqdm
import math,time

from Transformer import TransformerBlock

######################################
###
###  A Dummy Model of Siamese Net
###
######################################

'''
Initiation && HyperParameter
'''
np.random.seed(123)

SAMPLE_SIZE = 10000
EPOCH = 20
BATCH_SIZE = 16
LR = 1e-4
BATCH_NUMS = int(math.ceil(SAMPLE_SIZE/BATCH_SIZE))
RAND_ID = np.random.permutation(list(range(SAMPLE_SIZE)))

RNN_TYPE = 'LSTM'  ## OPTIONAL: LSTM/GRU
RNN_TIMESTEPS = 10
RNN_EMB_SIZE = 300
RNN_HIDDEN_SIZE = 256
TRANSFORMER_HIDDEN_SIZE = RNN_HIDDEN_SIZE
TRANSFORMER_HEADS = 4
TRANSFORMER_DROPOUT = 0.1
RNN_MERGE_MODE = 'CONCATE'

DEVICE = 'gpu'


'''
Data Preparation
'''
x1 = np.random.rand(SAMPLE_SIZE, RNN_TIMESTEPS, TRANSFORMER_HIDDEN_SIZE).astype(np.float16)
x2 = np.random.rand(SAMPLE_SIZE, RNN_TIMESTEPS, TRANSFORMER_HIDDEN_SIZE).astype(np.float16)

y = np.random.randint(low=0, high=2, size=SAMPLE_SIZE)


'''
Define Network
'''
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        self.transformer = TransformerBlock(TRANSFORMER_HIDDEN_SIZE, TRANSFORMER_HEADS, TRANSFORMER_HIDDEN_SIZE * 4, TRANSFORMER_DROPOUT, is_cuda=True)

        self.FC_mul = nn.Sequential(
            nn.Linear(256, 128),
            nn.BatchNorm1d(128, 1),
            nn.ReLU()
        )

        self.FC_minus = nn.Sequential(
            nn.Linear(256, 128),
            nn.BatchNorm1d(128, 1),
            nn.ReLU()
        )

        self.final = nn.Sequential(
            nn.Linear(128*2, 2),
            nn.LogSoftmax(dim=1)
        )

    def forward(self, input):
        text_1 = input[0]
        text_2 = input[1]

        rnn1 = self.transformer(text_1, mask=None)
        rnn2 = self.transformer(text_2, mask=None)

        # Approach 1: Use Final State as Context
        context_1 = rnn1[:, -1, :].clone()  # Warning: clone should be added while slicing
        context_2 = rnn2[:, -1, :].clone()


        # Interaction
        mul = context_1.mul_(context_2)
        minus = context_1.add_(-context_2)


        interation_feature = torch.cat(
            [
                self.FC_mul(mul),
                self.FC_minus(minus),
            ], 1)

        output = self.final(interation_feature)

        return output


class RNN(nn.Module):
    def __init__(self, rnn_type='LSTM'):
        super(RNN, self).__init__()

        self.rnn_type = rnn_type
        self.rnn_hidden_state = None

        if rnn_type == 'LSTM':
            self.rnn = nn.LSTM(
                input_size=RNN_EMB_SIZE,
                hidden_size=RNN_HIDDEN_SIZE,  # rnn hidden unit
                num_layers=1,  # 有几层 RNN layers
                batch_first=True,  # input & output 会是以 batch size 为第一维度的特征集 e.g. (batch, time_step, input_size)
                bidirectional=True,
            )
        elif rnn_type == 'GRU':
            self.rnn = nn.GRU(
                input_size=RNN_EMB_SIZE,
                hidden_size=RNN_HIDDEN_SIZE,  # rnn hidden unit
                num_layers=1,  # 有几层 RNN layers
                batch_first=True,  # input & output 会是以 batch size 为第一维度的特征集 e.g. (batch, time_step, input_size)
                bidirectional=True,
            )

    def forward(self, input):

        output, self.rnn_hidden_state = self.rnn(input, None)

        return output



def summary(mode='console'):

    params_values  = [RNN_TYPE, RNN_HIDDEN_SIZE, LR]
    params_keys    = ['RNN_TYPE', 'RNN_HIDDEN_SIZE', 'LR']

    params_values = [str(item) for item in params_values]
    max_len = max([len(item) for item in params_keys]) + 1

    def format(key, value, max_len):
        return key + ':' + ' ' * (max_len-len(key)) + value

    if mode == 'console':
        print('#' * 30)
        print('#' * 5)
        print('#' * 5 + '  Model Summary')
        print('#' * 5 + '  ' + '-' * 23)
        for i in range(len(params_keys)):
            print('#' * 5 + '  ' + format(params_keys[i], params_values[i], max_len))
        print('#' * 5)
        print('#' * 30)

    print('-' * 30)

'''
Training Phrase
'''
net = Net().cuda()
summary(mode='console')
loss_func = nn.NLLLoss().cuda()
optimizer = torch.optim.Adam(net.parameters(), lr=LR)

for epoch in range(EPOCH):

    time.sleep(0.1)
    loss_history = 0
    for i in tqdm(range(BATCH_NUMS)):
        ids = RAND_ID[i*BATCH_SIZE : (i+1)*BATCH_SIZE]
        batch_x1 = torch.Tensor(x1[ids]).cuda()
        batch_x2 = torch.Tensor(x2[ids]).cuda()
        batch_y = torch.Tensor(y[ids]).long().cuda()


        output = net([batch_x1, batch_x2])
        loss = loss_func(output, batch_y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        loss_history += loss.item()
    time.sleep(0.1)
    print('EPOCH: {}, Loss: {} \n'.format(epoch, loss_history/BATCH_NUMS))

Reference

《The Annotated Transformer》by Harvard NLP: http://nlp.seas.harvard.edu/2018/04/03/attention.html

你可能感兴趣的:(自然语言处理,深度学习)