transformer使用示例

关于transformer的一些基础知识,之前在看李宏毅视频的时候总结了一些,可以看here,到写此文章时,也基本忘的差不多了,故也不深究,讲两个关于transformer的基本应用,来方便理解与应用。

序列标注

参考文件transformer_postag.py.

1. 加载数据

1
2
#加载数据
train_data, test_data, vocab, pos_vocab = load_treebank()

其中load_treebank代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
def load_treebank():
# 需要下载,可以自行设置代码
nltk.set_proxy('http://192.168.0.28:1080')
# 如果没有的话那么则会下载,否则忽略
nltk.download('treebank')
from nltk.corpus import treebank

sents, postags = zip(*(zip(*sent) for sent in treebank.tagged_sents()))

vocab = Vocab.build(sents, reserved_tokens=[""])

tag_vocab = Vocab.build(postags)

train_data = [(vocab.convert_tokens_to_ids(sentence), tag_vocab.convert_tokens_to_ids(tags)) for sentence, tags in zip(sents[:3000], postags[:3000])]
test_data = [(vocab.convert_tokens_to_ids(sentence), tag_vocab.convert_tokens_to_ids(tags)) for sentence, tags in zip(sents[3000:], postags[3000:])]

return train_data, test_data, vocab, tag_vocab

加载后可以看到,train_datatest_data都是list,其中每一个sample都是tuple,分别是input和target。如下:

1
2
3
4
>>> train_data[0]
>>> Out[1]:
([2, 3, 4, 5, 6, 7, 4, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18],
[1, 1, 2, 3, 4, 5, 2, 6, 7, 8, 9, 10, 8, 5, 9, 1, 3, 11])

2. 数据处理

1
2
3
4
5
6
7
8
9

# 这个函数就是将其变成等长,填充使用,至于是0还是1还是其他值并不重要,因为还有mask~
def collate_fn(examples):
lengths = torch.tensor([len(ex[0]) for ex in examples])
inputs = [torch.tensor(ex[0]) for ex in examples]
targets = [torch.tensor(ex[1]) for ex in examples]
inputs = pad_sequence(inputs, batch_first=True, padding_value=vocab[""])
targets = pad_sequence(targets, batch_first=True, padding_value=vocab[""])
return inputs, lengths, targets, inputs != vocab[""]

3. 模型部分

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
class PositionalEncoding(nn.Module):
def __init__(self, d_model, dropout=0.1, max_len=512):
super(PositionalEncoding, self).__init__()

pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer('pe', pe)

def forward(self, x):
x = x + self.pe[:x.size(0), :]
return x

class Transformer(nn.Module):
def __init__(self, vocab_size, embedding_dim, hidden_dim, num_class,
dim_feedforward=512, num_head=2, num_layers=2, dropout=0.1, max_len=512, activation: str = "relu"):
super(Transformer, self).__init__()
# 词嵌入层
self.embedding_dim = embedding_dim
self.embeddings = nn.Embedding(vocab_size, embedding_dim)
self.position_embedding = PositionalEncoding(embedding_dim, dropout, max_len)
# 编码层:使用Transformer
encoder_layer = nn.TransformerEncoderLayer(hidden_dim, num_head, dim_feedforward, dropout, activation)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)
# 输出层
self.output = nn.Linear(hidden_dim, num_class)

def forward(self, inputs, lengths):
inputs = torch.transpose(inputs, 0, 1)
hidden_states = self.embeddings(inputs)
hidden_states = self.position_embedding(hidden_states)
attention_mask = length_to_mask(lengths) == False
hidden_states = self.transformer(hidden_states, src_key_padding_mask=attention_mask).transpose(0, 1)
logits = self.output(hidden_states)
log_probs = F.log_softmax(logits, dim=-1)
return log_probs

这里有几点可能需要注意的:

  • PositionalEncoding

因为self attention是没有像rnn位置信息编码的,所以transformer引入了positional encoding,使用绝对位置进行编码,对每一个输入加上position信息,可以看self.pe,这个一个static lookup table。目前也出现一些使用relative positional encoding的,也就是加入相对位置编码,这个在ner任务中挺常见,比如TENER和Flat-Lattice-Transformer。但是最近google证明这种相对位置编码只是引入了更多的信息特征进来。。

扯完上面这个,进入正题,那就是如何计算的。

看forward部分,发现首先进行了torch.transpose操作,然后进行self.position_embedding,这个transpose是否让你感到困惑呢?
如果没有就不用看了。。。

一般输入Embedding的shape是(batch_size, seq_length),然后对每个seq_length那维的token进行编码获取对应的feature。但是这里将其transpose了,变成了(seq_length, batch_size),这种操作是否理解呢?ok,举个例子:

1
2
tensor([[1, 2, 3],
[4, 5, 6]])

这个就是我们通常理解的(batch_size, seq_length),如果将其transpose下就变成了:

1
2
3
tensor([[1, 4],
[2, 5],
[3, 6]])

囔,是不是理解了呢,是对position进行embedding,然后接着看PositionalEncoding是如何forward的。

1
2
3
4
5
6
7
8
def forward(self, x):
x = x + self.pe[:x.size(0), :]
return x

>>> x.shape
Out[2]: torch.Size([70, 32, 128])
>>> self.pe.shape
Out[3]: torch.Size([512, 1, 128])

那么上述例子就是指获取self.pe前70个长度的位置编码信息,然后和x进行相加返回,从而带入了位置编码信息。

  • TransformerEncoder部分
1
2
3
>>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
>>> src = torch.rand(10, 32, 512)
>>> out = encoder_layer(src)

这部分就容易理解了,使用多少nhead和TransformerEncoder的num_layers。

句子极性二分类

这地方基本和之前一样,就是linear n_out=2,然后交叉熵算loss就行。
我稍微改动了下源码,这样理解起来会更方便。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
class Transformer(nn.Module):
def __init__(self, vocab_size, embedding_dim, hidden_dim, num_class,
dim_feedforward=512, num_head=2, num_layers=2, dropout=0.1, max_len=128, activation: str = "relu"):
super(Transformer, self).__init__()
# 词嵌入层
self.embedding_dim = embedding_dim
self.embeddings = nn.Embedding(vocab_size, embedding_dim)
self.position_embedding = PositionalEncoding(embedding_dim, dropout, max_len)
# 编码层:使用Transformer
encoder_layer = nn.TransformerEncoderLayer(hidden_dim, num_head, dim_feedforward, dropout, activation)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)
# 输出层
self.output = nn.Linear(hidden_dim, num_class)


def forward(self, inputs, lengths):
inputs = torch.transpose(inputs, 0, 1)
hidden_states = self.embeddings(inputs)
hidden_states = self.position_embedding(hidden_states)
lengths = lengths.cpu()
attention_mask = length_to_mask(lengths) == False
attention_mask = attention_mask.cuda()
hidden_states = self.transformer(hidden_states, src_key_padding_mask=attention_mask)

# 在这里,因为把seq_length那一维放前面,觉得有点怪怪的,所以这里transpose一下。
hidden_states = hidden_states.transpose(0, 1)
hidden_states = hidden_states[:, 0, :]
output = self.output(hidden_states)
log_probs = F.log_softmax(output, dim=1)
return log_probs

总结

目前nlp都变成了微调时代,关于transformer网络结构,感兴趣可以点击我上面链接,可以看看从代码层面如果实现encoder和decoder部分。

你可能感兴趣的:(transformer,深度学习,自然语言处理,人工智能,pytorch)