在面试时经常有面试官问有没有做过Bert的预训练,今天特地记录一下。
参考链接:
https://blog.csdn.net/weixin_43476533/article/details/107512497
https://blog.csdn.net/qq_22472047/article/details/115528031
"""
transformers:version-4.5.1
个人感觉:数据预处理是比较麻烦的一个步骤,因为每个人的数据格式是不同的;而运行代码则大同小异。
本次预训练仅使用了MLM,未使用NSP(需要额外处理NSP标签,比较费劲)
github链接
https://github.com/tyistyler/some_little_program/tree/main/pretrain_new_bert
"""
# coding: utf-8
# Name: do_pretrain
# Author: dell
# Data: 2021/11/8
import os
import torch
import random
import warnings
import numpy as np
from argparse import ArgumentParser
from transformers import TrainingArguments, Trainer, DataCollatorForLanguageModeling
from transformers import BertTokenizer, BertConfig, BertForMaskedLM
from transformers.trainer_utils import get_last_checkpoint
from transformers import TextDataset
# 设置随机种子
def setup_seed(seed):
torch.manual_seed(seed) # 为cpu分配随机种子
if torch.cuda.is_available():
torch.cuda.manual_seed(seed) # 为gpu分配随机种子
torch.cuda.manual_seed_all(seed) # 若使用多块gpu,使用该命令设置随机种子
random.seed(seed)
np.random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmard = False
def main():
parser = ArgumentParser()
parser.add_argument("--pretrain_data_path", type=str, default="./pretrain_data/preprocessed_data.txt")
parser.add_argument("--pretrain_model_path", type=str, default="./ckpt/bert-base-chinese")
parser.add_argument("--data_caches", type=str, default="./caches")
parser.add_argument("--vocab_path", type=str, default="./pretrain_data/vocab.txt")
parser.add_argument("--config_path", type=str, default="./pretrain_data/config.json")
parser.add_argument("--checkpoint_save_path", type=str, default="./ckpt/checkpoint")
parser.add_argument("--save_path", type=str, default="./ckpt/bert-base-patent")
parser.add_argument("--num_train_epochs", type=int, default=50)
parser.add_argument("--max_seq_len", type=int, default=300)
parser.add_argument("--batch_size", type=int, default=16)
parser.add_argument("--learning_rate", type=float, default=1e-5)
parser.add_argument("--seed", type=int, default=1234)
parser.add_argument("--save_steps", type=int, default=5000)
parser.add_argument("--logging_steps", type=int, default=500)
parser.add_argument("--save_total_limit", type=int, default=5) # 限制checkpoints的数量,最多5个
# python通过调用warnings模块中定义的warn()函数来发出警告,我们可以通过警告过滤器进行控制是否发出警告消息。
warnings.filterwarnings("ignore")
args = parser.parse_args()
setup_seed(args.seed)
device = "cuda" if torch.cuda.is_available() else "cpu"
if not os.path.exists(os.path.dirname(args.save_path)):
os.makedirs(os.path.dirname(args.save_path))
tokenizer = BertTokenizer.from_pretrained(args.vocab_path, model_max_length=args.max_seq_len)
bert_config = BertConfig.from_pretrained(args.config_path)
model = BertForMaskedLM(config=bert_config)
model = model.to(device)
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=True, mlm_probability=0.15)
training_args = TrainingArguments(
seed=args.seed,
save_steps=args.save_steps,
logging_steps=args.logging_steps,
output_dir=args.checkpoint_save_path,
learning_rate=args.learning_rate,
save_total_limit=args.save_total_limit,
num_train_epochs=args.num_train_epochs,
per_device_train_batch_size=args.batch_size
)
print("=========loading TextDateset=========")
dataset = TextDataset(tokenizer=tokenizer, block_size=args.max_seq_len, file_path=args.pretrain_data_path)
print("=========TextDateset loaded =========")
trainer = Trainer(model, args=training_args, train_dataset=dataset, data_collator=data_collator)
last_checkpoint = get_last_checkpoint(training_args.output_dir)
if last_checkpoint is not None:
train_result = trainer.train(resume_from_checkpoint=last_checkpoint)
else:
print("=========training=========")
train_result = trainer.train()
print(train_result)
trainer.save_model(args.save_path)
tokenizer.save_vocabulary(args.save_path)
if __name__ == "__main__":
main()