有任何问题欢迎在下面留言
本篇文章的代码运行界面均在PyCharm中进行
本篇文章配套的代码资源已经上传
从零构建属于自己的GPT系列1:文本数据预处理
从零构建属于自己的GPT系列2:语言模型训练
在本任务的训练数据中,我选择了金庸的15本小说,全部都是txt文件
数据打开后的样子
数据预处理需要做的事情就是使用huggingface的transformers包的tokenizer模块,将文本转化为token
最后生成的文件就是train_novel.pkl文件,就不用在训练的时候读txt文件了
数据预处理:preprocess.py
import argparse
from utils import set_logger
from transformers import CpmTokenizer
import os
import pickle
from tqdm import tqdm
parser = argparse.ArgumentParser()
parser.add_argument('--vocab_file', default='vocab/chinese_vocab.model', type=str, required=False,
help='词表路径')
parser.add_argument('--log_path', default='log/preprocess.log', type=str, required=False, help='日志存放位置')
parser.add_argument('--data_path', default='data/novel', type=str, required=False, help='数据集存放位置')
parser.add_argument('--save_path', default='data/train.pkl', type=str, required=False,
help='对训练数据集进行tokenize之后的数据存放位置')
parser.add_argument('--win_size', default=200, type=int, required=False,
help='滑动窗口的大小,相当于每条数据的最大长度')
parser.add_argument('--step', default=200, type=int, required=False, help='滑动窗口的滑动步幅')
args = parser.parse_args()
logger = set_logger(args.log_path)
def set_logger(log_path):
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
file_handler = logging.FileHandler(filename=log_path)
file_handler.setFormatter(formatter)
file_handler.setLevel(logging.INFO)
logger.addHandler(file_handler)
console = logging.StreamHandler()
console.setLevel(logging.DEBUG)
console.setFormatter(formatter)
logger.addHandler(console)
return logger
logger = set_logger(args.log_path)
tokenizer = CpmTokenizer(vocab_file="vocab/chinese_vocab.model") # pip install jieba
eod_id = tokenizer.convert_tokens_to_ids("" ) # 文档结束符
sep_id = tokenizer.sep_token_id
train_list = []
logger.info("start tokenizing data")
for file in tqdm(os.listdir(args.data_path)):
file = os.path.join(args.data_path, file)
with open(file, "r", encoding="utf8") as reader:
lines = reader.readlines()
for i in range(len(lines)):
if lines[i].isspace() != True and lines[i] != '\n':
token_ids = tokenizer.encode(lines[i].strip(), add_special_tokens=False) + [eod_id]
if i % 1000 == 0:
print('cur_step', i, lines[i].strip())
else:
continue
win_size = args.win_size
step = args.step
start_index = 0
end_index = win_size
data = token_ids[start_index:end_index]
train_list.append(data)
start_index += step
end_index += step
while end_index + 50 < len(token_ids): # 剩下的数据长度,大于或等于50,才加入训练数据集
data = token_ids[start_index:end_index]
train_list.append(data)
start_index += step
end_index += step
# 序列化训练数据
with open(args.save_path, "wb") as f:
pickle.dump(train_list, f)
os.listdir(args.data_path)
:得到该路径下所有文件的文件名字符串并返回一个字符串数组,for file in tqdm的for循环会打印读取进度的进度条file
路径、utf-8编码格式、只读模式打开文件从零构建属于自己的GPT系列1:文本数据预处理
从零构建属于自己的GPT系列2:语言模型训练