一.参考链接
- BERT理论解读:https://blog.csdn.net/ACM_hades/article/details/89375058
- 代码参考连接:https://github.com/graykode/nlp-tutorial
二.代码
import math
import re
from random import *
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
maxlen = 30
batch_size = 6
max_pred = 5
n_layers = 6
n_heads = 12
d_model = 768
d_ff = 768*4
d_k = d_v = 64
n_segments = 2
text = (
'Hello, how are you? I am Romeo.\n'
'Hello, Romeo My name is Juliet. Nice to meet you.\n'
'Nice meet you too. How are you today?\n'
'Great. My baseball team won the competition.\n'
'Oh Congratulations, Juliet\n'
'Thanks you Romeo'
)
sentences = re.sub("[.,!?\\-]", '', text.lower()).split('\n')
word_list = list(set(" ".join(sentences).split()))
word_dict = {'[PAD]' : 0, '[CLS]' : 1, '[SEP]' : 2, '[MASK]' : 3}
for i, w in enumerate(word_list):
word_dict[w] = i + 4
number_dict = {i: w for i, w in enumerate(word_dict)}
vocab_size = len(word_dict)
token_list = list()
for sentence in sentences:
arr = [word_dict[s] for s in sentence.split()]
token_list.append(arr)
def make_batch():
batch = []
positive = negative = 0
while positive != batch_size/2 or negative != batch_size/2:
tokens_a_index, tokens_b_index= randrange(len(sentences)), randrange(len(sentences))
tokens_a, tokens_b= token_list[tokens_a_index], token_list[tokens_b_index]
input_ids = [word_dict['[CLS]']] + tokens_a + [word_dict['[SEP]']] + tokens_b + [word_dict['[SEP]']]
segment_ids = [0] * (1 + len(tokens_a) + 1) + [1] * (len(tokens_b) + 1)
n_pred = min(max_pred, max(1, int(round(len(input_ids) * 0.15))))
cand_maked_pos = [i for i, token in enumerate(input_ids)
if token != word_dict['[CLS]'] and token != word_dict['[SEP]']]
shuffle(cand_maked_pos)
masked_tokens, masked_pos = [], []
for pos in cand_maked_pos[:n_pred]:
masked_pos.append(pos)
masked_tokens.append(input_ids[pos])
if random() < 0.8:
input_ids[pos] = word_dict['[MASK]']
elif random() < 0.5:
index = randint(0, vocab_size - 1)
input_ids[pos] = word_dict[number_dict[index]]
n_pad = maxlen - len(input_ids)
input_ids.extend([0] * n_pad)
segment_ids.extend([0] * n_pad)
if max_pred > n_pred:
n_pad = max_pred - n_pred
masked_tokens.extend([0] * n_pad)
masked_pos.extend([0] * n_pad)
if tokens_a_index + 1 == tokens_b_index and positive < batch_size/2:
batch.append([input_ids, segment_ids, masked_tokens, masked_pos, True])
positive += 1
elif tokens_a_index + 1 != tokens_b_index and negative < batch_size/2:
batch.append([input_ids, segment_ids, masked_tokens, masked_pos, False])
negative += 1
return batch
def get_attn_pad_mask(seq_q, seq_k):
batch_size, len_q = seq_q.size()
batch_size, len_k = seq_k.size()
pad_attn_mask = seq_k.data.eq(0).unsqueeze(1)
return pad_attn_mask.expand(batch_size, len_q, len_k)
def gelu(x):
"Implementation of the gelu activation function by Hugging Face"
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
class Embedding(nn.Module):
def __init__(self):
super(Embedding, self).__init__()
self.tok_embed = nn.Embedding(vocab_size, d_model)
self.pos_embed = nn.Embedding(maxlen, d_model)
self.seg_embed = nn.Embedding(n_segments, d_model)
self.norm = nn.LayerNorm(d_model)
def forward(self, x, seg):
seq_len = x.size(1)
pos = torch.arange(seq_len, dtype=torch.long)
pos = pos.unsqueeze(0).expand_as(x)
embedding = self.tok_embed(x) + self.pos_embed(pos) + self.seg_embed(seg)
return self.norm(embedding)
class ScaledDotProductAttention(nn.Module):
def __init__(self):
super(ScaledDotProductAttention, self).__init__()
def forward(self, Q, K, V, attn_mask):
scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(d_k)
scores.masked_fill_(attn_mask, -1e9)
attn = nn.Softmax(dim=-1)(scores)
context = torch.matmul(attn, V)
return context, attn
class MultiHeadAttention(nn.Module):
def __init__(self):
super(MultiHeadAttention, self).__init__()
self.W_Q = nn.Linear(d_model, d_k * n_heads)
self.W_K = nn.Linear(d_model, d_k * n_heads)
self.W_V = nn.Linear(d_model, d_v * n_heads)
def forward(self, Q, K, V, attn_mask):
residual, batch_size = Q, Q.size(0)
q_s = self.W_Q(Q).view(batch_size, -1, n_heads, d_k).transpose(1,2)
k_s = self.W_K(K).view(batch_size, -1, n_heads, d_k).transpose(1,2)
v_s = self.W_V(V).view(batch_size, -1, n_heads, d_v).transpose(1,2)
attn_mask = attn_mask.unsqueeze(1).repeat(1, n_heads, 1, 1)
context, attn = ScaledDotProductAttention()(q_s, k_s, v_s, attn_mask)
context = context.transpose(1, 2).contiguous().view(batch_size, -1, n_heads * d_v)
output = nn.Linear(n_heads * d_v, d_model)(context)
return nn.LayerNorm(d_model)(output + residual), attn
class PoswiseFeedForwardNet(nn.Module):
def __init__(self):
super(PoswiseFeedForwardNet, self).__init__()
self.fc1 = nn.Linear(d_model, d_ff)
self.fc2 = nn.Linear(d_ff, d_model)
def forward(self, x):
return self.fc2(gelu(self.fc1(x)))
class EncoderLayer(nn.Module):
def __init__(self):
super(EncoderLayer, self).__init__()
self.enc_self_attn = MultiHeadAttention()
self.pos_ffn = PoswiseFeedForwardNet()
def forward(self, enc_inputs, enc_self_attn_mask):
enc_outputs, attn = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs, enc_self_attn_mask)
enc_outputs = self.pos_ffn(enc_outputs)
return enc_outputs, attn
class BERT(nn.Module):
def __init__(self):
super(BERT, self).__init__()
self.embedding = Embedding()
self.layers = nn.ModuleList([EncoderLayer() for _ in range(n_layers)])
self.fc = nn.Linear(d_model, d_model)
self.activ1 = nn.Tanh()
self.linear = nn.Linear(d_model, d_model)
self.activ2 = gelu
self.norm = nn.LayerNorm(d_model)
self.classifier = nn.Linear(d_model, 2)
embed_weight = self.embedding.tok_embed.weight
n_vocab, n_dim = embed_weight.size()
self.decoder = nn.Linear(n_dim, n_vocab, bias=False)
self.decoder.weight = embed_weight
self.decoder_bias = nn.Parameter(torch.zeros(n_vocab))
def forward(self, input_ids, segment_ids, masked_pos):
output = self.embedding(input_ids, segment_ids)
enc_self_attn_mask = get_attn_pad_mask(input_ids, input_ids)
for layer in self.layers:
output, enc_self_attn = layer(output, enc_self_attn_mask)
h_pooled = self.activ1(self.fc(output[:, 0]))
logits_clsf = self.classifier(h_pooled)
masked_pos = masked_pos[:, :, None].expand(-1, -1, output.size(-1))
h_masked = torch.gather(output, 1, masked_pos)
h_masked = self.norm(self.activ2(self.linear(h_masked)))
logits_lm = self.decoder(h_masked) + self.decoder_bias
return logits_lm, logits_clsf
model = BERT()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
batch = make_batch()
input_ids, segment_ids, masked_tokens, masked_pos, isNext = zip(*batch)
input_ids, segment_ids, masked_tokens, masked_pos, isNext = \
torch.LongTensor(input_ids), torch.LongTensor(segment_ids), torch.LongTensor(masked_tokens), \
torch.LongTensor(masked_pos), torch.LongTensor(isNext)
for epoch in range(1):
optimizer.zero_grad()
logits_lm, logits_clsf = model(input_ids, segment_ids, masked_pos)
loss_lm = criterion(logits_lm.transpose(1, 2), masked_tokens)
loss_lm = (loss_lm.float()).mean()
loss_clsf = criterion(logits_clsf, isNext)
loss = loss_lm + loss_clsf
if (epoch + 1) % 10 == 0:
print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss))
loss.backward()
optimizer.step()
input_ids, segment_ids, masked_tokens, masked_pos, isNext = batch[0]
print(text)
print([number_dict[w] for w in input_ids if number_dict[w] != '[PAD]'])
logits_lm, logits_clsf = model(torch.LongTensor([input_ids]), \
torch.LongTensor([segment_ids]), torch.LongTensor([masked_pos]))
logits_lm = logits_lm.data.max(2)[1][0].data.numpy()
print('masked tokens list : ',[pos for pos in masked_tokens if pos != 0])
print('predict masked tokens list : ',[pos for pos in logits_lm if pos != 0])
logits_clsf = logits_clsf.data.max(1)[1].data.numpy()[0]
print('isNext : ', True if isNext else False)
print('predict isNext : ',True if logits_clsf else False)