主要分为以下几个步骤:
encode
和decode
函数分别编码与解码,注意参数add_special_tokens
和skip_special_tokens
# -*- encoding: utf-8 -*-
import warnings
warnings.filterwarnings('ignore')
from transformers import BertModel, BertTokenizer, BertConfig
import os
from os.path import dirname, abspath
root_dir = dirname(dirname(dirname(abspath(__file__))))
import torch
# 把预训练的模型从官网下载下来放到目录中
pretrained_path = os.path.join(root_dir, 'pretrained/bert_zh')
# 从文件中加载bert模型
model = BertModel.from_pretrained(pretrained_path)
# 从bert目录中加载词典
tokenizer = BertTokenizer.from_pretrained(pretrained_path)
print(f'vocab size :{tokenizer.vocab_size}')
# 把'[PAD]'编码
print(tokenizer.encode('[PAD]'))
print(tokenizer.encode('[SEP]'))
# 把中文句子编码,默认加入了special tokens了,也就是句子开头加入了[CLS] 句子结尾加入了[SEP]
ids = tokenizer.encode("我是中国人", add_special_tokens=True)
# 从结果中看,101是[CLS]的id,而2769是"我"的id
# [101, 2769, 3221, 704, 1744, 782, 102]
print(ids)
# 把ids解码为中文,默认是没有跳过特殊字符的
print(tokenizer.decode([101, 2769, 3221, 704, 1744, 782, 102], skip_special_tokens=False))
# print(model)
inputs = torch.tensor(ids).unsqueeze(0)
# forward,result是一个tuple,第一个tensor是最后的hidden-state
result = model(torch.tensor(inputs))
# [1, 5, 768]
print(result[0].size())
# [1, 768]
print(result[1].size())
for name, parameter in model.named_parameters():
# 打印每一层,及每一层的参数
print(name)
# 每一层的参数默认都requires_grad=True的,参数是可以学习的
print(parameter.requires_grad)
# 如果只想训练第11层transformer的参数的话:
if '11' in name:
parameter.requires_grad = True
else:
parameter.requires_grad = False
print([p.requires_grad for name, p in model.named_parameters()])
添加atten_mask的方法:
其中101是[CLS],102是[SEP],0是[PAD]
>>> a
tensor([[101, 3, 4, 23, 11, 1, 102, 0, 0, 0]])
>>> notpad = a!=0
>>> notpad
tensor([[ True, True, True, True, True, True, True, False, False, False]])
>>> notcls = a!=101
>>> notcls
tensor([[False, True, True, True, True, True, True, True, True, True]])
>>> notsep = a!=102
>>> notsep
tensor([[ True, True, True, True, True, True, False, True, True, True]])
>>> mask = notpad & notcls & notsep
>>> mask
tensor([[False, True, True, True, True, True, False, False, False, False]])
>>>