使用上节处理后的数据,数据存放在match_data文件下,将数据再处理成bert模型的输入数据格式
import torch
import os
import pickle as pkl
from tqdm import tqdm
from torch.utils.data import dataset
class TextMatchDataset(dataset.Dataset):
def __init__(self, config, path):
self.config = config
self.path = path
self.inference = False
self.max_len = self.config.pad_size
self.contents = self.load_dataset_match(config)
def load_dataset_match(self, config):
if "test" in self.path:
self.inference = True
if self.config.token_type:
pad, cls, sep = '[PAD]', '[CLS]', '[SEP]'
else:
pad, cls, sep = '' , '', ''
contens = []
lenth_count = []
file_stream = open(self.path, 'r', encoding="utf-8")
for line in tqdm(file_stream.readlines()):
lin = line.strip()
if not lin:
continue
if len(lin.split("\t")) != 3:
print(line)
continue
source, target, label = lin.split('\t')
token_id_full = []
mask_full = []
# 对超长序列进行截断
seq_source = config.tokenizer.tokenize(source[:(self.max_len - 2)])
seq_target = config.tokenizer.tokenize(target[:(self.max_len - 1)])
# 分别在句子收尾拼接特殊符号
seq_token = [cls] + seq_source + [sep] + seq_target + [sep]
# 段标记
seq_segment = [0] * (len(seq_source) + 2) + [1] * (len(seq_target) + 1)
# id化标记
seq_idx = self.config.tokenizer.convert_tokens_to_ids(seq_token)
# 根据max_len与seq_idx的长度产生填充序列
padding = [0] * ((self.max_len * 2) - len(seq_idx))
# seg_mask
seq_mask = [1] * len(seq_idx) + padding
# 对seq拼接填充序列
seq_idx = seq_idx + padding
# seq_segment
seq_segment = seq_segment + padding
# print(seq_idx)
# print(seq_mask)
# print(seq_segment)
# print(len(seq_idx))
# print(len(seq_mask))
# print(len(seq_segment))
assert len(seq_idx) == self.max_len * 2
assert len(seq_mask) == self.max_len * 2
assert len(seq_segment) == self.max_len * 2
token_id_full.append(seq_idx)
token_id_full.append(seq_mask)
token_id_full.append(seq_segment)
if self.inference:
token_id_full.append(label)
else:
token_id_full.append(int(label))
contens.append(token_id_full)
return contens
def __getitem__(self, index):
elements = self.contents[index]
seq_idx = torch.LongTensor(elements[0])
seq_mask = torch.LongTensor(elements[1])
seq_segment = torch.LongTensor(elements[2])
if not self.inference:
label = torch.LongTensor([elements[3]])
else:
label = [elements[3]]
return (seq_idx, seq_mask, seq_segment), label
def __len__(self):
return len(self.contents)
from param import Param
if __name__ == '__main__':
param = Param(base_path="./match_data", model_name="SimBERT_A")
train_data = TextMatchDataset(param, param.dev_path)
(token, mask, segment), label = train_data[0]
print(train_data[4300])
print(len(token))
print(len(mask))
print(len(segment))
import os.path as osp
# from util import mkdir_if_no_dir
import os
from transformers import BertTokenizer, ElectraTokenizer, AutoTokenizer
def mkdir_if_no_dir(path):
"""创建不存在的文件夹"""
if not os.path.exists(path):
os.mkdir(path)
class Param:
def __init__(self, base_path, model_name):
if "A" in model_name:
self.train_path = osp.join(base_path, 'train_A.txt') # 训练集
self.dev_path = osp.join(base_path, 'valid_A.txt') # 验证集
self.test_path = osp.join(base_path, 'test_A.txt') # 测试集
self.result_path = osp.join(base_path, "predict_A.csv")
else:
self.train_path = osp.join(base_path, 'train_B.txt') # 训练集
self.dev_path = osp.join(base_path, 'valid_B.txt') # 验证集
self.test_path = osp.join(base_path, 'test_B.txt') # 测试集
self.result_path = osp.join(base_path, "predict_B.csv")
print([self.train_path, self.dev_path, self.test_path, self.result_path])
mkdir_if_no_dir(osp.join(base_path, "saved_dict"))
mkdir_if_no_dir(osp.join(base_path, "log"))
self.save_path = osp.join(osp.join(base_path, 'saved_dict'), model_name + '.pt') # 模型训练结果
self.log_path = osp.join(osp.join(base_path, "log"), model_name) # 日志保存路径
self.vocab_path = osp.join(base_path, "vocab.pkl")
self.class_path = osp.join(base_path, "class.txt")
self.vocab = {
}
self.device = None
self.token_type = True
self.model_name = "BERT"
self.warmup_steps = 1000
self.t_total = 100000
self.class_list = {
}
with open(self.class_path, "r", encoding="utf-8") as fr:
idx = 0
for line in fr:
line = line.strip("\n")
self.class_list[line] = idx
idx += 1
self.class_list_verse = {
v: k for k, v in self.class_list.items()}
self.num_epochs = 5 # epoch数
self.batch_size = 32 # mini-batch大小
self.pad_size = 256 # 每句话处理成的长度(短填长切)
self.learning_rate = 1e-5 # 学习率
self.require_improvement = 10000000 # 若超过1000batch效果还没提升,则提前结束训练
self.multi_gpu = True
self.device_ids = [0, 1]
self.full_fine_tune = True
self.use_adamW = True
self.input_language = "multi" # ["eng", "original", "multi"]
self.MAX_VOCAB_SIZE = 20000
self.min_vocab_freq = 1
if "BERT" in model_name:
print("Load BERT Tokenizer")
self.bert_path = "bert-base-chinese"
self.tokenizer = BertTokenizer.from_pretrained(self.bert_path)
else:
print("Load BERT Tokenizer")
self.bert_path = "bert-base-chinese"
self.tokenizer = BertTokenizer.from_pretrained(self.bert_path)
0
1